This morning’s post was a snapshot at 10,800 lines. Since then, another 23 commits went in — vectorizing the full inference stack, adding gradient estimators, neural network integration, and compiled Metal inference loops. The theme: make GenMLX fast for real workloads.

Vectorized splice

The previous vectorization work had a limitation: splice (sub-generative-function calls) didn’t work in batched mode. Now it does. DynamicGF sub-GFs run under batched handlers, with nested splice (3+ levels) supported. This means hierarchical models vectorize end-to-end.

Vectorized MCMC and SMC

N independent MH chains now run in parallel via MLX broadcasting — enabling parallel tempering and R-hat convergence diagnostics from a single call.

For SMC, the model body runs once per timestep for all N particles. Multi-step batched particle filtering gives 24.5x speedup over sequential SMC (50 particles, 5 steps).

Compiled Metal inference loops

The entire MH accept/reject chain compiles into a single Metal program via mx/compile-fn. Same for the VI optimization loop. No per-step JavaScript-to-Metal dispatch overhead — the full iteration is one cached GPU dispatch.

A benchmark suite comparing GenMLX vs handcoded MLX for importance sampling and HMC at various scales confirms the abstraction overhead is minimal.

Near-complete batch sampling

Native dist-sample-n now covers nearly every distribution. Added discrete-uniform, geometric, categorical, multivariate-normal, binomial, and student-t. Then a vectorized Marsaglia-Tsang implementation for gamma (with Ahrens-Dieter for alpha < 1) unlocked batch sampling for beta, inverse-gamma, and dirichlet. Only poisson remains sequential.

VIMCO and ADEV gradient estimators

VIMCO — multi-sample variational objective with leave-one-out baseline for tighter gradient estimates than single-sample ELBO.

ADEV — sound automatic differentiation of expected values (Lew et al., POPL 2023). Integrates reparameterization and score-function estimators with mx/grad, with a dedicated ADEV handler for tracing through stochastic computation graphs.

Prefix skipping for sequential models

Unfold and Scan now store per-step scores (and carries for Scan) as trace metadata. During update, the system finds the earliest constrained step and skips the unchanged prefix. This turns O(T^2) total work for SMC on time series into O(T).

Custom gradient generative functions

CustomGradientGF wraps user-supplied forward/backward passes. The IHasArgumentGrads protocol indicates which arguments are differentiable, so gradient-based inference avoids unnecessary computation.

Neural networks as generative functions

NeuralNetGF wraps MLX nn.Module as a deterministic generative function with full GFI support. Layer constructors (linear, sequential, relu, gelu, etc.) and training utilities using MLX’s native nn.valueAndGrad.

Building on this, an amortized inference module provides VAE-style reparameterized ELBO training and neural importance sampling — training a neural network guide, then using it as a proposal for importance sampling.

Complete GFI coverage

All 9 combinators now implement the full GFI. Mask got update and regenerate. Contramap and Dimap got update-with-diffs. No more partial implementations.

Running total

Metric Value
Commits today 23
Lines added +3,933
TODO items completed 17 (41 of 66 done)
New source files 5
New test files 9

The code is at github.com/robert-johansson/genmlx.