diff --git a/src/mcp_server_python_docs/cache/__init__.py b/src/mcp_server_python_docs/cache/__init__.py new file mode 100644 index 0000000..6e3801b --- /dev/null +++ b/src/mcp_server_python_docs/cache/__init__.py @@ -0,0 +1 @@ +"""Cache support utilities.""" diff --git a/src/mcp_server_python_docs/cache/codec.py b/src/mcp_server_python_docs/cache/codec.py new file mode 100644 index 0000000..552b3a1 --- /dev/null +++ b/src/mcp_server_python_docs/cache/codec.py @@ -0,0 +1,114 @@ +"""Versioned codecs for cache-at-rest payloads.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass + +import zstandard as zstd + +_SUPPORTED_CODECS = ["none", "zstd", "zstd-dict-v1"] + + +@dataclass(frozen=True) +class _Codec: + encode: Callable[[str, object | None], bytes] + decode: Callable[[bytes, object | None], str] + + +def list_supported() -> list[str]: + """Return codec ids in stable preference order.""" + + return list(_SUPPORTED_CODECS) + + +def encode(text: str, codec: str, *, dictionary: object | None = None) -> bytes: + """Encode text using a supported cache codec.""" + + try: + handler = _REGISTRY[codec] + except KeyError as e: + raise ValueError(f"Unsupported cache codec: {codec}") from e + return handler.encode(text, dictionary) + + +def decode(blob: bytes, codec: str, *, dictionary: object | None = None) -> str: + """Decode text using the codec stored with the cache row.""" + + try: + handler = _REGISTRY[codec] + except KeyError as e: + raise ValueError(f"Unsupported cache codec: {codec}") from e + return handler.decode(blob, dictionary) + + +def _encode_none(text: str, dictionary: object | None) -> bytes: + _reject_dictionary("none", dictionary) + return text.encode("utf-8") + + +def _decode_none(blob: bytes, dictionary: object | None) -> str: + _reject_dictionary("none", dictionary) + return blob.decode("utf-8") + + +def _encode_zstd(text: str, dictionary: object | None) -> bytes: + _reject_dictionary("zstd", dictionary) + try: + return zstd.ZstdCompressor().compress(text.encode("utf-8")) + except zstd.ZstdError as e: + raise ValueError(f"zstd encode failed: {e}") from e + + +def _decode_zstd(blob: bytes, dictionary: object | None) -> str: + _reject_dictionary("zstd", dictionary) + try: + return zstd.ZstdDecompressor().decompress(blob).decode("utf-8") + except zstd.ZstdError as e: + raise ValueError(f"zstd decode failed: {e}") from e + + +def _encode_zstd_dict(text: str, dictionary: object | None) -> bytes: + try: + return zstd.ZstdCompressor(dict_data=_coerce_dictionary(dictionary)).compress( + text.encode("utf-8") + ) + except zstd.ZstdError as e: + raise ValueError(f"zstd dictionary encode failed: {e}") from e + + +def _decode_zstd_dict(blob: bytes, dictionary: object | None) -> str: + try: + return ( + zstd.ZstdDecompressor(dict_data=_coerce_dictionary(dictionary)) + .decompress(blob) + .decode("utf-8") + ) + except zstd.ZstdError as e: + raise ValueError(f"zstd dictionary decode failed: {e}") from e + + +def _reject_dictionary(codec: str, dictionary: object | None) -> None: + if dictionary is not None: + raise ValueError(f"Codec {codec!r} does not use a dictionary") + + +def _coerce_dictionary(dictionary: object | None) -> zstd.ZstdCompressionDict: + if dictionary is None: + raise ValueError("Codec 'zstd-dict-v1' requires an explicit dictionary") + if isinstance(dictionary, zstd.ZstdCompressionDict): + return dictionary + if isinstance(dictionary, bytes): + return zstd.ZstdCompressionDict(dictionary) + if isinstance(dictionary, bytearray | memoryview): + return zstd.ZstdCompressionDict(bytes(dictionary)) + raise TypeError(f"Unsupported zstd dictionary object: {type(dictionary).__name__}") + + +_REGISTRY: dict[str, _Codec] = { + "none": _Codec(_encode_none, _decode_none), + "zstd": _Codec(_encode_zstd, _decode_zstd), + "zstd-dict-v1": _Codec(_encode_zstd_dict, _decode_zstd_dict), +} + +__all__ = ["decode", "encode", "list_supported"] diff --git a/src/mcp_server_python_docs/services/persistent_cache.py b/src/mcp_server_python_docs/services/persistent_cache.py index f1162d9..5931818 100644 --- a/src/mcp_server_python_docs/services/persistent_cache.py +++ b/src/mcp_server_python_docs/services/persistent_cache.py @@ -10,10 +10,13 @@ from pydantic import ValidationError +from mcp_server_python_docs.cache.codec import decode as decode_cache_payload +from mcp_server_python_docs.cache.codec import encode as encode_cache_payload from mcp_server_python_docs.models import GetDocsResult logger = logging.getLogger(__name__) _NO_ANCHOR_KEY = "\x00mcp-python-docs:no-anchor\x00" +DEFAULT_RETRIEVED_DOCS_CACHE_CODEC = "zstd" class CacheStats(NamedTuple): @@ -25,8 +28,15 @@ class CacheStats(NamedTuple): class PersistentDocsCache: """Persist get_docs results by index fingerprint, version, and request identity.""" - def __init__(self, cache_path: Path, index_path: Path) -> None: + def __init__( + self, + cache_path: Path, + index_path: Path, + *, + default_codec: str = DEFAULT_RETRIEVED_DOCS_CACHE_CODEC, + ) -> None: self._cache_path = Path(cache_path) + self._default_codec = default_codec # Set after fingerprint stat succeeds; stays "" if init fails so the # cache disables cleanly without leaking partial state. self._fingerprint = "" @@ -47,9 +57,11 @@ def __init__(self, cache_path: Path, index_path: Path) -> None: "CREATE TABLE IF NOT EXISTS retrieved_docs_cache (" "index_fingerprint TEXT NOT NULL, version TEXT NOT NULL, slug TEXT NOT NULL, " "anchor TEXT NOT NULL, max_chars INTEGER NOT NULL, start_index INTEGER NOT NULL, " - "result_json TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, " + "result_json TEXT NOT NULL, compression TEXT NOT NULL DEFAULT 'none', " + "created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, " "PRIMARY KEY (index_fingerprint, version, slug, anchor, max_chars, start_index))" ) + self._ensure_compression_column() self._conn.execute( "DELETE FROM retrieved_docs_cache WHERE index_fingerprint != ?", (self._fingerprint,), @@ -74,6 +86,18 @@ def _fingerprint_index(index_path: Path) -> str: def _anchor_key(anchor: str | None) -> str: return _NO_ANCHOR_KEY if anchor is None else anchor + def _ensure_compression_column(self) -> None: + if self._conn is None: + return + columns = { + row[1] for row in self._conn.execute("PRAGMA table_info(retrieved_docs_cache)") + } + if "compression" not in columns: + self._conn.execute( + "ALTER TABLE retrieved_docs_cache " + "ADD COLUMN compression TEXT NOT NULL DEFAULT 'none'" + ) + def stats(self) -> CacheStats: return CacheStats(self._hits, self._misses, self._writes) @@ -87,7 +111,8 @@ def get( with self._lock: try: row = self._conn.execute( - "SELECT result_json FROM retrieved_docs_cache WHERE index_fingerprint = ? " + "SELECT result_json, compression FROM retrieved_docs_cache " + "WHERE index_fingerprint = ? " "AND version = ? AND slug = ? AND anchor = ? AND max_chars = ? " "AND start_index = ?", ( @@ -107,8 +132,10 @@ def get( self._misses += 1 return None try: - result = GetDocsResult.model_validate_json(row[0]) - except (ValidationError, ValueError) as e: + payload = row[0].encode("utf-8") if isinstance(row[0], str) else bytes(row[0]) + result_json = decode_cache_payload(payload, row[1]) + result = GetDocsResult.model_validate_json(result_json) + except (ValidationError, ValueError, TypeError) as e: self._misses += 1 logger.warning("Persistent docs cache entry ignored: %s", e) return None @@ -123,7 +150,7 @@ def put(self, *, result: GetDocsResult, max_chars: int, start_index: int) -> Non self._conn.execute( "INSERT OR REPLACE INTO retrieved_docs_cache " "(index_fingerprint, version, slug, anchor, max_chars, start_index, " - "result_json) VALUES (?, ?, ?, ?, ?, ?, ?)", + "result_json, compression) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", ( self._fingerprint, result.version, @@ -131,7 +158,8 @@ def put(self, *, result: GetDocsResult, max_chars: int, start_index: int) -> Non self._anchor_key(result.anchor), max_chars, start_index, - result.model_dump_json(), + encode_cache_payload(result.model_dump_json(), self._default_codec), + self._default_codec, ), ) self._conn.commit() diff --git a/tests/cache/__init__.py b/tests/cache/__init__.py new file mode 100644 index 0000000..88f9452 --- /dev/null +++ b/tests/cache/__init__.py @@ -0,0 +1 @@ +"""Cache tests.""" diff --git a/tests/cache/test_codec.py b/tests/cache/test_codec.py new file mode 100644 index 0000000..c5c2c75 --- /dev/null +++ b/tests/cache/test_codec.py @@ -0,0 +1,49 @@ +"""Cache codec coverage.""" + +from __future__ import annotations + +import zstandard as zstd + +from mcp_server_python_docs.cache.codec import decode, encode, list_supported + + +def _test_dictionary() -> zstd.ZstdCompressionDict: + samples = [ + ( + f"Python documentation section {i}: json dumps loads encoder decoder " + "arguments return values exceptions examples. " + ).encode("utf-8") + * 8 + for i in range(64) + ] + return zstd.train_dictionary(512, samples) + + +def test_list_supported_is_stable() -> None: + assert list_supported() == ["none", "zstd", "zstd-dict-v1"] + + +def test_none_round_trips_text() -> None: + text = '{"content":"plain json payload","version":"3.13"}' + encoded = encode(text, "none") + assert encoded == text.encode("utf-8") + assert decode(encoded, "none") == text + + +def test_zstd_round_trips_text() -> None: + text = '{"content":"compressed json payload","version":"3.13"}' + encoded = encode(text, "zstd") + assert encoded != text.encode("utf-8") + assert decode(encoded, "zstd") == text + + +def test_zstd_dict_v1_round_trips_with_explicit_dictionary() -> None: + dictionary = _test_dictionary() + text = "Python documentation section 7: json dumps loads encoder decoder arguments." + encoded = encode(text, "zstd-dict-v1", dictionary=dictionary) + assert decode(encoded, "zstd-dict-v1", dictionary=dictionary) == text + + +def test_none_decodes_payload_from_prior_server_version() -> None: + prior_payload = b'{"content":"legacy uncompressed json","version":"3.12"}' + assert decode(prior_payload, "none") == prior_payload.decode("utf-8") diff --git a/tests/test_mcp_get_docs_cache_smoke.py b/tests/test_mcp_get_docs_cache_smoke.py index 87ad457..1860a5c 100644 --- a/tests/test_mcp_get_docs_cache_smoke.py +++ b/tests/test_mcp_get_docs_cache_smoke.py @@ -140,11 +140,12 @@ def test_get_docs_cache_restart_and_corrupt_cache_fallback(tmp_path: Path): with sqlite3.connect(cache_path) as conn: rows = conn.execute( - "SELECT version, slug, anchor, max_chars, start_index, length(result_json) " + "SELECT version, slug, anchor, max_chars, start_index, " + "length(result_json), compression " "FROM retrieved_docs_cache" ).fetchall() assert len(rows) == 1 - version, slug, anchor, max_chars, start_index, result_json_length = rows[0] + version, slug, anchor, max_chars, start_index, result_json_length, compression = rows[0] assert (version, slug, anchor, max_chars, start_index) == ( "3.13", "library/json.html", @@ -153,6 +154,7 @@ def test_get_docs_cache_restart_and_corrupt_cache_fallback(tmp_path: Path): 0, ) assert result_json_length > 0 + assert compression == "zstd" restarted_page = _tool_structured_content( _run_server( diff --git a/tests/test_persistent_docs_cache.py b/tests/test_persistent_docs_cache.py index f9fd5dd..e835a03 100644 --- a/tests/test_persistent_docs_cache.py +++ b/tests/test_persistent_docs_cache.py @@ -10,7 +10,7 @@ from mcp_server_python_docs.models import GetDocsResult from mcp_server_python_docs.services.content import ContentService -from mcp_server_python_docs.services.persistent_cache import PersistentDocsCache +from mcp_server_python_docs.services.persistent_cache import _NO_ANCHOR_KEY, PersistentDocsCache def _doc(db, version: str, content: str, default: int = 0) -> None: @@ -68,6 +68,76 @@ def test_cache_survives_restart_and_miss_falls_back(populated_db, tmp_path: Path assert restarted.stats().hits == 1 +def test_current_default_codec_reads_identically_after_restart(tmp_path: Path): + index_path, cache = _cache(tmp_path) + expected = _result("compressed docs payload") + cache.put(result=expected, max_chars=500, start_index=0) + + with sqlite3.connect(cache.cache_path) as conn: + compression = conn.execute("SELECT compression FROM retrieved_docs_cache").fetchone()[0] + assert compression == "zstd" + + restarted = PersistentDocsCache(tmp_path / "retrieved.sqlite3", index_path) + assert ( + restarted.get( + version="3.12", + slug="library/json.html", + anchor=None, + max_chars=500, + start_index=0, + ) + == expected + ) + assert restarted.stats().hits == 1 + + +def test_legacy_uncompressed_cache_row_migrates_and_reads(tmp_path: Path): + index_path = tmp_path / "index.db" + index_path.write_bytes(b"index-1") + fingerprint = PersistentDocsCache._fingerprint_index(index_path) + cache_path = tmp_path / "retrieved.sqlite3" + expected = _result("legacy docs payload") + with sqlite3.connect(cache_path) as conn: + conn.execute( + "CREATE TABLE retrieved_docs_cache (" + "index_fingerprint TEXT NOT NULL, version TEXT NOT NULL, slug TEXT NOT NULL, " + "anchor TEXT NOT NULL, max_chars INTEGER NOT NULL, start_index INTEGER NOT NULL, " + "result_json TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, " + "PRIMARY KEY (index_fingerprint, version, slug, anchor, max_chars, start_index))" + ) + conn.execute( + "INSERT INTO retrieved_docs_cache " + "(index_fingerprint, version, slug, anchor, max_chars, start_index, result_json) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + ( + fingerprint, + expected.version, + expected.slug, + _NO_ANCHOR_KEY, + 500, + 0, + expected.model_dump_json(), + ), + ) + + migrated = PersistentDocsCache(cache_path, index_path) + assert ( + migrated.get( + version="3.12", + slug="library/json.html", + anchor=None, + max_chars=500, + start_index=0, + ) + == expected + ) + with sqlite3.connect(cache_path) as conn: + columns = {row[1] for row in conn.execute("PRAGMA table_info(retrieved_docs_cache)")} + compression = conn.execute("SELECT compression FROM retrieved_docs_cache").fetchone()[0] + assert "compression" in columns + assert compression == "none" + + def test_cache_key_includes_python_version(populated_db, tmp_path: Path): _doc(populated_db, "3.12", "docs for 3.12") _doc(populated_db, "3.13", "docs for 3.13", 1)