feat(mlx): pt.random support with mlx backend#1979
feat(mlx): pt.random support with mlx backend#1979williambdean wants to merge 8 commits intopymc-devs:v3from
Conversation
b620bb1 to
d602268
Compare
ricardoV94
left a comment
There was a problem hiding this comment.
missing rng outputs /updates (so consecutive calls get updated rng)
There should be tests in numba/jax you can use as template. Jax is going to be more similar
| thunk_inputs = [] | ||
| for n in self.fgraph.inputs: | ||
| sinput = storage_map[n] | ||
| if isinstance(sinput[0], Generator): |
There was a problem hiding this comment.
you need to do the same dance jax linker does with shared Generator variables
|
#2010 caused conflicts for this PR. You will need to rebase. |
e6f7371 to
0b4fb85
Compare
| def sample_fn(rng_key, size, dtype, p): | ||
| p = mx.array(p) | ||
| if size is None: | ||
| shape = p.shape |
There was a problem hiding this comment.
you always need the shape? You didn't need it in the categorical. I would assume you only need when one of the parameters doesn't go in the random function. If so that would take a lot of boilerplate away from your dispatches
There was a problem hiding this comment.
my comment wasn't about Bernoulli specifically, I would expect you don't need to define shape explicitly (when the user didn't do it themselves) most of the time
…ethods, permutation error at dispatch time
0af680e to
5d450de
Compare
|
why? |
|
Accident |
Description
Basic support for
mlxrandom generation.They have limited support. Missing Gamma distribution. Could support additional ones
with basic transformations. i.e.
pt.abs(pt.random.normal(...))~ Half NormalMLX Reference: https://ml-explore.github.io/mlx/build/html/python/random.html
Related Issue
Checklist
Type of change