1111
1212import asyncio
1313import contextvars
14- import dis
1514import inspect
1615import logging
1716import socket
2625
2726try :
2827 import sniffio # type: ignore[import-not-found]
29- except ImportError :
28+ except ImportError : # pragma: no cover
3029 sniffio = None # type: ignore[assignment]
3130
3231try :
3332 import trio # type: ignore[import-not-found]
34- from trio .lowlevel import current_trio_token # type: ignore[import-not-found]
35- except ImportError :
33+ except ImportError : # pragma: no cover
3634 trio = None # type: ignore[assignment]
37- current_trio_token = None # type: ignore[assignment]
3835
3936try :
4037 import curio # type: ignore[import-not-found,import-untyped]
@@ -223,7 +220,6 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
223220 Returns:
224221 The result of the function.
225222 """
226- # Check if f is a partial wrapping an async function
227223 func_to_check = f .func if isinstance (f , partial ) else f
228224 is_async = inspect .iscoroutinefunction (func_to_check )
229225 signature = inspect .signature (f )
@@ -252,8 +248,6 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
252248 # First, we inspect the keyword arguments and try and pass in some arguments
253249 # by currying them in.
254250 for param in signature .parameters .values ():
255- # Try matching the parameter name, or if it starts with underscore,
256- # also try matching without the leading underscore.
257251 arg_key = None
258252 if param .name in args :
259253 arg_key = param .name
@@ -320,25 +314,20 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
320314 },
321315 )
322316
323- try :
324- if is_async :
325- result = f ()
326- if inspect .iscoroutine (result ):
327- return _run_async_coroutine (result )
328- return result
329- return f ()
330- except Exception :
331- logger .exception ("Error occurred while calling function %s" , f_name )
332- raise
333-
334-
335- def _run_async_coroutine (coro : Coroutine [Any , Any , _T ]) -> _T : # noqa: C901
317+ if is_async :
318+ result = f ()
319+ if inspect .iscoroutine (result ):
320+ return _run_async_coroutine (result )
321+ return result # pragma: no cover
322+ return f ()
323+
324+
325+ def _run_async_coroutine (coro : Coroutine [Any , Any , _T ]) -> _T :
336326 """
337327 Run a coroutine in an event loop.
338328
339- Detects the async runtime (asyncio, trio, or curio) and executes the
340- coroutine appropriately. Preserves ContextVars when creating a new event
341- loop, which is important when handlers are called from threads.
329+ Detects the current async runtime and runs the coroutine in it,
330+ preserving ContextVars across the dispatch.
342331
343332 Args:
344333 coro:
@@ -351,24 +340,13 @@ def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: # noqa: C901
351340 RuntimeError:
352341 If the detected runtime (trio or curio) is not installed.
353342 """
354- runtime = _detect_async_runtime_from_coroutine (coro )
343+ runtime = _detect_async_runtime (coro )
355344
356345 if runtime == "trio" :
357346 if trio is None :
358347 msg = "trio is not installed"
359348 raise RuntimeError (msg )
360349
361- if current_trio_token is not None :
362- try :
363- token = current_trio_token ()
364-
365- async def _run_with_token () -> _T :
366- return await coro
367-
368- return trio .from_thread .run_sync (_run_with_token , trio_token = token ) # type: ignore[return-value]
369- except RuntimeError :
370- pass
371-
372350 ctx = contextvars .copy_context ()
373351
374352 async def _run_trio () -> _T :
@@ -404,13 +382,17 @@ async def _run_curio() -> _T:
404382 return ctx .run (asyncio .run , coro ) # type: ignore[arg-type,return-value]
405383
406384
407- def _detect_async_runtime_from_coroutine (coro : Coroutine [Any , Any , _T ]) -> str : # noqa: C901
385+ def _detect_async_runtime (coro : Coroutine [Any , Any , _T ]) -> str :
408386 """
409- Detect async runtime by inspecting the coroutine object.
387+ Detect the async runtime to use for a given coroutine.
388+
389+ When called from within a running async context, `sniffio` is used to
390+ identify the library. Otherwise the coroutine's `co_names` is inspected
391+ for `trio` or `curio` references.
410392
411393 Args:
412394 coro:
413- The coroutine object to inspect.
395+ The coroutine to inspect.
414396
415397 Returns:
416398 The detected runtime: "asyncio", "trio", or "curio".
@@ -421,47 +403,10 @@ def _detect_async_runtime_from_coroutine(coro: Coroutine[Any, Any, _T]) -> str:
421403 except sniffio .AsyncLibraryNotFoundError :
422404 pass
423405
424- # Inspect bytecode to check for qualified attribute access (e.g., trio.sleep)
425- # This is more robust than just checking co_names for module and method separately
426- func_code = coro .cr_code # type: ignore[attr-defined]
427-
428- # Parse bytecode to find LOAD_GLOBAL/LOAD_NAME followed by LOAD_ATTR patterns
429- # This detects qualified accesses like `trio.sleep()` or `curio.spawn()`
430- bytecode = list (dis .get_instructions (func_code ))
431-
432- trio_detected = False
433- curio_detected = False
434-
435- for i , instr in enumerate (bytecode ):
436- # Check for module.attribute pattern (LOAD_GLOBAL/LOAD_NAME + LOAD_ATTR)
437- if instr .opname in ("LOAD_GLOBAL" , "LOAD_NAME" ) and i + 1 < len (bytecode ):
438- next_instr = bytecode [i + 1 ]
439- if next_instr .opname == "LOAD_ATTR" :
440- module_name = instr .argval
441- attr_name = next_instr .argval
442-
443- # Check for trio-specific qualified access
444- if module_name == "trio" :
445- trio_indicators = {
446- "sleep" ,
447- "open_nursery" ,
448- "CancelScope" ,
449- "current_trio_token" ,
450- }
451- if attr_name in trio_indicators :
452- trio_detected = True
453-
454- # Check for curio-specific qualified access
455- elif module_name == "curio" :
456- curio_indicators = {"sleep" , "spawn" , "TaskGroup" , "AWAIT" }
457- if attr_name in curio_indicators :
458- curio_detected = True
459-
460- # Trio takes precedence if both are detected
461- if trio_detected :
406+ names = set (coro .cr_code .co_names ) # type: ignore[attr-defined]
407+ if trio is not None and "trio" in names :
462408 return "trio"
463- if curio_detected :
409+ if curio is not None and "curio" in names :
464410 return "curio"
465411
466- # Default to asyncio as it's the most common
467412 return "asyncio"
0 commit comments