Nicholas M. Boffi
nmboffi.bsky.social
Nicholas M. Boffi
@nmboffi.bsky.social
Building generative models for high-dimensional science and engineering.

Assistant prof. @CarnegieMellon & affiliated faculty @mldcmu, previously instructor @NYU_Courant, PhD jointly @Harvard and @MIT

https://nmboffi.github.io
thanks for this -- it's a very useful resource. it occupies a slightly different niche than i'm looking for, though. what if you want to train on CIFAR-10, ImageNet 64x64, etc., so the latent space isn't needed? are there any standard UNet implementations around in jax, like huggingface diffusers?
February 18, 2025 at 2:18 PM
PS: i was on Vikas's team at brain back in the day, so nice to interact with you @froystig.bsky.social! (7/7)
February 5, 2025 at 3:15 PM
but certainly interested in tpus too.

anyways, thanks for listening! i really think it would help the ai + scientific computing community to have models implemented. jax is just so much easier for almost all tasks except actually using SoTA models, in my opinion. (6/n)
February 5, 2025 at 3:15 PM
note that these comments on graph networks also apply to applications of generative models to molecular sciences, which is now a huge area.

i mostly use gpus, but have played with tpus through google research cloud. i'd probably prioritize gpus (they're mostly what's available in academia) (5/n)
February 5, 2025 at 3:14 PM
applications of diffusion-based techniques to particle systems in statistical physics and quantum mechanics, where graph networks are very natural. jraph is not nearly as mature as pytorch geometric, which makes it hard to use jax. otherwise, jax is clearly much better for the entire workflow. (4/n)
February 5, 2025 at 3:13 PM
this meant we couldn't really use jax even for cifar-10, and had to switch to pytorch. personally i release all my code in jax (nmboffi.github.io/code/) so this was a big disappointment to me.

another example is graph neural networks. there's jraph, but it's very "roll your own". i work on (3/n)
code
nmboffi.github.io
February 5, 2025 at 3:12 PM
or the huggingface diffusers implementation of a u-net. these are often used in large-scale experiments by people at meta, in some of yang song's works, etc. when i was playing with the jax implementation of these models by huggingface last year, the same model parameters would perform worse. (2/n)
February 5, 2025 at 3:11 PM
i work mostly on generative models and scientific applications right now, so for me, good implementations of things like u-nets, transformers, etc. would help a lot. for example, if you want to train a diffusion model in pytorch, there's the lucidrains implementation, (1/n)
February 5, 2025 at 3:10 PM
i absolutely love jax, but an issue that often comes up in practice is the lack of available pre-implemented models (SoTA UNets, transformers, etc).

is there any plan for google to release a package with model implementations? their absence seems to be the dominant issue for scaling jax in research
February 5, 2025 at 12:49 AM
it's a bummer because i much prefer jax, but writing your own everything is not always a viable option
February 3, 2025 at 1:48 PM
very nice post!
February 1, 2025 at 3:37 PM
here the l_2 norm picks up a \sqrt{d} dependence while the infinity norm does not. this is where choosing the mirror map to be the entropy comes into play, recovering exponential weights methods
January 29, 2025 at 1:16 PM
the point is that it doesn't change the 1/t convergence rate of gradient descent but it can change the dimensionality dependence. the canonical example is if you take a problem where the gradients tend to be the same in each component (such as problems over the probability simplex)
January 29, 2025 at 1:16 PM
isn't the basic idea that you want the bregman divergence to be strongly convex with respect to a norm \Vert\cdot\Vert such that the gradients are bounded in the corresponding dual norm? see below for a slide from my phd thesis -- this is covered in the cited textbook by Nemirovsky and Yudin
January 29, 2025 at 1:14 PM
how do these techniques, which approach the HJB equation directly, compare to more traditional RL algorithms? can they be used in RL pipelines, such as for problems in robotics? can RL algorithms be flipped on their head and used to solve classes of high-d PDEs? (3/n)
January 16, 2025 at 2:07 PM
we know high-d HJB equations characterize solutions to optimal control problems, and this is precisely what RL aims to solve. so, RL must be implicitly approximating the solution to a high-dimensional HJB equation. (2/n)
January 16, 2025 at 2:06 PM
in the spirit of using this platform for scientific discussion, i'll post a question i've been wondering about that may or may not be very well formulated

techniques like the above method can be used, in principle, to solve high-dimensional HJB equations (1/)
January 16, 2025 at 2:05 PM
mark is a hero -- extremely kind, unpretentious individual too.
January 15, 2025 at 6:44 PM
this was what i thought as well; morally speaking you could say learning surrogate models for the "ground-truth" MD sampler?

i've read your recent ITO papers to try to learn more about this: what's the right assumption on the data? do we have it, or do we need to sample given U but with no data?
January 10, 2025 at 1:50 PM