Skip to content

Commit 7744cfe

Browse files
google-genai-botcopybara-github
authored andcommitted
feat(mcp): gracefully handle tool execution errors and transport crashes
Previously, if an MCP tool returned a JSON-RPC error (e.g. 403 Forbidden) or if the underlying transport connection crashed, the resulting exceptions (McpError and ConnectionError) would bubble up and crash the entire ADK runner. This change introduces robust error boundaries for MCP tools: - `McpTool.run_async()` now catches `McpError` and general exceptions, returning them as structured error dictionaries `{"error": ...}` to the LLM agent so the conversation can continue gracefully. - `SessionContext` races tool calls against the background session task so transport crashes surface immediately instead of hanging. - Fixes an AnyIO cancellation scope bug ("Attempted to exit cancel scope in a different task") by removing redundant `asyncio.wait_for` wrappers around exit stack context entry. - Connection errors trigger automatic retries via `@retry_on_errors` before finally surfacing the failure to the agent. Fixes #4901, #4162 PiperOrigin-RevId: 902981801
1 parent 7ae83b2 commit 7744cfe

File tree

9 files changed

+227
-760
lines changed

9 files changed

+227
-760
lines changed

src/google/adk/tools/load_mcp_resource_tool.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from .base_tool import BaseTool
3030

3131
if TYPE_CHECKING:
32-
from .mcp_tool.mcp_toolset import McpToolset
32+
from mcp_toolset import McpToolset
33+
3334
from .tool_context import ToolContext
3435

3536
logger = logging.getLogger("google_adk." + __name__)
@@ -38,7 +39,7 @@
3839
class LoadMcpResourceTool(BaseTool):
3940
"""A tool that loads the MCP resources and adds them to the session."""
4041

