Fix handling of type var syntax and types.GenericAlias#962
Fix handling of type var syntax and types.GenericAlias#962provinzkraut wants to merge 3 commits intojcrist:mainfrom
types.GenericAlias#962Conversation
63b9a18 to
57bf7e6
Compare
|
Okay, maybe this isn't a good solution. I've discovered a slight issue with the It does not cache itself when called directly (i.g. doing Since msgspec relies on that cache being preserved, the solution as currently proposed only maintains the cache within a decoder instance, meaning that calling Note that this only affects generic types bound by a |
43d8d35 to
4285b96
Compare
|
Alright, I think I managed to find a solution. Essentially, we're now reproducing a It's basically doing typing._GenericAlias(
alias.__origin__,
alias.__origin__.__parameters__
).__getitem__(*alias.__args__)which is functionally the same as alias = typing._SpecialGenericAlias(list, 1)[int]
# this is the same as
alias = typing.List[int]I've also added some more test cases to ensure the caching works properly in all newly supported cases. |
|
Another friendly ping to @ofek for a review :) |
65f7010 to
866b311
Compare
866b311 to
33ba0ed
Compare
| return _get_type_hints(obj, include_extras=True) | ||
|
|
||
|
|
||
| PY_31PLUS = sys.version_info >= (3, 12) |
There was a problem hiding this comment.
Typo in variable name: PY_31PLUS reads as "Python 3.1+", but the comparison is >= (3, 12). Should be PY_312PLUS (consistent with PY312_PLUS used in _core.c).
Performance & correctness reviewTested on a dedicated VPS, pinned to a single core via VPS specs
Benchmark results
Conclusion: no measurable performance regression. Correctness verification6/6 tests pass on the PR branch:
Memory leak test10,000 iterations of
No memory leak detected. Benchmark scriptimport timeit
import statistics
import json
import typing
import msgspec
from msgspec import Struct
T = typing.TypeVar("T")
class Point(Struct):
x: int
y: int
z: float
class Box(Struct, typing.Generic[T]):
value: T
enc = msgspec.json.Encoder()
dec_point = msgspec.json.Decoder(Point)
dec_box_int = msgspec.json.Decoder(Box[int])
point_msg = enc.encode(Point(1, 2, 3.14))
box_msg = enc.encode(Box(value=42))
def bench_point_encode():
p = Point(1, 2, 3.14)
for _ in range(1000):
enc.encode(p)
def bench_point_decode():
for _ in range(1000):
dec_point.decode(point_msg)
def bench_generic_struct_decode():
for _ in range(1000):
dec_box_int.decode(box_msg)
def bench_decoder_creation_simple():
for _ in range(200):
msgspec.json.Decoder(Point)
def bench_decoder_creation_generic():
for _ in range(200):
msgspec.json.Decoder(Box[int])
benchmarks = [
("point_encode", bench_point_encode),
("point_decode", bench_point_decode),
("generic_struct_decode", bench_generic_struct_decode),
("decoder_creation_simple", bench_decoder_creation_simple),
("decoder_creation_generic", bench_decoder_creation_generic),
]
# Warmup
for name, func in benchmarks:
timeit.repeat(func, number=1, repeat=10)
# 3 full runs of 100 rounds each
all_results = {}
for run in range(3):
for name, func in benchmarks:
times = timeit.repeat(func, number=1, repeat=100)
times_sorted = sorted(times)
trimmed = times_sorted[10:-10] # drop top/bottom 10%
if name not in all_results:
all_results[name] = []
all_results[name].extend(trimmed)
results = {}
for name, times in all_results.items():
results[name] = {
"mean": statistics.mean(times),
"stdev": statistics.stdev(times),
"min": min(times),
"p50": statistics.median(times),
}
print(json.dumps(results, indent=2))Verification scriptimport sys
import typing
import collections.abc
import dataclasses
import gc
import msgspec
from msgspec import Struct
results = []
def test(name, func):
try:
func()
results.append((name, "PASS", ""))
except Exception as e:
results.append((name, "FAIL", str(e)))
def test_dataclass_mapping():
@dataclasses.dataclass
class Bar(typing.Generic[typing.T], collections.abc.Mapping[str, typing.T]):
data: dict[str, typing.T]
def __getitem__(self, x): return self.data[x]
def __len__(self): return len(self.data)
def __iter__(self): return iter(self.data)
x = Bar(data={"x": 3})
encoded = msgspec.msgpack.encode(x)
decoded = msgspec.msgpack.decode(encoded, type=Bar[int])
assert decoded == x
test("dataclass_mapping_generic", test_dataclass_mapping)
def test_struct_mapping():
import abc
class CombinedMeta(msgspec.structs.StructMeta, abc.ABCMeta):
pass
T = typing.TypeVar("T")
class Foo(collections.abc.Mapping[str, T], Struct, typing.Generic[T], metaclass=CombinedMeta):
data: dict[str, T]
def __getitem__(self, x): return self.data[x]
def __len__(self): return len(self.data)
def __iter__(self): return iter(self.data)
encoded = msgspec.msgpack.encode(Foo(data={"x": 1}))
decoded = msgspec.msgpack.decode(encoded, type=Foo[int])
assert decoded.data == {"x": 1}
try:
msgspec.msgpack.decode(
msgspec.msgpack.encode(Foo(data={"x": "foo"})), type=Foo[int]
)
assert False, "Should have raised ValidationError"
except msgspec.ValidationError:
pass
test("struct_mapping_generic", test_struct_mapping)
def test_typing_mapping():
T = typing.TypeVar("T")
@dataclasses.dataclass
class Bar(typing.Generic[T], typing.Mapping[str, T]):
data: dict[str, T]
def __getitem__(self, x): return self.data[x]
def __len__(self): return len(self.data)
def __iter__(self): return iter(self.data)
x = Bar(data={"x": 3})
encoded = msgspec.msgpack.encode(x)
decoded = msgspec.msgpack.decode(encoded, type=Bar[int])
assert decoded == x
test("typing_mapping_generic", test_typing_mapping)
if sys.version_info >= (3, 12):
def test_typevar_syntax():
code = '''
from msgspec import Struct
class Ex[T](Struct):
x: T
y: list[T]
'''
ns = {}
exec(code, ns)
Ex = ns["Ex"]
msg = msgspec.json.encode(Ex(1, [1, 2]))
res = msgspec.json.decode(msg, type=Ex[int])
assert res.x == 1 and res.y == [1, 2]
test("pep695_typevar_syntax", test_typevar_syntax)
def test_refcount_cache():
T = typing.TypeVar("T")
@dataclasses.dataclass
class Foo(typing.Generic[T], collections.abc.Mapping[str, T]):
data: dict[str, T]
def __getitem__(self, x): return self.data[x]
def __len__(self): return len(self.data)
def __iter__(self): return iter(self.data)
typ = Foo[int]
dec1 = msgspec.json.Decoder(typ)
dec2 = msgspec.json.Decoder(typ)
msg = msgspec.json.encode(Foo(data={"a": 1}))
r1 = dec1.decode(msg)
r2 = dec2.decode(msg)
assert r1 == r2
del dec1, dec2
gc.collect()
test("refcount_cache", test_refcount_cache)
def test_cache_persistence():
T = typing.TypeVar("T")
@dataclasses.dataclass
class Foo(typing.Generic[T], collections.abc.Mapping[str, T]):
data: dict[str, T]
def __getitem__(self, x): return self.data[x]
def __len__(self): return len(self.data)
def __iter__(self): return iter(self.data)
typ = Foo[int]
msg = msgspec.msgpack.encode(Foo(data={"a": 1}))
for i in range(10):
dec = msgspec.msgpack.Decoder(typ)
result = dec.decode(msg)
assert result.data == {"a": 1}
test("cache_persistence", test_cache_persistence)
for name, status, detail in results:
marker = "OK" if status == "PASS" else "FAIL"
print(f"[{marker}] {name}" + (f" - {detail}" if detail else ""))
print(f"\n{len(results)} tests, {sum(1 for _, s, _ in results if s == 'FAIL')} failed")Memory leak test scriptimport gc
import sys
import typing
import collections.abc
import dataclasses
import tracemalloc
import msgspec
from msgspec import Struct
T = typing.TypeVar("T")
@dataclasses.dataclass
class Foo(typing.Generic[T], collections.abc.Mapping[str, T]):
data: dict[str, T]
def __getitem__(self, x): return self.data[x]
def __len__(self): return len(self.data)
def __iter__(self): return iter(self.data)
tracemalloc.start()
gc.collect()
with open("/proc/self/status") as f:
rss_before = int([l for l in f if l.startswith("VmRSS:")][0].split()[1])
snap1 = tracemalloc.take_snapshot()
for i in range(10000):
dec = msgspec.json.Decoder(Foo[int])
msg = msgspec.json.encode(Foo(data={"a": 1}))
dec.decode(msg)
del dec
gc.collect()
with open("/proc/self/status") as f:
rss_after = int([l for l in f if l.startswith("VmRSS:")][0].split()[1])
snap2 = tracemalloc.take_snapshot()
print(f"RSS diff: {rss_after - rss_before} KB")
for stat in snap2.compare_to(snap1, "lineno")[:5]:
print(f" {stat}") |
Fix #957.
This is a few fixes combined into one, since they were tightly coupled.
Handling of "new style" generics during type resolution
When subscribing a "new style" generic (such as
collections.abc.Mapping), it produces atypes.GenericAlias(vs. the "old style"typing._GenericAlias), which msgspec did not handle correctly during inspectionHandling of type var syntax
When dealing with builtin generics that resolve to
typing.TypeAlias, msgspec did not account for type var syntax at correctly in all cases, so type information would get lost during the conversion processType conversions on
types.GenericAliasDuring type conversion, msgspec caches certain information on the type objects themselves, if the types are complex (i.e.
Structs ordataclass-like).When decoding into a
Foo[int], msgspec will set an__msgspec_cache__attribute on theFoo[int]alias type.For
typing._GenericAlias, this work, since it has a__dict__, so you can just assign attributes to it. However,types.GenericAliasdoes not allow assigning arbitrary attributes to it.I fix this by downtyping
types.GenericAliasinto atyping._GenericAlias, when encountering a genericStructordataclasstype. This allows to keep the existing caching mechanism in place.This seemed like the most reasonable fix to me, as the other alternatives would like incur some sort of performance penalty; By storing the typing info directly on the alias, msgspec can forego maintaining a dedicated cache, making lookups very fast. It also allows to not care about invalidating a cache, since it will just be gce'd when the alias isn't referenced anymore.
One thing to not here though is that in the future,
typing._GenericAliasmight just go away (at least from the stdlib), in which case we'll have to find another way to deal with this.