File tree Expand file tree Collapse file tree 3 files changed +18
-0
lines changed
Expand file tree Collapse file tree 3 files changed +18
-0
lines changed Original file line number Diff line number Diff line change 1414import pytensor .link .mlx .dispatch .sort
1515import pytensor .link .mlx .dispatch .slinalg
1616import pytensor .link .mlx .dispatch .nlinalg
17+ import pytensor .link .mlx .dispatch .random
1718# isort: on
Original file line number Diff line number Diff line change @@ -300,6 +300,16 @@ def _coerce_to_int(value):
300300 )
301301
302302
303+ def mlx_to_list_shape (size ) -> list [int ]:
304+ """Convert a size value (mx.array, np.ndarray, or sequence) to a plain Python list of ints.
305+
306+ Used by random variable dispatch to normalise the ``size`` argument, which
307+ PyTensor may pass as an ``mx.array`` or ``np.ndarray`` rather than a plain
308+ Python list.
309+ """
310+ return [_coerce_to_int (x ) for x in size ]
311+
312+
303313def _rethrow_dynamic_shape_error (exc ):
304314 msg = str (exc )
305315 if "[eval] Attempting to eval an array during function transformations" in msg :
Original file line number Diff line number Diff line change @@ -69,9 +69,16 @@ def create_thunk_inputs(self, storage_map):
6969 list
7070 The inputs for the thunk
7171 """
72+ from numpy .random import Generator
73+
74+ from pytensor .link .mlx .dispatch import mlx_typify
75+
7276 thunk_inputs = []
7377 for n in self .fgraph .inputs :
7478 sinput = storage_map [n ]
79+ if isinstance (sinput [0 ], Generator ):
80+ # Convert Generator into MLX PRNG key
81+ sinput [0 ] = mlx_typify (sinput [0 ])
7582 thunk_inputs .append (sinput )
7683
7784 return thunk_inputs
You can’t perform that action at this time.
0 commit comments