Goal: Recreate the DiT paper from “scratch” (only pytorch API for now), using only the research paper + other papers online. No online code/github implementations or LLM’s (including copilot) allowed.
Scalable Diffusion Models with Transformers
My approach for research papers has always been 1) have an understanding of the paper at a high level, 2) read the code implementing the paper to get a deep understanding (ie. if i don’t understand a concept, look at the code + keep going lower into the hierarchy until I understand it)
A problem thats come up recently is that its actually harder to find open source implementations of certain papers as things get more closed source + theres just something fun about doing it yourself! I haven’t had any experience building diffusion models from scratch (unlike language models, graph neural nets, or RL), so I’m taking this as a challenge to be able to break that abstraction layer & get more mathy!
I’m also planning on implementing this using WSGL on the web, so this whole process should help me understand the fundamentals a lot better.
300
Check it out above! Its a ~130 line implementation of DiT’s where the code is (hopefully) really readable, and a good learning resource for people! It took a couple after work sessions to get everything working, but I think I came out understanding a lot more about diffusion than I did before.
I think I could probably get the line count down a lot but readability is more important for me.
TODOS:
Playing around with noise and generation actually gave me what I think is an interesting way to think of diffusion. Basically thinking of diffusion models as “guides” in latent space, instead of next noise predictors.
This way you can actually manipulate what the model outputs at sampling time, just using it as a vibe measure of what the real image should look like. (ie. you can use negative controls, or even think of it as an equation that slowly traverses the gradients to give you the correct answer!)
Once you think of it that way, you can actually cross apply a lot of the ideas from backpropogation & navigating gradients here! (learning rates how much to update images etc.) The sample above is with a mini technique where I tried “batch estimating” gradients—I’m experimenting a lot here!
350
<aside> 🔥 pt. 2 coming soon…
*30% of the way done with porting all these layers to webgpu WGSL shaders/kernels, planning on having hippo favicon diffusion on my website soon!!)
my reasoning for pushing pt1/pt2 separately is that I wanted to give the research-y side its own post + me learning WSGL/shaders outside of WebGL would crowd out the rest*
</aside>