Skip to content

Pass value arg to optax, allowing use of reduce_on_plateau#1974

Merged
fehiepsi merged 8 commits intopyro-ppl:masterfrom
zmbc:optax_pass_value_arg
Mar 10, 2025
Merged

Pass value arg to optax, allowing use of reduce_on_plateau#1974
fehiepsi merged 8 commits intopyro-ppl:masterfrom
zmbc:optax_pass_value_arg

Conversation

@zmbc
Copy link
Copy Markdown
Contributor

@zmbc zmbc commented Feb 12, 2025

Addresses #1955 using the approach from this comment: #1955 (comment)

@zmbc
Copy link
Copy Markdown
Contributor Author

zmbc commented Feb 12, 2025

@fehiepsi Let me know what you think! I haven't tested this code, and I might need some help with adding types as well.

Comment thread numpyro/optim.py Outdated
Comment thread numpyro/optim.py Outdated
Comment thread numpyro/optim.py Outdated
"""
i, opt_state = state
opt_state = self.update_fn(i, g, opt_state)
opt_state = self.update_fn(i, g, opt_state, value=value)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can introduce an attribute update_with_value (false by default) to control the behavior here

if self.update_with_value=True:
    opt_state = self.update_fn(i, g, opt_state, value)
else:
    opt_state = self.update_fn(i, g, opt_state)

you can then use the typing

self.update_fn: Union[Callable[[ArrayLike, _Params, _OptState], _OptState], Callable[[ArrayLike, _Params, _OptState, ArrayLike], _OptState]]

In optax optimizer, you can set it to True

numpyro_optim = _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)
numpyro_optim.update_with_value = True
return numpyro_optim

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I can't easily get this typing to pass mypy. I'd need to use some kind of custom TypeGuard to narrow this Union type down to the specific Callable signature at runtime before it is called. Do you want me to do that? Or would you rather I changed this to simply Callable[..., _OptState]?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi Any pointers here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed the message. I think you can define a protocol for the purpose

class _UpdateFn(Protocol):
    def __call__(self, arr: ArrayLike, params: _Params, opt_state: _OptState, *args: ArrayLike) -> _OptState:
        ...

Copy link
Copy Markdown
Collaborator

@tillahoffmann tillahoffmann Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be possible to reuse the optax protocol here, although the OptStates may not be compatible.

@zmbc
Copy link
Copy Markdown
Contributor Author

zmbc commented Feb 20, 2025

@fehiepsi thanks for your comments, want to take another look?

Copy link
Copy Markdown
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a simple test for this?

Comment thread numpyro/optim.py Outdated
@fehiepsi
Copy link
Copy Markdown
Member

fehiepsi commented Mar 6, 2025

@zmbc, to test the optimizer, you can simply add the new one to this line

(optax.sgd, (1e-2,), {}),

@fehiepsi fehiepsi added this to the 0.18 milestone Mar 7, 2025
@zmbc
Copy link
Copy Markdown
Contributor Author

zmbc commented Mar 7, 2025

@fehiepsi I think I've got this all working except that the test test_numpyrooptim_no_double_jit is failing on this optimizer with

>       assert my_fn_calls == 1
E       assert 2 == 1

test/test_optimizers.py:125: AssertionError

My understanding of JAX and JIT is pretty basic, any ideas why this might be happening?

Copy link
Copy Markdown
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for supporting those new optimizers, @zmbc!

@fehiepsi
Copy link
Copy Markdown
Member

fehiepsi commented Mar 7, 2025

test_numpyrooptim_no_double_jit is failing

I think the init state and the later states of the new optimizer have different dtype. Could you double check? If it is the case, you can assert <= 2 for that specific case.

@zmbc
Copy link
Copy Markdown
Contributor Author

zmbc commented Mar 8, 2025

@fehiepsi You are correct and I'm glad I asked; that would have taken me a long time to figure out!

The difference in type is small (see comment):

(
    Traced<ShapedArray(int32[], weak_type=True)> with <DynamicJaxprTrace>,
    (
        Traced<ShapedArray(float32[10])> with <DynamicJaxprTrace>,
        (
            (
                ScaleByAdamState(
                    count=Traced<ShapedArray(int32[])> with <DynamicJaxprTrace>,
                    mu=Traced<ShapedArray(float32[10])> with <DynamicJaxprTrace>,
                    nu=Traced<ShapedArray(float32[10])> with <DynamicJaxprTrace>
                ),
                EmptyState()
            ),
            ReduceLROnPlateauState(
                scale=Traced<ShapedArray(float32[])> with <DynamicJaxprTrace>,
                best_value=Traced<ShapedArray(float32[], weak_type=True)> with <DynamicJaxprTrace>,
                plateau_count=Traced<ShapedArray(int32[])> with <DynamicJaxprTrace>,
                cooldown_count=Traced<ShapedArray(int32[])> with <DynamicJaxprTrace>,
                count=Traced<ShapedArray(int32[])> with <DynamicJaxprTrace>,
                avg_value=Traced<ShapedArray(float32[],
                  weak_type=True, # This is only present on the first call
                )> with <DynamicJaxprTrace>
            )
        )
    )
)

Does that seem correct to you? I'll update the test.

Copy link
Copy Markdown
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add value=loss_val here

optim_state = self.optim.update(grads, optim_state)
to make tests pass?

@fehiepsi
Copy link
Copy Markdown
Member

There is a conflict. Could you merge with the master branch?

@zmbc
Copy link
Copy Markdown
Contributor Author

zmbc commented Mar 10, 2025

Sorry about that, should be resolved now!

@fehiepsi fehiepsi merged commit bdb3329 into pyro-ppl:master Mar 10, 2025
@zmbc zmbc deleted the optax_pass_value_arg branch March 10, 2025 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants