Skip to content

Commit d11c8dc

Browse files
authored
Fix: Tracer leaks during NUTS sampling on random_nnx_module when NNX module contains mutable state (#2162)
* fix * rm skip
1 parent 3f9cdea commit d11c8dc

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

numpyro/handlers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,8 @@ def process_message(self, msg: Message) -> None:
971971
value = self.substitute_fn(msg)
972972

973973
if value is not None:
974+
if msg["type"] == "mutable" and isinstance(value, dict):
975+
value = value.copy()
974976
msg["value"] = value
975977

976978

test/contrib/test_module.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,39 @@ def model(data, labels=None):
435435
assert f"nn{scope_divider}layers.1.bias" in samples
436436

437437

438+
@pytest.mark.parametrize("use_deterministic", [True, False])
439+
def test_random_nnx_module_mcmc_with_mutable_state(use_deterministic):
440+
from flax import nnx
441+
442+
class NNXModel(nnx.Module):
443+
def __init__(self):
444+
self.linear = nnx.Linear(10, 1, rngs=nnx.Rngs(0))
445+
self.mutable = nnx.Variable(0)
446+
447+
def __call__(self, x):
448+
return self.linear(x)
449+
450+
nn_module = NNXModel()
451+
452+
def model(x, y=None):
453+
random_model = random_nnx_module("model", nn_module, dist.Normal(0, 1))
454+
pred = random_model(x)
455+
with numpyro.plate("plate", size=x.shape[0]):
456+
if use_deterministic:
457+
pred = numpyro.deterministic("pred", pred)
458+
numpyro.sample("obs", dist.Normal(pred, 1.0).to_event(1), obs=y)
459+
460+
x = jax.random.uniform(jax.random.key(0), shape=(10, 10))
461+
y = jax.random.uniform(jax.random.key(0), shape=(10, 1))
462+
463+
mcmc = MCMC(NUTS(model), num_warmup=5, num_samples=5, progress_bar=False)
464+
with jax.check_tracer_leaks(True):
465+
mcmc.run(jax.random.key(0), x, y)
466+
samples = mcmc.get_samples()
467+
assert "model/linear.kernel" in samples
468+
assert "model/linear.bias" in samples
469+
470+
438471
@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
439472
def test_eqx_module():
440473
import equinox as eqx

0 commit comments

Comments
 (0)