Skip to content

Commit 62e6c5a

Browse files
Replace bytecode inspection with co_names and move sniffio to runtime deps
1 parent 25035b2 commit 62e6c5a

3 files changed

Lines changed: 65 additions & 386 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ dependencies = [
4343
# Pact dependencies
4444
"pact-python-ffi~=0.4.0",
4545
# External dependencies
46+
"sniffio~=1.0",
4647
"yarl~=1.0",
4748
"typing-extensions~=4.0 ; python_version < '3.13'",
4849
]
@@ -120,7 +121,6 @@ test = [
120121
"pytest-rerunfailures~=16.0",
121122
"pytest~=9.0",
122123
"requests~=2.0",
123-
"sniffio>=1.3",
124124
"testcontainers~=4.0",
125125
]
126126
types = [

src/pact/_util.py

Lines changed: 24 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import asyncio
1313
import contextvars
14-
import dis
1514
import inspect
1615
import logging
1716
import socket
@@ -26,15 +25,13 @@
2625

2726
try:
2827
import sniffio # type: ignore[import-not-found]
29-
except ImportError:
28+
except ImportError: # pragma: no cover
3029
sniffio = None # type: ignore[assignment]
3130

3231
try:
3332
import trio # type: ignore[import-not-found]
34-
from trio.lowlevel import current_trio_token # type: ignore[import-not-found]
35-
except ImportError:
33+
except ImportError: # pragma: no cover
3634
trio = None # type: ignore[assignment]
37-
current_trio_token = None # type: ignore[assignment]
3835

3936
try:
4037
import curio # type: ignore[import-not-found,import-untyped]
@@ -223,7 +220,6 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
223220
Returns:
224221
The result of the function.
225222
"""
226-
# Check if f is a partial wrapping an async function
227223
func_to_check = f.func if isinstance(f, partial) else f
228224
is_async = inspect.iscoroutinefunction(func_to_check)
229225
signature = inspect.signature(f)
@@ -252,8 +248,6 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
252248
# First, we inspect the keyword arguments and try and pass in some arguments
253249
# by currying them in.
254250
for param in signature.parameters.values():
255-
# Try matching the parameter name, or if it starts with underscore,
256-
# also try matching without the leading underscore.
257251
arg_key = None
258252
if param.name in args:
259253
arg_key = param.name
@@ -320,25 +314,20 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
320314
},
321315
)
322316

323-
try:
324-
if is_async:
325-
result = f()
326-
if inspect.iscoroutine(result):
327-
return _run_async_coroutine(result)
328-
return result
329-
return f()
330-
except Exception:
331-
logger.exception("Error occurred while calling function %s", f_name)
332-
raise
333-
334-
335-
def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: # noqa: C901
317+
if is_async:
318+
result = f()
319+
if inspect.iscoroutine(result):
320+
return _run_async_coroutine(result)
321+
return result # pragma: no cover
322+
return f()
323+
324+
325+
def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T:
336326
"""
337327
Run a coroutine in an event loop.
338328
339-
Detects the async runtime (asyncio, trio, or curio) and executes the
340-
coroutine appropriately. Preserves ContextVars when creating a new event
341-
loop, which is important when handlers are called from threads.
329+
Detects the current async runtime and runs the coroutine in it,
330+
preserving ContextVars across the dispatch.
342331
343332
Args:
344333
coro:
@@ -351,24 +340,13 @@ def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: # noqa: C901
351340
RuntimeError:
352341
If the detected runtime (trio or curio) is not installed.
353342
"""
354-
runtime = _detect_async_runtime_from_coroutine(coro)
343+
runtime = _detect_async_runtime(coro)
355344

356345
if runtime == "trio":
357346
if trio is None:
358347
msg = "trio is not installed"
359348
raise RuntimeError(msg)
360349

361-
if current_trio_token is not None:
362-
try:
363-
token = current_trio_token()
364-
365-
async def _run_with_token() -> _T:
366-
return await coro
367-
368-
return trio.from_thread.run_sync(_run_with_token, trio_token=token) # type: ignore[return-value]
369-
except RuntimeError:
370-
pass
371-
372350
ctx = contextvars.copy_context()
373351

374352
async def _run_trio() -> _T:
@@ -404,13 +382,17 @@ async def _run_curio() -> _T:
404382
return ctx.run(asyncio.run, coro) # type: ignore[arg-type,return-value]
405383

406384

407-
def _detect_async_runtime_from_coroutine(coro: Coroutine[Any, Any, _T]) -> str: # noqa: C901
385+
def _detect_async_runtime(coro: Coroutine[Any, Any, _T]) -> str:
408386
"""
409-
Detect async runtime by inspecting the coroutine object.
387+
Detect the async runtime to use for a given coroutine.
388+
389+
When called from within a running async context, `sniffio` is used to
390+
identify the library. Otherwise the coroutine's `co_names` is inspected
391+
for `trio` or `curio` references.
410392
411393
Args:
412394
coro:
413-
The coroutine object to inspect.
395+
The coroutine to inspect.
414396
415397
Returns:
416398
The detected runtime: "asyncio", "trio", or "curio".
@@ -421,47 +403,10 @@ def _detect_async_runtime_from_coroutine(coro: Coroutine[Any, Any, _T]) -> str:
421403
except sniffio.AsyncLibraryNotFoundError:
422404
pass
423405

424-
# Inspect bytecode to check for qualified attribute access (e.g., trio.sleep)
425-
# This is more robust than just checking co_names for module and method separately
426-
func_code = coro.cr_code # type: ignore[attr-defined]
427-
428-
# Parse bytecode to find LOAD_GLOBAL/LOAD_NAME followed by LOAD_ATTR patterns
429-
# This detects qualified accesses like `trio.sleep()` or `curio.spawn()`
430-
bytecode = list(dis.get_instructions(func_code))
431-
432-
trio_detected = False
433-
curio_detected = False
434-
435-
for i, instr in enumerate(bytecode):
436-
# Check for module.attribute pattern (LOAD_GLOBAL/LOAD_NAME + LOAD_ATTR)
437-
if instr.opname in ("LOAD_GLOBAL", "LOAD_NAME") and i + 1 < len(bytecode):
438-
next_instr = bytecode[i + 1]
439-
if next_instr.opname == "LOAD_ATTR":
440-
module_name = instr.argval
441-
attr_name = next_instr.argval
442-
443-
# Check for trio-specific qualified access
444-
if module_name == "trio":
445-
trio_indicators = {
446-
"sleep",
447-
"open_nursery",
448-
"CancelScope",
449-
"current_trio_token",
450-
}
451-
if attr_name in trio_indicators:
452-
trio_detected = True
453-
454-
# Check for curio-specific qualified access
455-
elif module_name == "curio":
456-
curio_indicators = {"sleep", "spawn", "TaskGroup", "AWAIT"}
457-
if attr_name in curio_indicators:
458-
curio_detected = True
459-
460-
# Trio takes precedence if both are detected
461-
if trio_detected:
406+
names = set(coro.cr_code.co_names) # type: ignore[attr-defined]
407+
if trio is not None and "trio" in names:
462408
return "trio"
463-
if curio_detected:
409+
if curio is not None and "curio" in names:
464410
return "curio"
465411

466-
# Default to asyncio as it's the most common
467412
return "asyncio"

0 commit comments

Comments
 (0)