Skip to content

Add pupport forward-mode differentiation with autoguide#2163

Merged
fehiepsi merged 1 commit intopyro-ppl:masterfrom
juanitorduz:forward-diff-autoguide
Mar 30, 2026
Merged

Add pupport forward-mode differentiation with autoguide#2163
fehiepsi merged 1 commit intopyro-ppl:masterfrom
juanitorduz:forward-diff-autoguide

Conversation

@juanitorduz
Copy link
Copy Markdown
Collaborator

Closes #2068


Summary

Adds forward_mode_differentiation parameter to all autoguide classes, fixing #2068.

AutoGuide._setup_prototype calls initialize_model to inspect the model and find valid initial parameters. Previously, it always used reverse-mode differentiation (the default), which fails for models containing JAX primitives that don't support reverse-mode AD (e.g., jax.lax.while_loop, jax.numpy.linalg.eigh). Users could pass forward_mode_differentiation=True to SVI.run/SVI.update to control differentiation of the ELBO loss, but the autoguide's internal initialize_model call was unaffected.

This PR adds a forward_mode_differentiation keyword argument (default False) to AutoGuide.__init__ and threads it through to all subclasses:

  • AutoGuide (base class) — stores the flag and passes it to initialize_model in _setup_prototype
  • AutoGuideList
  • AutoNormal
  • AutoDelta
  • AutoDAIS
  • AutoSurrogateLikelihoodDAIS — also passes the flag to the second initialize_model call for the surrogate model
  • AutoSemiDAIS

AutoBatchedMixin uses *args, **kwargs and forwards automatically.

Usage

guide = AutoNormal(model, forward_mode_differentiation=True)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi.run(rng_key, num_steps, forward_mode_differentiation=True)

@juanitorduz juanitorduz marked this pull request as draft March 29, 2026 17:37
@juanitorduz juanitorduz self-assigned this Mar 29, 2026
@juanitorduz juanitorduz marked this pull request as ready for review March 29, 2026 18:35
@juanitorduz juanitorduz requested a review from fehiepsi March 29, 2026 18:35
Copy link
Copy Markdown
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @juanitorduz!

@fehiepsi fehiepsi merged commit 781b0e7 into pyro-ppl:master Mar 30, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support forward-mode differentiation with autoguide.

2 participants