41-
def __init__(self, mcp_toolset: McpToolset) -> None:
42+
def __init__(self, mcp_toolset: McpToolset):
4243
super().__init__(
4344
name="load_mcp_resource",
4445
description="""Loads resources from the MCP server.

src/google/adk/tools/mcp_tool/__init__.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,17 @@
1717
try:
1818
from .conversion_utils import adk_to_mcp_tool_type
1919
from .conversion_utils import gemini_to_json_schema
20-
from .mcp_session_manager import MCPSessionManager as MCPSessionManager
21-
from .mcp_session_manager import SseConnectionParams as SseConnectionParams
22-
from .mcp_session_manager import StdioConnectionParams as StdioConnectionParams
23-
from .mcp_session_manager import StreamableHTTPConnectionParams as StreamableHTTPConnectionParams
24-
from .mcp_tool import MCPTool as MCPTool
25-
from .mcp_tool import McpTool as McpTool
26-
from .mcp_toolset import MCPToolset as MCPToolset
27-
from .mcp_toolset import McpToolset as McpToolset
20+
from .mcp_session_manager import SseConnectionParams
21+
from .mcp_session_manager import StdioConnectionParams
22+
from .mcp_session_manager import StreamableHTTPConnectionParams
23+
from .mcp_tool import MCPTool
24+
from .mcp_tool import McpTool
25+
from .mcp_toolset import MCPToolset
26+
from .mcp_toolset import McpToolset
2827

2928
__all__.extend([
3029
'adk_to_mcp_tool_type',
3130
'gemini_to_json_schema',
32-
'MCPSessionManager',
3331
'McpTool',
3432
'MCPTool',
3533
'McpToolset',

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 49 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
import asyncio
1818
from collections import deque
19-
import concurrent.futures
2019
from contextlib import AsyncExitStack
21-
from dataclasses import dataclass
2220
from datetime import timedelta
2321
import functools
2422
import hashlib
@@ -27,20 +25,13 @@
2725
import sys
2826
import threading
2927
from typing import Any
30-
from typing import Callable
31-
from typing import cast
3228
from typing import Dict
3329
from typing import Optional
3430
from typing import Protocol
3531
from typing import runtime_checkable
3632
from typing import TextIO
37-
from typing import TYPE_CHECKING
38-
from typing import TypeVar
3933
from typing import Union
4034

41-
if TYPE_CHECKING:
42-
from .session_context import SessionContext
43-
4435
from mcp import ClientSession
4536
from mcp import SamplingCapability
4637
from mcp import StdioServerParameters
@@ -53,6 +44,8 @@
5344
from pydantic import BaseModel
5445
from pydantic import ConfigDict
5546

47+
from .session_context import SessionContext
48+
5649
logger = logging.getLogger('google_adk.' + __name__)
5750

5851

@@ -153,10 +146,7 @@ class StreamableHTTPConnectionParams(BaseModel):
153146
httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client
154147

155148

156-
_F = TypeVar('_F', bound=Callable[..., Any])
157-
158-
159-
def retry_on_errors(func: _F) -> _F:
149+
def retry_on_errors(func):
160150
"""Decorator to automatically retry action when MCP session errors occur.
161151
162152
When MCP session errors occur, the decorator will automatically retry the
@@ -175,7 +165,7 @@ def retry_on_errors(func: _F) -> _F:
175165
"""
176166

177167
@functools.wraps(func) # Preserves original function metadata
178-
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
168+
async def wrapper(self, *args, **kwargs):
179169
try:
180170
return await func(self, *args, **kwargs)
181171
except Exception as e:
@@ -192,17 +182,7 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
192182
logger.info('Retrying %s due to error: %s', func.__name__, e)
193183
return await func(self, *args, **kwargs)
194184

195-
return cast(_F, wrapper)
196-
197-
198-
@dataclass
199-
class _SessionEntry:
200-
"""A dataclass to hold session information."""
201-
202-
session: ClientSession
203-
exit_stack: AsyncExitStack
204-
loop: asyncio.AbstractEventLoop
205-
context: SessionContext
185+
return wrapper
206186

207187

208188
class MCPSessionManager:
@@ -225,7 +205,7 @@ def __init__(
225205
*,
226206
sampling_callback: Optional[SamplingFnT] = None,
227207
sampling_capabilities: Optional[SamplingCapability] = None,
228-
) -> None:
208+
):
229209
"""Initializes the MCP session manager.
230210
231211
Args:
@@ -257,8 +237,10 @@ def __init__(
257237
self._connection_params = connection_params
258238
self._errlog = errlog
259239

260-
# Session pool: maps session keys to _SessionEntry objects
261-
self._sessions: Dict[str, _SessionEntry] = {}
240+
# Session pool: maps session keys to (session, exit_stack, loop) tuples
241+
self._sessions: Dict[
242+
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
243+
] = {}
262244

263245
# Map of event loops to their respective locks to prevent race conditions
264246
# across different event loops in session creation.
@@ -330,66 +312,35 @@ def _merge_headers(
330312

331313
return base_headers
332314

333-
def _is_session_disconnected(
334-
self,
335-
entry: _SessionEntry,
336-
) -> bool:
315+
def _is_session_disconnected(self, session: ClientSession) -> bool:
337316
"""Checks if a session is disconnected or closed.
338317
339318
Args:
340-
entry: The _SessionEntry to check.
319+
session: The ClientSession to check.
341320
342321
Returns:
343322
True if the session is disconnected, False otherwise.
344323
"""
345-
if (
346-
entry.session._read_stream._closed
347-
or entry.session._write_stream._closed
348-
):
349-
return True
350-
if entry.context is not None and not entry.context._is_task_alive: # pylint: disable=protected-access
351-
return True
352-
return False
353-
354-
def _get_session_context(
355-
self, headers: Optional[Dict[str, str]] = None
356-
) -> Optional['SessionContext']:
357-
"""Returns the SessionContext for the session matching the given headers.
358-
359-
Note: This method reads from the session pool without acquiring
360-
``_session_lock``. This is safe because it is called immediately after
361-
``create_session()`` (which populates the entry under the lock) within
362-
the same task, and dict reads are atomic in CPython.
363-
364-
Args:
365-
headers: Optional headers used to identify the session.
366-
367-
Returns:
368-
The SessionContext if a matching session exists, None otherwise.
369-
"""
370-
merged_headers = self._merge_headers(headers)
371-
session_key = self._generate_session_key(merged_headers)
372-
entry = self._sessions.get(session_key)
373-
if entry is not None:
374-
return entry.context
375-
return None
324+
return session._read_stream._closed or session._write_stream._closed
376325

377326
async def _cleanup_session(
378327
self,
379328
session_key: str,
380-
entry: _SessionEntry,
381-
) -> None:
329+
exit_stack: AsyncExitStack,
330+
stored_loop: asyncio.AbstractEventLoop,
331+
):
382332
"""Cleans up a session, handling different event loops safely.
383333
384334
Args:
385335
session_key: The session key to clean up.
386-
entry: The _SessionEntry managing the session resources.
336+
exit_stack: The AsyncExitStack managing the session resources.
337+
stored_loop: The event loop on which the session was created.
387338
"""
388339
current_loop = asyncio.get_running_loop()
389340
try:
390-
if entry.loop is current_loop:
391-
await entry.exit_stack.aclose()
392-
elif entry.loop.is_closed():
341+
if stored_loop is current_loop:
342+
await exit_stack.aclose()
343+
elif stored_loop.is_closed():
393344
logger.warning(
394345
f'Error cleaning up session {session_key}: original event loop'
395346
' is closed, resources may be leaked.'
@@ -402,11 +353,11 @@ async def _cleanup_session(
402353
' event loop.'
403354
)
404355
future = asyncio.run_coroutine_threadsafe(
405-
entry.exit_stack.aclose(), entry.loop
356+
exit_stack.aclose(), stored_loop
406357
)
407358

408359
# Attach a callback so errors don't go unnoticed
409-
def cleanup_done(f: 'concurrent.futures.Future[Any]') -> None:
360+
def cleanup_done(f: asyncio.Future):
410361
try:
411362
if f.exception():
412363
logger.warning(
@@ -428,9 +379,7 @@ def cleanup_done(f: 'concurrent.futures.Future[Any]') -> None:
428379
if session_key in self._sessions:
429380
del self._sessions[session_key]
430381

431-
def _create_client(
432-
self, merged_headers: Optional[Dict[str, str]] = None
433-
) -> Any:
382+
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
434383
"""Creates an MCP client based on the connection parameters.
435384
436385
Args:
@@ -502,22 +451,22 @@ async def create_session(
502451
async with self._session_lock:
503452
# Check if we have an existing session
504453
if session_key in self._sessions:
505-
entry = self._sessions[session_key]
454+
session, exit_stack, stored_loop = self._sessions[session_key]
506455

507456
# Check if the existing session is still connected and bound to the current loop
508457
current_loop = asyncio.get_running_loop()
509-
if entry.loop is current_loop and not self._is_session_disconnected(
510-
entry
458+
if stored_loop is current_loop and not self._is_session_disconnected(
459+
session
511460
):
512461
# Session is still good, return it
513-
return entry.session
462+
return session
514463
else:
515464
# Session is disconnected or from a different loop, clean it up
516465
logger.info(
517466
'Cleaning up session (disconnected or different loop): %s',
518467
session_key,
519468
)
520-
await self._cleanup_session(session_key, entry)
469+
await self._cleanup_session(session_key, exit_stack, stored_loop)
521470

522471
# Create a new session (either first time or replacing disconnected one)
523472
exit_stack = AsyncExitStack()
@@ -533,30 +482,28 @@ async def create_session(
533482
)
534483

535484
try:
536-
from .session_context import SessionContext
537-
538485
client = self._create_client(merged_headers)
539486
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
540487

541-
session_context = SessionContext(
542-
client=client,
543-
timeout=timeout_in_seconds,
544-
sse_read_timeout=sse_read_timeout_in_seconds,
545-
is_stdio=is_stdio,
546-
sampling_callback=self._sampling_callback,
547-
sampling_capabilities=self._sampling_capabilities,
548-
)
549488
session = await asyncio.wait_for(
550-
exit_stack.enter_async_context(session_context),
489+
exit_stack.enter_async_context(
490+
SessionContext(
491+
client=client,
492+
timeout=timeout_in_seconds,
493+
sse_read_timeout=sse_read_timeout_in_seconds,
494+
is_stdio=is_stdio,
495+
sampling_callback=self._sampling_callback,
496+
sampling_capabilities=self._sampling_capabilities,
497+
)
498+
),
551499
timeout=timeout_in_seconds,
552500
)
553501

554-
# Store session, exit stack, loop, and context in the pool
555-
self._sessions[session_key] = _SessionEntry(
556-
session=session,
557-
exit_stack=exit_stack,
558-
loop=asyncio.get_running_loop(),
559-
context=session_context,
502+
# Store session, exit stack, and loop in the pool
503+
self._sessions[session_key] = (
504+
session,
505+
exit_stack,
506+
asyncio.get_running_loop(),
560507
)
561508
logger.debug('Created new session: %s', session_key)
562509
return session
@@ -572,7 +519,7 @@ async def create_session(
572519
)
573520
raise ConnectionError(f'Failed to create MCP session: {e}') from e
574521

575-
def __getstate__(self) -> Dict[str, Any]:
522+
def __getstate__(self):
576523
"""Custom pickling to exclude non-picklable runtime objects."""
577524
state = self.__dict__.copy()
578525
# Remove unpicklable entries or those that shouldn't persist across pickle
@@ -585,7 +532,7 @@ def __getstate__(self) -> Dict[str, Any]:
585532

586533
return state
587534

588-
def __setstate__(self, state: Dict[str, Any]) -> None:
535+
def __setstate__(self, state):
589536
"""Custom unpickling to restore state."""
590537
self.__dict__.update(state)
591538
# Re-initialize members that were not pickled
@@ -596,12 +543,12 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
596543
if not hasattr(self, '_errlog') or self._errlog is None:
597544
self._errlog = sys.stderr
598545

599-
async def close(self) -> None:
546+
async def close(self):
600547
"""Closes all sessions and cleans up resources."""
601548
async with self._session_lock:
602549
for session_key in list(self._sessions.keys()):
603-
entry = self._sessions[session_key]
604-
await self._cleanup_session(session_key, entry)
550+
_, exit_stack, stored_loop = self._sessions[session_key]
551+
await self._cleanup_session(session_key, exit_stack, stored_loop)
605552

606553

607554
SseServerParams = SseConnectionParams

0 commit comments

Comments
 (0)