Skip to content

Commit 9ff1207

Browse files
committed
add test suite
1 parent 63763d9 commit 9ff1207

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

tests/link/mlx/test_random.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor
5+
import pytensor.tensor as pt
6+
from pytensor.compile.mode import MLX, Mode
7+
from pytensor.link.mlx.linker import MLXLinker
8+
from pytensor.tensor.random.utils import RandomStream
9+
10+
11+
mx = pytest.importorskip("mlx.core")
12+
13+
# MLX mode without mx.compile — needed for ops that use CPU streams internally
14+
# (e.g. multivariate_normal, which uses SVD via mx.cpu stream and is
15+
# incompatible with mx.compile's tracing).
16+
MLX_NO_COMPILE = Mode(linker=MLXLinker(use_compile=False), optimizer=MLX.optimizer)
17+
18+
19+
def test_normal_cumsum():
20+
out = pt.random.normal(size=(52,)).cumsum()
21+
result = out.eval(mode="MLX")
22+
assert isinstance(result, mx.array)
23+
assert result.shape == (52,)
24+
25+
26+
def check_shape_and_dtype(
27+
make_rv, expected_shape, expected_dtype=None, n_evals=2, mode="MLX"
28+
):
29+
"""Compile and run an RV under MLX, assert shape and dtype, and verify
30+
that two successive draws differ (RNG state is properly threaded).
31+
32+
Parameters
33+
----------
34+
make_rv : callable(srng) -> rv_var
35+
Factory that creates the RV using the provided RandomStream.
36+
expected_shape : tuple
37+
expected_dtype : str or None
38+
n_evals : int
39+
mode : str or Mode
40+
"""
41+
srng = RandomStream(seed=12345)
42+
rv = make_rv(srng)
43+
f = pytensor.function([], rv, mode=mode, updates=srng.updates())
44+
results = [np.array(f()) for _ in range(n_evals)]
45+
46+
for r in results:
47+
assert r.shape == expected_shape, (
48+
f"Expected shape {expected_shape}, got {r.shape}"
49+
)
50+
if expected_dtype is not None:
51+
assert r.dtype == np.dtype(expected_dtype), (
52+
f"Expected dtype {expected_dtype}, got {r.dtype}"
53+
)
54+
55+
assert not np.array_equal(results[0], results[1]), (
56+
"Two draws were identical — RNG not advancing"
57+
)
58+
59+
return results
60+
61+
62+
def test_normal_shape_dtype():
63+
check_shape_and_dtype(
64+
lambda srng: srng.normal(loc=0.0, scale=1.0, size=(3, 4)),
65+
(3, 4),
66+
"float32",
67+
)
68+
69+
70+
def test_normal_scalar():
71+
check_shape_and_dtype(
72+
lambda srng: srng.normal(loc=2.0, scale=0.5),
73+
(),
74+
)
75+
76+
77+
def test_normal_array_params():
78+
result = pt.random.normal(loc=[0, 1], scale=[1.0, 0.3], size=(100, 2)).eval(
79+
mode="MLX"
80+
)
81+
assert result.shape == (100, 2)
82+
means = np.array(result).mean(axis=0)
83+
assert abs(means[0]) < 0.3
84+
assert abs(means[1] - 1.0) < 0.3
85+
86+
87+
def test_uniform_shape_dtype():
88+
results = check_shape_and_dtype(
89+
lambda srng: srng.uniform(low=0.0, high=1.0, size=(10,)),
90+
(10,),
91+
"float32",
92+
)
93+
r = np.array(results[0])
94+
assert np.all(r >= 0.0)
95+
assert np.all(r < 1.0)
96+
97+
98+
def test_bernoulli_shape():
99+
check_shape_and_dtype(
100+
lambda srng: srng.bernoulli(p=0.7, size=(5, 5)),
101+
(5, 5),
102+
)
103+
104+
105+
def test_categorical_shape():
106+
probs = np.array([0.1, 0.4, 0.5], dtype=np.float32)
107+
results = check_shape_and_dtype(
108+
lambda srng: srng.categorical(p=probs, size=(8,)),
109+
(8,),
110+
)
111+
r = np.array(results[0])
112+
assert np.all(r < 3)
113+
assert np.all(r >= 0)
114+
115+
116+
def test_mvnormal_shape():
117+
mean = np.zeros(4, dtype=np.float32)
118+
cov = np.eye(4, dtype=np.float32)
119+
# multivariate_normal uses SVD internally (CPU-only in MLX), which is
120+
# incompatible with mx.compile — use the no-compile mode.
121+
check_shape_and_dtype(
122+
lambda srng: srng.multivariate_normal(mean=mean, cov=cov, size=(6,)),
123+
(6, 4),
124+
"float32",
125+
mode=MLX_NO_COMPILE,
126+
)
127+
128+
129+
def test_laplace_shape_dtype():
130+
check_shape_and_dtype(
131+
lambda srng: srng.laplace(loc=0.0, scale=1.0, size=(7,)),
132+
(7,),
133+
"float32",
134+
)
135+
136+
137+
def test_gumbel_shape_dtype():
138+
check_shape_and_dtype(
139+
lambda srng: srng.gumbel(loc=0.0, scale=1.0, size=(6,)),
140+
(6,),
141+
"float32",
142+
)
143+
144+
145+
def test_integers_shape():
146+
results = check_shape_and_dtype(
147+
lambda srng: srng.integers(low=0, high=10, size=(12,)),
148+
(12,),
149+
)
150+
r = np.array(results[0])
151+
assert np.all(r >= 0)
152+
assert np.all(r < 10)
153+
154+
155+
def test_permutation_shape():
156+
x = np.arange(8, dtype=np.int32)
157+
results = check_shape_and_dtype(
158+
lambda srng: srng.permutation(x),
159+
(8,),
160+
)
161+
assert sorted(np.array(results[0]).tolist()) == list(range(8))
162+
163+
164+
def test_gamma_not_implemented():
165+
srng = RandomStream(seed=1)
166+
rv = srng.gamma(shape=1.0, scale=1.0, size=(3,))
167+
with pytest.raises(NotImplementedError, match="No MLX implementation"):
168+
pytensor.function([], rv, mode="MLX", updates=srng.updates())
169+
170+
171+
def test_beta_not_implemented():
172+
srng = RandomStream(seed=1)
173+
rv = srng.beta(alpha=2.0, beta=5.0, size=(3,))
174+
with pytest.raises(NotImplementedError, match="No MLX implementation"):
175+
pytensor.function([], rv, mode="MLX", updates=srng.updates())

0 commit comments

Comments
 (0)