Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,26 @@ class AutoGuide(ABC):
``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
or iterable of plates. Plates not returned will be created
automatically as usual. This is useful for data subsampling.
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
during model initialization. Defaults to False. This is useful for models that
contain JAX primitives which are not supported by reverse-mode differentiation
(e.g. :func:`jax.lax.while_loop`).
"""

def __init__(
self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None
self,
model,
*,
prefix="auto",
init_loc_fn=init_to_uniform,
create_plates=None,
forward_mode_differentiation=False,
):
self.model = model
self.prefix = prefix
self.init_loc_fn = init_loc_fn
self.create_plates = create_plates
self._forward_mode_differentiation = forward_mode_differentiation
self.prototype_trace = None
self._prototype_frames = {}
self._prototype_frame_full_sizes = {}
Expand Down Expand Up @@ -164,6 +175,7 @@ def _setup_prototype(self, *args, **kwargs):
dynamic_args=True,
model_args=args,
model_kwargs=kwargs,
forward_mode_differentiation=self._forward_mode_differentiation,
)
self._potential_fn = self._potential_fn_gen(*args, **kwargs)
postprocess_fn = postprocess_fn_gen(*args, **kwargs)
Expand Down Expand Up @@ -246,14 +258,26 @@ class AutoGuideList(AutoGuide):
params = svi.get_params(svi_state)

:param callable model: a NumPyro model
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
during model initialization. Defaults to False.
"""

def __init__(
self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None
self,
model,
*,
prefix="auto",
init_loc_fn=init_to_uniform,
create_plates=None,
forward_mode_differentiation=False,
):
self._guides = []
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
model,
prefix=prefix,
init_loc_fn=init_loc_fn,
create_plates=create_plates,
forward_mode_differentiation=forward_mode_differentiation,
)

def append(self, part):
Expand Down Expand Up @@ -363,6 +387,8 @@ class AutoNormal(AutoGuide):
``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
or iterable of plates. Plates not returned will be created
automatically as usual. This is useful for data subsampling.
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
during model initialization. Defaults to False.
"""

scale_constraint = constraints.softplus_positive
Expand All @@ -375,11 +401,16 @@ def __init__(
init_loc_fn=init_to_uniform,
init_scale=0.1,
create_plates=None,
forward_mode_differentiation=False,
):
self._init_scale = init_scale
self._event_dims = {}
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
model,
prefix=prefix,
init_loc_fn=init_loc_fn,
create_plates=create_plates,
forward_mode_differentiation=forward_mode_differentiation,
)

def _setup_prototype(self, *args, **kwargs):
Expand Down Expand Up @@ -516,14 +547,26 @@ class AutoDelta(AutoGuide):
``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
or iterable of plates. Plates not returned will be created
automatically as usual. This is useful for data subsampling.
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
during model initialization. Defaults to False.
"""

def __init__(
self, model, *, prefix="auto", init_loc_fn=init_to_median, create_plates=None
self,
model,
*,
prefix="auto",
init_loc_fn=init_to_median,
create_plates=None,
forward_mode_differentiation=False,
):
self._event_dims = {}
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
model,
prefix=prefix,
init_loc_fn=init_loc_fn,
create_plates=create_plates,
forward_mode_differentiation=forward_mode_differentiation,
)

def _setup_prototype(self, *args, **kwargs):
Expand Down Expand Up @@ -853,6 +896,8 @@ class AutoDAIS(AutoContinuous):
:param float init_scale: Initial scale for the standard deviation of
the base variational distribution for each (unconstrained transformed)
latent variable. Defaults to 0.1.
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
during model initialization. Defaults to False.
"""

def __init__(
Expand All @@ -867,6 +912,7 @@ def __init__(
prefix="auto",
init_loc_fn=init_to_uniform,
init_scale=0.1,
forward_mode_differentiation=False,
):
if K < 1:
raise ValueError("K must satisfy K >= 1 (got K = {})".format(K))
Expand All @@ -889,7 +935,12 @@ def __init__(
self.K = K
self.base_dist = base_dist
self._init_scale = init_scale
super().__init__(model, prefix=prefix, init_loc_fn=init_loc_fn)
super().__init__(
model,
prefix=prefix,
init_loc_fn=init_loc_fn,
forward_mode_differentiation=forward_mode_differentiation,
)

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
Expand Down Expand Up @@ -1083,6 +1134,8 @@ def surrogate_model(X_surr, Y_surr):
:param float init_scale: Initial scale for the standard deviation of
the base variational distribution for each (unconstrained transformed)
latent variable. Defaults to 0.1.
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
during model initialization. Defaults to False.
"""

def __init__(
Expand All @@ -1098,6 +1151,7 @@ def __init__(
base_dist="diagonal",
init_loc_fn=init_to_uniform,
init_scale=0.1,
forward_mode_differentiation=False,
):
super().__init__(
model,
Expand All @@ -1109,6 +1163,7 @@ def __init__(
init_loc_fn=init_loc_fn,
init_scale=init_scale,
base_dist=base_dist,
forward_mode_differentiation=forward_mode_differentiation,
)

self.surrogate_model = surrogate_model
Expand All @@ -1127,6 +1182,7 @@ def _setup_prototype(self, *args, **kwargs):
dynamic_args=False,
model_args=(),
model_kwargs={},
forward_mode_differentiation=self._forward_mode_differentiation,
)
)

Expand Down Expand Up @@ -1299,6 +1355,8 @@ def local_model(theta):
data points in the subsample plate) or local (i.e. each data point in the
subsample plate has individual parameters). Note that we do not use global
parameters for the base distribution.
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
during model initialization. Defaults to False.
"""

def __init__(
Expand All @@ -1316,9 +1374,14 @@ def __init__(
init_scale=0.1,
subsample_plate=None,
use_global_dais_params=False,
forward_mode_differentiation=False,
):
# init_loc_fn is only used to inspect the model.
super().__init__(model, prefix=prefix, init_loc_fn=init_to_uniform)
super().__init__(
model,
prefix=prefix,
init_loc_fn=init_to_uniform,
forward_mode_differentiation=forward_mode_differentiation,
)
if K < 1:
raise ValueError("K must satisfy K >= 1 (got K = {})".format(K))
if eta_init <= 0.0 or eta_init >= eta_max:
Expand Down
12 changes: 12 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,3 +1378,15 @@ def model(n: int, x: jnp.ndarray):
)
state = svi.init(jax.random.key(2), x=x)
svi.update(state, x=subset)


@pytest.mark.parametrize("auto_class", [AutoNormal, AutoDelta])
def test_autoguide_forward_mode_differentiation(auto_class):
def model():
x = numpyro.sample("x", dist.Normal(0, 1))
y = lax.while_loop(lambda x: x < 10, lambda x: x + 1, x)
numpyro.sample("obs", dist.Normal(y, 1), obs=1.0)

guide = auto_class(model, forward_mode_differentiation=True)
svi = SVI(model, guide, optim.Adam(0.01), loss=Trace_ELBO())
svi.run(random.key(0), 10, forward_mode_differentiation=True)
Loading