|
3 | 3 |
|
4 | 4 | import pytensor |
5 | 5 | import pytensor.tensor as pt |
| 6 | +from pytensor.compile.function import function |
6 | 7 | from pytensor.compile.mode import MLX, Mode |
| 8 | +from pytensor.compile.sharedvalue import shared |
7 | 9 | from pytensor.link.mlx.linker import MLXLinker |
8 | 10 | from pytensor.tensor.random.utils import RandomStream |
9 | 11 |
|
@@ -173,3 +175,82 @@ def test_beta_not_implemented(): |
173 | 175 | rv = srng.beta(alpha=2.0, beta=5.0, size=(3,)) |
174 | 176 | with pytest.raises(NotImplementedError, match="No MLX implementation"): |
175 | 177 | pytensor.function([], rv, mode="MLX", updates=srng.updates()) |
| 178 | + |
| 179 | + |
| 180 | +def compile_shared_rng_function(*args, mode="MLX", **kwargs): |
| 181 | + with pytest.warns( |
| 182 | + UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" |
| 183 | + ): |
| 184 | + return function(*args, mode=mode, **kwargs) |
| 185 | + |
| 186 | + |
| 187 | +def test_random_updates(): |
| 188 | + original_value = np.random.default_rng(seed=98) |
| 189 | + rng = shared(original_value, name="original_rng", borrow=False) |
| 190 | + next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs |
| 191 | + |
| 192 | + f = compile_shared_rng_function([], [x], updates={rng: next_rng}) |
| 193 | + assert f() != f() |
| 194 | + |
| 195 | + # Check that the original shared variable was not overwritten when typifying |
| 196 | + assert all( |
| 197 | + a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) |
| 198 | + for a, b in zip( |
| 199 | + rng.get_value().bit_generator.state, |
| 200 | + original_value.bit_generator.state, |
| 201 | + strict=True, |
| 202 | + ) |
| 203 | + ) |
| 204 | + |
| 205 | + |
| 206 | +@pytest.mark.parametrize("noise_first", (False, True)) |
| 207 | +def test_replaced_shared_rng_storage_order(noise_first): |
| 208 | + # Test that replacing the RNG variable in the linker does not cause |
| 209 | + # a disalignment between the compiled graph and the storage_map. |
| 210 | + |
| 211 | + mu = pytensor.shared(np.array(1.0), name="mu") |
| 212 | + rng = pytensor.shared(np.random.default_rng(123)) |
| 213 | + next_rng, noise = pt.random.normal(rng=rng).owner.outputs |
| 214 | + |
| 215 | + out = noise * mu if noise_first else mu * noise |
| 216 | + |
| 217 | + updates = { |
| 218 | + mu: pt.grad(out, mu), |
| 219 | + rng: next_rng, |
| 220 | + } |
| 221 | + f = compile_shared_rng_function([], [out], updates=updates) |
| 222 | + |
| 223 | + # Confirm that input_storage type and fgraph input order are aligned |
| 224 | + for storage, fgraph_input in zip( |
| 225 | + f.input_storage, f.maker.fgraph.inputs, strict=True |
| 226 | + ): |
| 227 | + assert storage.type == fgraph_input.type |
| 228 | + |
| 229 | + assert mu.get_value() == 1 |
| 230 | + f() |
| 231 | + assert mu.get_value() != 1 |
| 232 | + |
| 233 | + |
| 234 | +def test_replaced_shared_rng_storage_ordering_equality(): |
| 235 | + """Test that storage identity comparison works when numpy arrays precede |
| 236 | + the RNG in input_storage (regression test for issue #314).""" |
| 237 | + pt_rng = RandomStream(1) |
| 238 | + |
| 239 | + batchshape = (3, 1, 4, 4) |
| 240 | + inp_shared = pytensor.shared( |
| 241 | + np.zeros(batchshape, dtype="float64"), name="inp_shared" |
| 242 | + ) |
| 243 | + |
| 244 | + inp = pt.tensor4(dtype="float64", name="inp") |
| 245 | + inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5) |
| 246 | + |
| 247 | + fn = compile_shared_rng_function( |
| 248 | + inputs=[], |
| 249 | + outputs=[], |
| 250 | + updates={inp_shared: inp_update}, |
| 251 | + givens={inp: inp_shared}, |
| 252 | + ) |
| 253 | + fn() |
| 254 | + np.testing.assert_allclose(np.array(inp_shared.get_value()), 5, rtol=1e-2) |
| 255 | + fn() |
| 256 | + np.testing.assert_allclose(np.array(inp_shared.get_value()), 10, rtol=1e-2) |
0 commit comments