Skip to content

Commit e6f7371

Browse files
committed
address review: shared rng dance in linker, xor-fold 128-bit pcg64 seed
1 parent d602268 commit e6f7371

File tree

3 files changed

+146
-8
lines changed

3 files changed

+146
-8
lines changed

pytensor/link/mlx/dispatch/random.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@
88
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx, mlx_to_list_shape
99

1010

11-
def _truncate_pcg64_state_to_uint64(rng: Generator) -> int:
12-
return int(rng.bit_generator.state["state"]["state"]) & 0xFFFFFFFFFFFFFFFF
13-
14-
1511
def numpy_generator_to_mlx_key(rng: Generator) -> mx.array:
1612
"""Convert a NumPy Generator to an MLX random key.
1713
1814
MLX uses a functional RNG model where each random call takes an explicit
19-
key rather than mutating shared state. This extracts the lower 64 bits of
20-
the PCG64 state integer as a seed for the MLX key.
15+
key rather than mutating shared state. The PCG64 state is 128 bits, which
16+
MLX cannot accept directly. We fold both 64-bit halves together via XOR
17+
to use all 128 bits of entropy in a single 64-bit seed.
2118
"""
22-
return mx.random.key(_truncate_pcg64_state_to_uint64(rng))
19+
state_128 = int(rng.bit_generator.state["state"]["state"])
20+
upper = (state_128 >> 64) & 0xFFFFFFFFFFFFFFFF
21+
lower = state_128 & 0xFFFFFFFFFFFFFFFF
22+
return mx.random.key(upper ^ lower)
2323

2424

2525
@mlx_typify.register(Generator)

pytensor/link/mlx/linker.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
3+
from pytensor.compile.sharedvalue import SharedVariable, shared
14
from pytensor.link.basic import JITLinker
25

36

@@ -17,7 +20,7 @@ def __init__(self, use_compile=True, *args, **kwargs):
1720
self.gen_functors = []
1821
self.use_compile = use_compile
1922

20-
def fgraph_convert(self, fgraph, **kwargs):
23+
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2124
"""Convert a PyTensor FunctionGraph to an MLX-compatible function.
2225
2326
Parameters
@@ -31,9 +34,63 @@ def fgraph_convert(self, fgraph, **kwargs):
3134
An MLX-compatible function
3235
"""
3336
from pytensor.link.mlx.dispatch import mlx_funcify
37+
from pytensor.tensor.random.type import RandomType
38+
39+
shared_rng_inputs = [
40+
inp
41+
for inp in fgraph.inputs
42+
if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType))
43+
]
44+
45+
# Replace any shared RNG inputs so that their values can be updated in place
46+
# without affecting the original RNG container. This is necessary because
47+
# MLX does not accept Generators as inputs, and they will have to
48+
# be typified
49+
if shared_rng_inputs:
50+
warnings.warn(
51+
f"The RandomType SharedVariables {shared_rng_inputs} will not be used "
52+
f"in the compiled MLX graph. Instead a copy will be used.",
53+
UserWarning,
54+
)
55+
new_shared_rng_inputs = [
56+
shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs
57+
]
58+
59+
fgraph.replace_all(
60+
zip(shared_rng_inputs, new_shared_rng_inputs, strict=True),
61+
import_missing=True,
62+
reason="MLXLinker.fgraph_convert",
63+
)
64+
65+
for old_inp, new_inp in zip(
66+
shared_rng_inputs, new_shared_rng_inputs, strict=True
67+
):
68+
new_inp_storage = [new_inp.get_value(borrow=True)]
69+
storage_map[new_inp] = new_inp_storage
70+
old_inp_storage = storage_map.pop(old_inp)
71+
# Find index of old_inp_storage in input_storage
72+
for input_storage_idx, input_storage_item in enumerate(input_storage):
73+
# We have to establish equality based on identity because input_storage may contain numpy arrays
74+
if input_storage_item is old_inp_storage:
75+
break
76+
else: # no break
77+
raise ValueError()
78+
input_storage[input_storage_idx] = new_inp_storage
79+
# We need to change the order of the inputs of the FunctionGraph
80+
# so that the new input is in the same position as to old one,
81+
# to align with the storage_map. We hope this is safe!
82+
old_inp_fgraph_index = fgraph.inputs.index(old_inp)
83+
fgraph.remove_input(
84+
old_inp_fgraph_index,
85+
reason="MLXLinker.fgraph_convert",
86+
)
87+
fgraph.inputs.remove(new_inp)
88+
fgraph.inputs.insert(old_inp_fgraph_index, new_inp)
3489

