Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion notebooks/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@

if "READTHEDOCS" not in os.environ:
# if developing locally, use numpyro.__version__ as version
from numpyro import __version__ # noqaE402
from numpyro import __version__ # noqa: E402

version = __version__

Expand Down
15 changes: 12 additions & 3 deletions numpyro/infer/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def init_to_median(site=None, num_samples=15):
samples = site["fn"](
sample_shape=(num_samples,) + sample_shape, rng_key=rng_key
)
return jnp.median(samples, axis=0)
from numpyro.infer.util import helpful_support_errors
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, you put the import here because of circular import?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I borrowed that line from init_to_uniform here.

# XXX: we import here to avoid circular import
from numpyro.infer.util import helpful_support_errors


with helpful_support_errors(site):
transform = biject_to(site["fn"].support)
unconstrained = transform.inv(samples)
median_unconstrained = jnp.median(unconstrained, axis=0)
return transform(median_unconstrained)
except NotImplementedError:
return init_to_uniform(site)

Expand All @@ -66,11 +72,14 @@ def init_to_mean(site=None):
)
return site["value"]
try:
# Try .mean property.
value = site["fn"].mean
# Try .mean property. We multiply by 1.0 to promote the mean to a floating
# point value. This is required to calculate gradients with respect to an
# initialized parameter.
value = 1.0 * site["fn"].mean
sample_shape = site["kwargs"].get("sample_shape")
if sample_shape:
value = jnp.broadcast_to(value, sample_shape + jnp.shape(value))
return value
except (NotImplementedError, ValueError):
return init_to_median(site)

Expand Down
9 changes: 7 additions & 2 deletions test/contrib/einstein/test_steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_init_auto_guide(auto_class, init_loc_fn, num_particles):
latent_dim = 3

def model(obs):
a = numpyro.sample("a", Normal(0, 1).expand((latent_dim,)).to_event(1))
a = numpyro.sample("a", Normal(0.2, 1).expand((latent_dim,)).to_event(1))
return numpyro.sample("obs", Bernoulli(logits=a), obs=obs)

obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim))
Expand Down Expand Up @@ -280,7 +280,12 @@ def model(obs):
assert init_value.shape == expected_shape
if "auto_loc" in name or name == "b":
assert np.all(init_value != np.zeros(expected_shape))
assert np.unique(init_value).shape == init_value.reshape(-1).shape
# Check that all values are unique except when `init_to_mean` is used
# because all initial values will be equal to the mean.
assert (
np.unique(init_value).shape == init_value.reshape(-1).shape
or init_loc_fn is init_to_mean
)
elif "scale" in name:
assert_allclose(init_value[init_value != 0.0], 0.1, rtol=1e-6)

Expand Down
25 changes: 25 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,31 @@ def model():
assert_allclose(mcmc.get_samples()["x"].mean(), 0.0, atol=0.15)


@pytest.mark.parametrize(
"init_strategy",
[
init_to_feasible(),
init_to_median(num_samples=2),
init_to_sample(),
init_to_uniform(radius=3),
init_to_value(values={"tau": 0.7}),
init_to_feasible,
init_to_mean,
init_to_median,
init_to_sample,
init_to_uniform,
init_to_value,
],
)
def test_init_to_valid(init_strategy):
with handlers.trace() as trace, handlers.seed(rng_seed=3):
numpyro.sample("x", dist.ZeroSumNormal(1, (3,)))
site = trace["x"]
site["value"] = None
init = init_strategy(site)
assert site["fn"].support(init)


@pytest.mark.parametrize(
"init_strategy",
[
Expand Down