I built GenMLX — a probabilistic programming language in ClojureScript implementing Gen’s Generative Function Interface, running on Apple Silicon GPUs via MLX.

Gen implementations exist for Julia and JAX. Now there’s one for Apple Silicon. The whole thing is about 2000 lines of ClojureScript.

What is the Generative Function Interface?

The GFI is the abstraction at the core of Gen. A generative function is a program that makes random choices. The interface defines four operations on these programs:

;; Forward sample — run the program, record everything
(p/simulate model args)

;; Constrained execution — force certain choices, get a weight
(p/generate model args constraints)

;; Update — change some choices in an existing execution
(p/update model trace new-constraints)

;; Resample — propose new values for selected choices
(p/regenerate model trace selection)

Every inference algorithm is built from these four operations. That’s the power of the design — you write a model once, and importance sampling, MH, HMC, SMC, and VI all just work.

Why MLX?

MLX’s unified memory model turns out to be a natural fit for probabilistic programming. MCMC control flow runs on CPU while all the numerics — scores, gradients, leapfrog steps — stay on GPU. There’s zero data transfer cost because CPU and GPU share the same memory.

ClojureScript on Node.js gives direct access to MLX through node-mlx with ~20ns call overhead, and nbb provides a fast REPL for interactive model development.

A model

Bayesian linear regression in GenMLX:

(ns my-model
  (:require [genmlx.mlx :as mx]
            [genmlx.dist :as dist]
            [genmlx.dynamic :as dyn]
            [genmlx.inference.mcmc :as mcmc]
            [genmlx.choicemap :as cm]
            [genmlx.selection :as sel]
            [genmlx.trace :as tr])
  (:require-macros [genmlx.gen :refer [gen]]))

(def model
  (gen [xs]
    (let [slope     (dyn/trace :slope (dist/gaussian 0 10))
          intercept (dyn/trace :intercept (dist/gaussian 0 10))]
      (mx/eval! slope intercept)
      (let [s (mx/item slope) i (mx/item intercept)]
        (doseq [[j x] (map-indexed vector xs)]
          (dyn/trace (keyword (str "y" j))
                     (dist/gaussian (+ (* s x) i) 1)))
        [s i]))))

Each dyn/trace call names a random choice. The model defines a joint distribution over slopes, intercepts, and observations. Given observed y-values, inference recovers the posterior over slope and intercept.

Inference

(def xs [1.0 2.0 3.0 4.0 5.0])
(def observations
  (reduce (fn [cm [j y]]
            (cm/set-choice cm [(keyword (str "y" j))] (mx/scalar y)))
          cm/EMPTY
          (map-indexed vector [2.1 3.9 6.2 7.8 10.1])))

;; Metropolis-Hastings
(def traces
  (mcmc/mh {:samples 500 :burn 100
            :selection (sel/select :slope :intercept)}
           model [xs] observations))

(let [slopes (mapv (fn [t]
                     (let [v (cm/get-value
                               (cm/get-submap (tr/get-choices t) :slope))]
                       (mx/eval! v) (mx/item v)))
                   traces)]
  (println "Posterior slope mean:"
           (/ (reduce + slopes) (count slopes))))
;; => ~2.0

But the same model works with any inference algorithm — switch mcmc/mh for mcmc/hmc, mcmc/nuts, smc/smc, or vi/vi without changing the model.

The full stack

Layer 0: MLX Foundation     — tensors, autograd, GPU
Layer 1: Core Data Types    — choicemaps, traces, selections
Layer 2: GFI Protocols      — simulate, generate, update, regenerate
Layer 3: DSL                — gen macro, dynamic tracing
Layer 4: Distributions      — 13 distributions (Gaussian, Beta, MVN, ...)
Layer 5: Combinators        — Map, Unfold, Switch
Layer 6: Inference           — IS, MH, MALA, HMC, NUTS, SMC, VI

Seven inference algorithms, 13 distributions, three combinators. 165 compatibility tests adapted from Gen.clj verify that GenMLX produces matching results.

What makes this different from prob-cljs + MLX

Earlier today I wrote about adding MLX as an acceleration backend to prob-cljs. That gives you autograd and gradient-based inference, but the programming model is manual — you write log-posterior functions by hand.

GenMLX is a different thing: a full probabilistic programming language where you write generative models and the system handles trace management, score computation, and the GFI operations that inference algorithms are built on. The model is the program; inference is an operation on the program.

The code is at github.com/robert-johansson/genmlx.