Skip to content

Commit b620bb1

Browse files
committed
handle rng input
1 parent 9ff1207 commit b620bb1

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
import pytensor.link.mlx.dispatch.sort
1515
import pytensor.link.mlx.dispatch.slinalg
1616
import pytensor.link.mlx.dispatch.nlinalg
17+
import pytensor.link.mlx.dispatch.random
1718
# isort: on

pytensor/link/mlx/dispatch/core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,16 @@ def _coerce_to_int(value):
300300
)
301301

302302

303+
def mlx_to_list_shape(size) -> list[int]:
304+
"""Convert a size value (mx.array, np.ndarray, or sequence) to a plain Python list of ints.
305+
306+
Used by random variable dispatch to normalise the ``size`` argument, which
307+
PyTensor may pass as an ``mx.array`` or ``np.ndarray`` rather than a plain
308+
Python list.
309+
"""
310+
return [_coerce_to_int(x) for x in size]
311+
312+
303313
def _rethrow_dynamic_shape_error(exc):
304314
msg = str(exc)
305315
if "[eval] Attempting to eval an array during function transformations" in msg:

pytensor/link/mlx/linker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,16 @@ def create_thunk_inputs(self, storage_map):
6969
list
7070
The inputs for the thunk
7171
"""
72+
from numpy.random import Generator
73+
74+
from pytensor.link.mlx.dispatch import mlx_typify
75+
7276
thunk_inputs = []
7377
for n in self.fgraph.inputs:
7478
sinput = storage_map[n]
79+
if isinstance(sinput[0], Generator):
80+
# Convert Generator into MLX PRNG key
81+
sinput[0] = mlx_typify(sinput[0])
7582
thunk_inputs.append(sinput)
7683

7784
return thunk_inputs

0 commit comments

Comments
 (0)