@@ -435,6 +435,39 @@ def model(data, labels=None):
435435 assert f"nn{ scope_divider } layers.1.bias" in samples
436436
437437
438+ @pytest .mark .parametrize ("use_deterministic" , [True , False ])
439+ def test_random_nnx_module_mcmc_with_mutable_state (use_deterministic ):
440+ from flax import nnx
441+
442+ class NNXModel (nnx .Module ):
443+ def __init__ (self ):
444+ self .linear = nnx .Linear (10 , 1 , rngs = nnx .Rngs (0 ))
445+ self .mutable = nnx .Variable (0 )
446+
447+ def __call__ (self , x ):
448+ return self .linear (x )
449+
450+ nn_module = NNXModel ()
451+
452+ def model (x , y = None ):
453+ random_model = random_nnx_module ("model" , nn_module , dist .Normal (0 , 1 ))
454+ pred = random_model (x )
455+ with numpyro .plate ("plate" , size = x .shape [0 ]):
456+ if use_deterministic :
457+ pred = numpyro .deterministic ("pred" , pred )
458+ numpyro .sample ("obs" , dist .Normal (pred , 1.0 ).to_event (1 ), obs = y )
459+
460+ x = jax .random .uniform (jax .random .key (0 ), shape = (10 , 10 ))
461+ y = jax .random .uniform (jax .random .key (0 ), shape = (10 , 1 ))
462+
463+ mcmc = MCMC (NUTS (model ), num_warmup = 5 , num_samples = 5 , progress_bar = False )
464+ with jax .check_tracer_leaks (True ):
465+ mcmc .run (jax .random .key (0 ), x , y )
466+ samples = mcmc .get_samples ()
467+ assert "model/linear.kernel" in samples
468+ assert "model/linear.bias" in samples
469+
470+
438471@pytest .mark .skipif (sys .version_info [:2 ] == (3 , 9 ), reason = "Skipping on Python 3.9" )
439472def test_eqx_module ():
440473 import equinox as eqx
0 commit comments