Pass value arg to optax, allowing use of reduce_on_plateau#1974
Pass value arg to optax, allowing use of reduce_on_plateau#1974fehiepsi merged 8 commits intopyro-ppl:masterfrom
Conversation
|
@fehiepsi Let me know what you think! I haven't tested this code, and I might need some help with adding types as well. |
| """ | ||
| i, opt_state = state | ||
| opt_state = self.update_fn(i, g, opt_state) | ||
| opt_state = self.update_fn(i, g, opt_state, value=value) |
There was a problem hiding this comment.
You can introduce an attribute update_with_value (false by default) to control the behavior here
if self.update_with_value=True:
opt_state = self.update_fn(i, g, opt_state, value)
else:
opt_state = self.update_fn(i, g, opt_state)
you can then use the typing
self.update_fn: Union[Callable[[ArrayLike, _Params, _OptState], _OptState], Callable[[ArrayLike, _Params, _OptState, ArrayLike], _OptState]]
In optax optimizer, you can set it to True
numpyro_optim = _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)
numpyro_optim.update_with_value = True
return numpyro_optim
There was a problem hiding this comment.
@fehiepsi I can't easily get this typing to pass mypy. I'd need to use some kind of custom TypeGuard to narrow this Union type down to the specific Callable signature at runtime before it is called. Do you want me to do that? Or would you rather I changed this to simply Callable[..., _OptState]?
There was a problem hiding this comment.
Sorry I missed the message. I think you can define a protocol for the purpose
class _UpdateFn(Protocol):
def __call__(self, arr: ArrayLike, params: _Params, opt_state: _OptState, *args: ArrayLike) -> _OptState:
...
There was a problem hiding this comment.
It might be possible to reuse the optax protocol here, although the OptStates may not be compatible.
|
@fehiepsi thanks for your comments, want to take another look? |
fehiepsi
left a comment
There was a problem hiding this comment.
Could you add a simple test for this?
|
@zmbc, to test the optimizer, you can simply add the new one to this line numpyro/test/test_optimizers.py Line 29 in d393f5a |
|
@fehiepsi I think I've got this all working except that the test My understanding of JAX and JIT is pretty basic, any ideas why this might be happening? |
I think the init state and the later states of the new optimizer have different dtype. Could you double check? If it is the case, you can assert <= 2 for that specific case. |
|
@fehiepsi You are correct and I'm glad I asked; that would have taken me a long time to figure out! The difference in type is small (see comment): (
Traced<ShapedArray(int32[], weak_type=True)> with <DynamicJaxprTrace>,
(
Traced<ShapedArray(float32[10])> with <DynamicJaxprTrace>,
(
(
ScaleByAdamState(
count=Traced<ShapedArray(int32[])> with <DynamicJaxprTrace>,
mu=Traced<ShapedArray(float32[10])> with <DynamicJaxprTrace>,
nu=Traced<ShapedArray(float32[10])> with <DynamicJaxprTrace>
),
EmptyState()
),
ReduceLROnPlateauState(
scale=Traced<ShapedArray(float32[])> with <DynamicJaxprTrace>,
best_value=Traced<ShapedArray(float32[], weak_type=True)> with <DynamicJaxprTrace>,
plateau_count=Traced<ShapedArray(int32[])> with <DynamicJaxprTrace>,
cooldown_count=Traced<ShapedArray(int32[])> with <DynamicJaxprTrace>,
count=Traced<ShapedArray(int32[])> with <DynamicJaxprTrace>,
avg_value=Traced<ShapedArray(float32[],
weak_type=True, # This is only present on the first call
)> with <DynamicJaxprTrace>
)
)
)
)Does that seem correct to you? I'll update the test. |
fehiepsi
left a comment
There was a problem hiding this comment.
Could you add value=loss_val here
numpyro/numpyro/contrib/einstein/steinvi.py
Line 470 in 374cd89
|
There is a conflict. Could you merge with the master branch? |
|
Sorry about that, should be resolved now! |
Addresses #1955 using the approach from this comment: #1955 (comment)