Skip to content

Commit 9a19304

Browse files
GWealecopybara-github
authored andcommitted
fix: preserve interaction ids for interactions SSE tool calls
Close #5169 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 897334695
1 parent 5fab983 commit 9a19304

File tree

2 files changed

+236
-3
lines changed

2 files changed

+236
-3
lines changed

src/google/adk/models/interactions_utils.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,36 @@
5656
_NEW_LINE = '\n'
5757

5858

59+
def _extract_stream_interaction_id(
60+
event: 'InteractionSSEEvent',
61+
) -> Optional[str]:
62+
"""Extract the interaction ID from an Interactions SSE event.
63+
64+
Different SSE lifecycle events expose the interaction ID on different
65+
attributes. We normalize them here so streamed ADK responses consistently
66+
carry the chain identifier needed for follow-up tool calls. Older
67+
google-genai builds may also yield a legacy ``interaction`` event with a
68+
top-level ``id``.
69+
"""
70+
from google.genai._interactions.types.interaction_complete_event import InteractionCompleteEvent
71+
from google.genai._interactions.types.interaction_start_event import InteractionStartEvent
72+
from google.genai._interactions.types.interaction_status_update import InteractionStatusUpdate
73+
74+
if isinstance(event, InteractionStatusUpdate):
75+
return event.interaction_id
76+
77+
if isinstance(event, (InteractionStartEvent, InteractionCompleteEvent)):
78+
return event.interaction.id
79+
80+
try:
81+
if event.event_type == 'interaction':
82+
return event.id
83+
except AttributeError:
84+
pass
85+
86+
return None
87+
88+
5989
def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]:
6090
"""Convert a types.Part to an interaction content dict.
6191
@@ -1013,9 +1043,9 @@ async def generate_content_via_interactions(
10131043
# Log the streaming event
10141044
logger.debug(build_interactions_event_log(event))
10151045

1016-
# Extract interaction ID from event if available
1017-
if hasattr(event, 'id') and event.id:
1018-
current_interaction_id = event.id
1046+
interaction_id = _extract_stream_interaction_id(event)
1047+
if interaction_id:
1048+
current_interaction_id = interaction_id
10191049
llm_response = convert_interaction_event_to_llm_response(
10201050
event, aggregated_parts, current_interaction_id
10211051
)

tests/unittests/models/test_interactions_utils.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,183 @@
1414

1515
"""Tests for interactions_utils.py conversion functions."""
1616

17+
import asyncio
1718
import base64
19+
from collections.abc import Callable
20+
from datetime import datetime
21+
from datetime import timezone
1822
import json
23+
from types import SimpleNamespace
1924
from unittest.mock import MagicMock
2025

2126
from google.adk.models import interactions_utils
2227
from google.adk.models.llm_request import LlmRequest
2328
from google.genai import types
29+
from google.genai._interactions.types.interaction import Interaction
30+
from google.genai._interactions.types.interaction_complete_event import InteractionCompleteEvent
31+
from google.genai._interactions.types.interaction_start_event import InteractionStartEvent
32+
from google.genai._interactions.types.interaction_status_update import InteractionStatusUpdate
33+
import pytest
34+
35+
36+
class _MockAsyncIterator:
37+
"""Simple async iterator for streaming interaction events."""
38+
39+
def __init__(self, sequence: list[object]):
40+
self._iterator = iter(sequence)
41+
42+
def __aiter__(self):
43+
return self
44+
45+
async def __anext__(self):
46+
try:
47+
return next(self._iterator)
48+
except StopIteration as exc:
49+
raise StopAsyncIteration from exc
50+
51+
52+
class _FakeInteractions:
53+
"""Minimal fake interactions resource for streaming tests."""
54+
55+
def __init__(self, events: list[object]):
56+
self._events = events
57+
58+
async def create(self, **_kwargs):
59+
return _MockAsyncIterator(self._events)
60+
61+
62+
class _FakeAio:
63+
"""Namespace matching the expected api_client.aio shape."""
64+
65+
def __init__(self, events: list[object]):
66+
self.interactions = _FakeInteractions(events)
67+
68+
69+
class _FakeApiClient:
70+
"""Minimal fake API client for generate_content_via_interactions tests."""
71+
72+
def __init__(self, events: list[object]):
73+
self.aio = _FakeAio(events)
74+
75+
76+
def _build_function_call_delta_event(
77+
*, function_id: str, name: str, arguments: dict[str, object]
78+
) -> SimpleNamespace:
79+
"""Build a version-agnostic content.delta event for a function call."""
80+
return SimpleNamespace(
81+
event_type='content.delta',
82+
delta=SimpleNamespace(
83+
type='function_call',
84+
id=function_id,
85+
name=name,
86+
arguments=arguments,
87+
),
88+
)
89+
90+
91+
def _build_llm_request() -> LlmRequest:
92+
"""Build a minimal request for interactions streaming tests."""
93+
return LlmRequest(
94+
model='gemini-2.5-flash',
95+
contents=[
96+
types.Content(
97+
role='user',
98+
parts=[types.Part(text='Weather in Tokyo?')],
99+
)
100+
],
101+
config=types.GenerateContentConfig(),
102+
)
103+
104+
105+
def _build_lifecycle_streamed_events() -> list[object]:
106+
"""Build streamed events with lifecycle updates carrying the ID."""
107+
now = datetime.now(timezone.utc)
108+
return [
109+
InteractionStartEvent(
110+
event_type='interaction.start',
111+
interaction=Interaction(
112+
id='interaction_123',
113+
created=now,
114+
updated=now,
115+
status='in_progress',
116+
),
117+
),
118+
_build_function_call_delta_event(
119+
function_id='call_1',
120+
name='get_weather',
121+
arguments={'city': 'Tokyo'},
122+
),
123+
InteractionStatusUpdate(
124+
event_type='interaction.status_update',
125+
interaction_id='interaction_123',
126+
status='requires_action',
127+
),
128+
]
129+
130+
131+
def _build_complete_streamed_events() -> list[object]:
132+
"""Build streamed events with the ID on an interaction.complete event."""
133+
now = datetime.now(timezone.utc)
134+
return [
135+
_build_function_call_delta_event(
136+
function_id='call_1',
137+
name='get_weather',
138+
arguments={'city': 'Tokyo'},
139+
),
140+
InteractionCompleteEvent(
141+
event_type='interaction.complete',
142+
interaction=Interaction(
143+
id='interaction_complete_123',
144+
created=now,
145+
updated=now,
146+
status='requires_action',
147+
),
148+
),
149+
]
150+
151+
152+
def _build_legacy_streamed_events() -> list[object]:
153+
"""Build streamed events with the ID on the legacy interaction event."""
154+
return [
155+
_build_function_call_delta_event(
156+
function_id='call_1',
157+
name='get_weather',
158+
arguments={'city': 'Tokyo'},
159+
),
160+
SimpleNamespace(
161+
event_type='interaction',
162+
id='interaction_legacy_123',
163+
status='requires_action',
164+
error=None,
165+
outputs=None,
166+
usage=None,
167+
),
168+
]
169+
170+
171+
async def _collect_function_call_interaction_ids(
172+
streamed_events: list[object],
173+
) -> list[str | None]:
174+
"""Collect non-partial function call interaction IDs from streamed events."""
175+
responses = [
176+
response
177+
async for response in (
178+
interactions_utils.generate_content_via_interactions(
179+
api_client=_FakeApiClient(streamed_events),
180+
llm_request=_build_llm_request(),
181+
stream=True,
182+
)
183+
)
184+
]
185+
186+
return [
187+
response.interaction_id
188+
for response in responses
189+
if response.partial is not True
190+
and response.content is not None
191+
and response.content.parts
192+
and response.content.parts[0].function_call is not None
193+
]
24194

25195

26196
class TestConvertPartToInteractionContent:
@@ -955,3 +1125,36 @@ def test_unknown_event_type_returns_none(self):
9551125

9561126
assert result is None
9571127
assert not aggregated_parts
1128+
1129+
1130+
@pytest.mark.parametrize(
1131+
('streamed_events_factory', 'expected_ids'),
1132+
[
1133+
pytest.param(
1134+
_build_lifecycle_streamed_events,
1135+
['interaction_123', 'interaction_123'],
1136+
id='lifecycle-events',
1137+
),
1138+
pytest.param(
1139+
_build_complete_streamed_events,
1140+
['interaction_complete_123'],
1141+
id='complete-event',
1142+
),
1143+
pytest.param(
1144+
_build_legacy_streamed_events,
1145+
['interaction_legacy_123'],
1146+
id='legacy-event',
1147+
),
1148+
],
1149+
)
1150+
def test_generate_content_via_interactions_stream_extracts_interaction_id(
1151+
streamed_events_factory: Callable[[], list[object]],
1152+
expected_ids: list[str],
1153+
):
1154+
"""Streamed interaction IDs should be preserved across event variants."""
1155+
streamed_events = streamed_events_factory()
1156+
1157+
assert (
1158+
asyncio.run(_collect_function_call_interaction_ids(streamed_events))
1159+
== expected_ids
1160+
)

0 commit comments

Comments
 (0)