1616
1717import asyncio
1818from collections import deque
19- import concurrent .futures
2019from contextlib import AsyncExitStack
21- from dataclasses import dataclass
2220from datetime import timedelta
2321import functools
2422import hashlib
2725import sys
2826import threading
2927from typing import Any
30- from typing import Callable
31- from typing import cast
3228from typing import Dict
3329from typing import Optional
3430from typing import Protocol
3531from typing import runtime_checkable
3632from typing import TextIO
37- from typing import TYPE_CHECKING
38- from typing import TypeVar
3933from typing import Union
4034
41- if TYPE_CHECKING :
42- from .session_context import SessionContext
43-
4435from mcp import ClientSession
4536from mcp import SamplingCapability
4637from mcp import StdioServerParameters
5344from pydantic import BaseModel
5445from pydantic import ConfigDict
5546
47+ from .session_context import SessionContext
48+
5649logger = logging .getLogger ('google_adk.' + __name__ )
5750
5851
@@ -153,10 +146,7 @@ class StreamableHTTPConnectionParams(BaseModel):
153146 httpx_client_factory : CheckableMcpHttpClientFactory = create_mcp_http_client
154147
155148
156- _F = TypeVar ('_F' , bound = Callable [..., Any ])
157-
158-
159- def retry_on_errors (func : _F ) -> _F :
149+ def retry_on_errors (func ):
160150 """Decorator to automatically retry action when MCP session errors occur.
161151
162152 When MCP session errors occur, the decorator will automatically retry the
@@ -175,7 +165,7 @@ def retry_on_errors(func: _F) -> _F:
175165 """
176166
177167 @functools .wraps (func ) # Preserves original function metadata
178- async def wrapper (self : Any , * args : Any , ** kwargs : Any ) -> Any :
168+ async def wrapper (self , * args , ** kwargs ) :
179169 try :
180170 return await func (self , * args , ** kwargs )
181171 except Exception as e :
@@ -192,17 +182,7 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
192182 logger .info ('Retrying %s due to error: %s' , func .__name__ , e )
193183 return await func (self , * args , ** kwargs )
194184
195- return cast (_F , wrapper )
196-
197-
198- @dataclass
199- class _SessionEntry :
200- """A dataclass to hold session information."""
201-
202- session : ClientSession
203- exit_stack : AsyncExitStack
204- loop : asyncio .AbstractEventLoop
205- context : SessionContext
185+ return wrapper
206186
207187
208188class MCPSessionManager :
@@ -225,7 +205,7 @@ def __init__(
225205 * ,
226206 sampling_callback : Optional [SamplingFnT ] = None ,
227207 sampling_capabilities : Optional [SamplingCapability ] = None ,
228- ) -> None :
208+ ):
229209 """Initializes the MCP session manager.
230210
231211 Args:
@@ -257,8 +237,10 @@ def __init__(
257237 self ._connection_params = connection_params
258238 self ._errlog = errlog
259239
260- # Session pool: maps session keys to _SessionEntry objects
261- self ._sessions : Dict [str , _SessionEntry ] = {}
240+ # Session pool: maps session keys to (session, exit_stack, loop) tuples
241+ self ._sessions : Dict [
242+ str , tuple [ClientSession , AsyncExitStack , asyncio .AbstractEventLoop ]
243+ ] = {}
262244
263245 # Map of event loops to their respective locks to prevent race conditions
264246 # across different event loops in session creation.
@@ -330,66 +312,35 @@ def _merge_headers(
330312
331313 return base_headers
332314
333- def _is_session_disconnected (
334- self ,
335- entry : _SessionEntry ,
336- ) -> bool :
315+ def _is_session_disconnected (self , session : ClientSession ) -> bool :
337316 """Checks if a session is disconnected or closed.
338317
339318 Args:
340- entry : The _SessionEntry to check.
319+ session : The ClientSession to check.
341320
342321 Returns:
343322 True if the session is disconnected, False otherwise.
344323 """
345- if (
346- entry .session ._read_stream ._closed
347- or entry .session ._write_stream ._closed
348- ):
349- return True
350- if entry .context is not None and not entry .context ._is_task_alive : # pylint: disable=protected-access
351- return True
352- return False
353-
354- def _get_session_context (
355- self , headers : Optional [Dict [str , str ]] = None
356- ) -> Optional ['SessionContext' ]:
357- """Returns the SessionContext for the session matching the given headers.
358-
359- Note: This method reads from the session pool without acquiring
360- ``_session_lock``. This is safe because it is called immediately after
361- ``create_session()`` (which populates the entry under the lock) within
362- the same task, and dict reads are atomic in CPython.
363-
364- Args:
365- headers: Optional headers used to identify the session.
366-
367- Returns:
368- The SessionContext if a matching session exists, None otherwise.
369- """
370- merged_headers = self ._merge_headers (headers )
371- session_key = self ._generate_session_key (merged_headers )
372- entry = self ._sessions .get (session_key )
373- if entry is not None :
374- return entry .context
375- return None
324+ return session ._read_stream ._closed or session ._write_stream ._closed
376325
377326 async def _cleanup_session (
378327 self ,
379328 session_key : str ,
380- entry : _SessionEntry ,
381- ) -> None :
329+ exit_stack : AsyncExitStack ,
330+ stored_loop : asyncio .AbstractEventLoop ,
331+ ):
382332 """Cleans up a session, handling different event loops safely.
383333
384334 Args:
385335 session_key: The session key to clean up.
386- entry: The _SessionEntry managing the session resources.
336+ exit_stack: The AsyncExitStack managing the session resources.
337+ stored_loop: The event loop on which the session was created.
387338 """
388339 current_loop = asyncio .get_running_loop ()
389340 try :
390- if entry . loop is current_loop :
391- await entry . exit_stack .aclose ()
392- elif entry . loop .is_closed ():
341+ if stored_loop is current_loop :
342+ await exit_stack .aclose ()
343+ elif stored_loop .is_closed ():
393344 logger .warning (
394345 f'Error cleaning up session { session_key } : original event loop'
395346 ' is closed, resources may be leaked.'
@@ -402,11 +353,11 @@ async def _cleanup_session(
402353 ' event loop.'
403354 )
404355 future = asyncio .run_coroutine_threadsafe (
405- entry . exit_stack .aclose (), entry . loop
356+ exit_stack .aclose (), stored_loop
406357 )
407358
408359 # Attach a callback so errors don't go unnoticed
409- def cleanup_done (f : 'concurrent.futures. Future[Any]' ) -> None :
360+ def cleanup_done (f : asyncio . Future ) :
410361 try :
411362 if f .exception ():
412363 logger .warning (
@@ -428,9 +379,7 @@ def cleanup_done(f: 'concurrent.futures.Future[Any]') -> None:
428379 if session_key in self ._sessions :
429380 del self ._sessions [session_key ]
430381
431- def _create_client (
432- self , merged_headers : Optional [Dict [str , str ]] = None
433- ) -> Any :
382+ def _create_client (self , merged_headers : Optional [Dict [str , str ]] = None ):
434383 """Creates an MCP client based on the connection parameters.
435384
436385 Args:
@@ -502,22 +451,22 @@ async def create_session(
502451 async with self ._session_lock :
503452 # Check if we have an existing session
504453 if session_key in self ._sessions :
505- entry = self ._sessions [session_key ]
454+ session , exit_stack , stored_loop = self ._sessions [session_key ]
506455
507456 # Check if the existing session is still connected and bound to the current loop
508457 current_loop = asyncio .get_running_loop ()
509- if entry . loop is current_loop and not self ._is_session_disconnected (
510- entry
458+ if stored_loop is current_loop and not self ._is_session_disconnected (
459+ session
511460 ):
512461 # Session is still good, return it
513- return entry . session
462+ return session
514463 else :
515464 # Session is disconnected or from a different loop, clean it up
516465 logger .info (
517466 'Cleaning up session (disconnected or different loop): %s' ,
518467 session_key ,
519468 )
520- await self ._cleanup_session (session_key , entry )
469+ await self ._cleanup_session (session_key , exit_stack , stored_loop )
521470
522471 # Create a new session (either first time or replacing disconnected one)
523472 exit_stack = AsyncExitStack ()
@@ -533,30 +482,28 @@ async def create_session(
533482 )
534483
535484 try :
536- from .session_context import SessionContext
537-
538485 client = self ._create_client (merged_headers )
539486 is_stdio = isinstance (self ._connection_params , StdioConnectionParams )
540487
541- session_context = SessionContext (
542- client = client ,
543- timeout = timeout_in_seconds ,
544- sse_read_timeout = sse_read_timeout_in_seconds ,
545- is_stdio = is_stdio ,
546- sampling_callback = self ._sampling_callback ,
547- sampling_capabilities = self ._sampling_capabilities ,
548- )
549488 session = await asyncio .wait_for (
550- exit_stack .enter_async_context (session_context ),
489+ exit_stack .enter_async_context (
490+ SessionContext (
491+ client = client ,
492+ timeout = timeout_in_seconds ,
493+ sse_read_timeout = sse_read_timeout_in_seconds ,
494+ is_stdio = is_stdio ,
495+ sampling_callback = self ._sampling_callback ,
496+ sampling_capabilities = self ._sampling_capabilities ,
497+ )
498+ ),
551499 timeout = timeout_in_seconds ,
552500 )
553501
554- # Store session, exit stack, loop, and context in the pool
555- self ._sessions [session_key ] = _SessionEntry (
556- session = session ,
557- exit_stack = exit_stack ,
558- loop = asyncio .get_running_loop (),
559- context = session_context ,
502+ # Store session, exit stack, and loop in the pool
503+ self ._sessions [session_key ] = (
504+ session ,
505+ exit_stack ,
506+ asyncio .get_running_loop (),
560507 )
561508 logger .debug ('Created new session: %s' , session_key )
562509 return session
@@ -572,7 +519,7 @@ async def create_session(
572519 )
573520 raise ConnectionError (f'Failed to create MCP session: { e } ' ) from e
574521
575- def __getstate__ (self ) -> Dict [ str , Any ] :
522+ def __getstate__ (self ):
576523 """Custom pickling to exclude non-picklable runtime objects."""
577524 state = self .__dict__ .copy ()
578525 # Remove unpicklable entries or those that shouldn't persist across pickle
@@ -585,7 +532,7 @@ def __getstate__(self) -> Dict[str, Any]:
585532
586533 return state
587534
588- def __setstate__ (self , state : Dict [ str , Any ]) -> None :
535+ def __setstate__ (self , state ) :
589536 """Custom unpickling to restore state."""
590537 self .__dict__ .update (state )
591538 # Re-initialize members that were not pickled
@@ -596,12 +543,12 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
596543 if not hasattr (self , '_errlog' ) or self ._errlog is None :
597544 self ._errlog = sys .stderr
598545
599- async def close (self ) -> None :
546+ async def close (self ):
600547 """Closes all sessions and cleans up resources."""
601548 async with self ._session_lock :
602549 for session_key in list (self ._sessions .keys ()):
603- entry = self ._sessions [session_key ]
604- await self ._cleanup_session (session_key , entry )
550+ _ , exit_stack , stored_loop = self ._sessions [session_key ]
551+ await self ._cleanup_session (session_key , exit_stack , stored_loop )
605552
606553
607554SseServerParams = SseConnectionParams
0 commit comments