Skip to content

Commit 0ef3d8c

Browse files
committed
fix(auth): persist refreshed OAuth2 credentials to store
After `OAuth2CredentialRefresher.refresh()` rotates the tokens in memory, the updated credential was never written back to the credential store. On the next tool invocation, `get_credential()` deserialized the stale pre-refresh dict and returned expired tokens. For providers that rotate refresh_tokens on each refresh (Salesforce, many OIDC providers), this caused the subsequent refresh attempt to fail — the old refresh_token was already invalidated — forcing a full re-authorization flow. The fix calls `_store_credential()` immediately after a successful refresh so that the new access_token and refresh_token are persisted. Fixes #5329
1 parent ce113a8 commit 0ef3d8c

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ async def _get_existing_credential(
242242
existing_credential = await refresher.refresh(
243243
existing_credential, self.auth_scheme
244244
)
245+
# Persist the refreshed credential so the next invocation
246+
# reads the new tokens instead of the stale pre-refresh ones.
247+
# Without this, providers that rotate refresh_tokens on each
248+
# refresh (e.g. Salesforce, many OIDC providers) will fail
249+
# because the old refresh_token has already been invalidated.
250+
self._store_credential(existing_credential)
245251
return existing_credential
246252
return None
247253

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,75 @@ async def test_openid_connect_existing_oauth2_token_refresh(
292292
assert result.state == 'done'
293293
# The result should contain the refreshed credential after exchange
294294
assert result.auth_credential is not None
295+
296+
297+
@patch(
298+
'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialRefresher'
299+
)
300+
@pytest.mark.asyncio
301+
async def test_refreshed_credential_is_persisted_to_store(
302+
mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential
303+
):
304+
"""Test that refreshed OAuth2 credentials are persisted back to the store.
305+
306+
Regression test for https://github.com/google/adk-python/issues/5329.
307+
Without persisting, the next invocation reads stale pre-refresh tokens from
308+
state. Providers that rotate refresh_tokens on each refresh (e.g.
309+
Salesforce, many OIDC providers) will then fail because the old
310+
refresh_token has already been invalidated.
311+
"""
312+
# Create existing OAuth2 credential with an "old" refresh token.
313+
existing_credential = AuthCredential(
314+
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
315+
oauth2=OAuth2Auth(
316+
client_id='test_client_id',
317+
client_secret='test_client_secret',
318+
access_token='old_access_token',
319+
refresh_token='old_refresh_token',
320+
),
321+
)
322+
323+
# The refresher will return a credential with rotated tokens.
324+
refreshed_credential = AuthCredential(
325+
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
326+
oauth2=OAuth2Auth(
327+
client_id='test_client_id',
328+
client_secret='test_client_secret',
329+
access_token='new_access_token',
330+
refresh_token='new_refresh_token',
331+
),
332+
)
333+
334+
from unittest.mock import AsyncMock
335+
336+
mock_refresher_instance = MagicMock()
337+
mock_refresher_instance.is_refresh_needed = AsyncMock(return_value=True)
338+
mock_refresher_instance.refresh = AsyncMock(return_value=refreshed_credential)
339+
mock_oauth2_refresher.return_value = mock_refresher_instance
340+
341+
tool_context = create_mock_tool_context()
342+
credential_store = ToolContextCredentialStore(tool_context=tool_context)
343+
344+
# Store the existing (stale) credential.
345+
key = credential_store.get_credential_key(
346+
openid_connect_scheme, openid_connect_credential
347+
)
348+
credential_store.store_credential(key, existing_credential)
349+
350+
handler = ToolAuthHandler(
351+
tool_context,
352+
openid_connect_scheme,
353+
openid_connect_credential,
354+
credential_store=credential_store,
355+
)
356+
357+
await handler.prepare_auth_credentials()
358+
359+
# The critical assertion: the *refreshed* credential must now be in the
360+
# store so that the next invocation reads the new tokens, not the old ones.
361+
persisted = credential_store.get_credential(
362+
openid_connect_scheme, openid_connect_credential
363+
)
364+
assert persisted is not None
365+
assert persisted.oauth2.access_token == 'new_access_token'
366+
assert persisted.oauth2.refresh_token == 'new_refresh_token'

0 commit comments

Comments
 (0)