3590
return mlx_funcify(
3691
fgraph,
92+
input_storage=input_storage,
93+
storage_map=storage_map,
3794
**kwargs,
3895
)
3996

tests/link/mlx/test_random.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import pytensor
55
import pytensor.tensor as pt
6+
from pytensor.compile.function import function
67
from pytensor.compile.mode import MLX, Mode
8+
from pytensor.compile.sharedvalue import shared
79
from pytensor.link.mlx.linker import MLXLinker
810
from pytensor.tensor.random.utils import RandomStream
911

@@ -173,3 +175,82 @@ def test_beta_not_implemented():
173175
rv = srng.beta(alpha=2.0, beta=5.0, size=(3,))
174176
with pytest.raises(NotImplementedError, match="No MLX implementation"):
175177
pytensor.function([], rv, mode="MLX", updates=srng.updates())
178+
179+
180+
def compile_shared_rng_function(*args, mode="MLX", **kwargs):
181+
with pytest.warns(
182+
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
183+
):
184+
return function(*args, mode=mode, **kwargs)
185+
186+
187+
def test_random_updates():
188+
original_value = np.random.default_rng(seed=98)
189+
rng = shared(original_value, name="original_rng", borrow=False)
190+
next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs
191+
192+
f = compile_shared_rng_function([], [x], updates={rng: next_rng})
193+
assert f() != f()
194+
195+
# Check that the original shared variable was not overwritten when typifying
196+
assert all(
197+
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
198+
for a, b in zip(
199+
rng.get_value().bit_generator.state,
200+
original_value.bit_generator.state,
201+
strict=True,
202+
)
203+
)
204+
205+
206+
@pytest.mark.parametrize("noise_first", (False, True))
207+
def test_replaced_shared_rng_storage_order(noise_first):
208+
# Test that replacing the RNG variable in the linker does not cause
209+
# a disalignment between the compiled graph and the storage_map.
210+
211+
mu = pytensor.shared(np.array(1.0), name="mu")
212+
rng = pytensor.shared(np.random.default_rng(123))
213+
next_rng, noise = pt.random.normal(rng=rng).owner.outputs
214+
215+
out = noise * mu if noise_first else mu * noise
216+
217+
updates = {
218+
mu: pt.grad(out, mu),
219+
rng: next_rng,
220+
}
221+
f = compile_shared_rng_function([], [out], updates=updates)
222+
223+
# Confirm that input_storage type and fgraph input order are aligned
224+
for storage, fgraph_input in zip(
225+
f.input_storage, f.maker.fgraph.inputs, strict=True
226+
):
227+
assert storage.type == fgraph_input.type
228+
229+
assert mu.get_value() == 1
230+
f()
231+
assert mu.get_value() != 1
232+
233+
234+
def test_replaced_shared_rng_storage_ordering_equality():
235+
"""Test that storage identity comparison works when numpy arrays precede
236+
the RNG in input_storage (regression test for issue #314)."""
237+
pt_rng = RandomStream(1)
238+
239+
batchshape = (3, 1, 4, 4)
240+
inp_shared = pytensor.shared(
241+
np.zeros(batchshape, dtype="float64"), name="inp_shared"
242+
)
243+
244+
inp = pt.tensor4(dtype="float64", name="inp")
245+
inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5)
246+
247+
fn = compile_shared_rng_function(
248+
inputs=[],
249+
outputs=[],
250+
updates={inp_shared: inp_update},
251+
givens={inp: inp_shared},
252+
)
253+
fn()
254+
np.testing.assert_allclose(np.array(inp_shared.get_value()), 5, rtol=1e-2)
255+
fn()
256+
np.testing.assert_allclose(np.array(inp_shared.get_value()), 10, rtol=1e-2)

0 commit comments

Comments
 (0)