Skip to content

Commit 63763d9

Browse files
committed
add currently supported random module
1 parent 7858f47 commit 63763d9

File tree

1 file changed

+188
-0
lines changed

1 file changed

+188
-0
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from functools import singledispatch
2+
3+
import mlx.core as mx
4+
from numpy.random import Generator
5+
6+
import pytensor.tensor.random.basic as ptr
7+
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify
8+
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx, mlx_to_list_shape
9+
10+
11+
def _truncate_pcg64_state_to_uint64(rng: Generator) -> int:
12+
return int(rng.bit_generator.state["state"]["state"]) & 0xFFFFFFFFFFFFFFFF
13+
14+
15+
def numpy_generator_to_mlx_key(rng: Generator) -> mx.array:
16+
"""Convert a NumPy Generator to an MLX random key.
17+
18+
MLX uses a functional RNG model where each random call takes an explicit
19+
key rather than mutating shared state. This extracts the lower 64 bits of
20+
the PCG64 state integer as a seed for the MLX key.
21+
"""
22+
return mx.random.key(_truncate_pcg64_state_to_uint64(rng))
23+
24+
25+
@mlx_typify.register(Generator)
26+
def mlx_typify_Generator(rng, **kwargs):
27+
return numpy_generator_to_mlx_key(rng)
28+
29+
30+
@mlx_funcify.register(ptr.RandomVariable)
31+
def mlx_funcify_RandomVariable(op, node, **kwargs):
32+
rv = node.outputs[1]
33+
out_dtype = rv.type.dtype
34+
35+
sample_fn_inner = mlx_sample_fn(op, node)
36+
37+
def sample_fn(rng, size, *parameters):
38+
new_keys = mx.random.split(rng, num=2)
39+
new_rng = new_keys[0]
40+
sampling_key = new_keys[1]
41+
sample = sample_fn_inner(sampling_key, size, out_dtype, *parameters)
42+
return (new_rng, sample)
43+
44+
return sample_fn
45+
46+
47+
@singledispatch
48+
def mlx_sample_fn(op, node):
49+
raise NotImplementedError(
50+
f"No MLX implementation for the given distribution: {op.name}"
51+
)
52+
53+
54+
@mlx_sample_fn.register(ptr.NormalRV)
55+
def mlx_sample_fn_normal(op, node):
56+
def sample_fn(rng_key, size, dtype, mu, sigma):
57+
mlx_dtype = convert_dtype_to_mlx(dtype)
58+
mu = mx.array(mu, dtype=mlx_dtype)
59+
sigma = mx.array(sigma, dtype=mlx_dtype)
60+
if size is None:
61+
shape = mx.broadcast_arrays(mu, sigma)[0].shape
62+
else:
63+
shape = mlx_to_list_shape(size)
64+
s = mx.random.normal(shape=shape, dtype=mlx_dtype, key=rng_key)
65+
return mu + sigma * s
66+
67+
return sample_fn
68+
69+
70+
@mlx_sample_fn.register(ptr.UniformRV)
71+
def mlx_sample_fn_uniform(op, node):
72+
def sample_fn(rng_key, size, dtype, low, high):
73+
mlx_dtype = convert_dtype_to_mlx(dtype)
74+
low = mx.array(low, dtype=mlx_dtype)
75+
high = mx.array(high, dtype=mlx_dtype)
76+
if size is None:
77+
shape = mx.broadcast_arrays(low, high)[0].shape
78+
else:
79+
shape = mlx_to_list_shape(size)
80+
return mx.random.uniform(
81+
low=low, high=high, shape=shape, dtype=mlx_dtype, key=rng_key
82+
)
83+
84+
return sample_fn
85+
86+
87+
@mlx_sample_fn.register(ptr.BernoulliRV)
88+
def mlx_sample_fn_bernoulli(op, node):
89+
def sample_fn(rng_key, size, dtype, p):
90+
p = mx.array(p)
91+
if size is None:
92+
shape = p.shape
93+
else:
94+
shape = mlx_to_list_shape(size)
95+
return mx.random.bernoulli(p=p, shape=shape, key=rng_key)
96+
97+
return sample_fn
98+
99+
100+
@mlx_sample_fn.register(ptr.CategoricalRV)
101+
def mlx_sample_fn_categorical(op, node):
102+
def sample_fn(rng_key, size, dtype, p):
103+
logits = mx.log(mx.array(p))
104+
shape = mlx_to_list_shape(size) if size is not None else None
105+
return mx.random.categorical(logits=logits, axis=-1, shape=shape, key=rng_key)
106+
107+
return sample_fn
108+
109+
110+
@mlx_sample_fn.register(ptr.MvNormalRV)
111+
def mlx_sample_fn_mvnormal(op, node):
112+
def sample_fn(rng_key, size, dtype, mean, cov):
113+
mlx_dtype = convert_dtype_to_mlx(dtype)
114+
shape = mlx_to_list_shape(size) if size is not None else []
115+
# multivariate_normal uses SVD internally, which requires mx.cpu in MLX.
116+
return mx.random.multivariate_normal(
117+
mean=mean,
118+
cov=cov,
119+
shape=shape,
120+
dtype=mlx_dtype,
121+
key=rng_key,
122+
stream=mx.cpu,
123+
)
124+
125+
return sample_fn
126+
127+
128+
@mlx_sample_fn.register(ptr.LaplaceRV)
129+
def mlx_sample_fn_laplace(op, node):
130+
def sample_fn(rng_key, size, dtype, loc, scale):
131+
mlx_dtype = convert_dtype_to_mlx(dtype)
132+
loc = mx.array(loc, dtype=mlx_dtype)
133+
scale = mx.array(scale, dtype=mlx_dtype)
134+
if size is None:
135+
shape = mx.broadcast_arrays(loc, scale)[0].shape
136+
else:
137+
shape = mlx_to_list_shape(size)
138+
s = mx.random.laplace(shape=shape, dtype=mlx_dtype, key=rng_key)
139+
return loc + scale * s
140+
141+
return sample_fn
142+
143+
144+
@mlx_sample_fn.register(ptr.GumbelRV)
145+
def mlx_sample_fn_gumbel(op, node):
146+
def sample_fn(rng_key, size, dtype, loc, scale):
147+
mlx_dtype = convert_dtype_to_mlx(dtype)
148+
loc = mx.array(loc, dtype=mlx_dtype)
149+
scale = mx.array(scale, dtype=mlx_dtype)
150+
if size is None:
151+
shape = mx.broadcast_arrays(loc, scale)[0].shape
152+
else:
153+
shape = mlx_to_list_shape(size)
154+
s = mx.random.gumbel(shape=shape, dtype=mlx_dtype, key=rng_key)
155+
return loc + scale * s
156+
157+
return sample_fn
158+
159+
160+
@mlx_sample_fn.register(ptr.PermutationRV)
161+
def mlx_sample_fn_permutation(op, node):
162+
batch_ndim = op.batch_ndim(node)
163+
164+
def sample_fn(rng_key, size, dtype, x):
165+
if batch_ndim:
166+
raise NotImplementedError(
167+
"MLX random.permutation does not support batch dimensions."
168+
)
169+
return mx.random.permutation(x, key=rng_key)
170+
171+
return sample_fn
172+
173+
174+
@mlx_sample_fn.register(ptr.IntegersRV)
175+
def mlx_sample_fn_integers(op, node):
176+
def sample_fn(rng_key, size, dtype, low, high):
177+
mlx_dtype = convert_dtype_to_mlx(dtype)
178+
low = mx.array(low, dtype=mlx_dtype)
179+
high = mx.array(high, dtype=mlx_dtype)
180+
if size is None:
181+
shape = mx.broadcast_arrays(low, high)[0].shape
182+
else:
183+
shape = mlx_to_list_shape(size)
184+
return mx.random.randint(
185+
low=low, high=high, shape=shape, dtype=mlx_dtype, key=rng_key
186+
)
187+
188+
return sample_fn

0 commit comments

Comments
 (0)