Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/mcp_server_python_docs/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Cache support utilities."""
114 changes: 114 additions & 0 deletions src/mcp_server_python_docs/cache/codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Versioned codecs for cache-at-rest payloads."""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass

import zstandard as zstd

_SUPPORTED_CODECS = ["none", "zstd", "zstd-dict-v1"]


@dataclass(frozen=True)
class _Codec:
encode: Callable[[str, object | None], bytes]
decode: Callable[[bytes, object | None], str]


def list_supported() -> list[str]:
"""Return codec ids in stable preference order."""

return list(_SUPPORTED_CODECS)


def encode(text: str, codec: str, *, dictionary: object | None = None) -> bytes:
"""Encode text using a supported cache codec."""

try:
handler = _REGISTRY[codec]
except KeyError as e:
raise ValueError(f"Unsupported cache codec: {codec}") from e
return handler.encode(text, dictionary)


def decode(blob: bytes, codec: str, *, dictionary: object | None = None) -> str:
"""Decode text using the codec stored with the cache row."""

try:
handler = _REGISTRY[codec]
except KeyError as e:
raise ValueError(f"Unsupported cache codec: {codec}") from e
return handler.decode(blob, dictionary)


def _encode_none(text: str, dictionary: object | None) -> bytes:
_reject_dictionary("none", dictionary)
return text.encode("utf-8")


def _decode_none(blob: bytes, dictionary: object | None) -> str:
_reject_dictionary("none", dictionary)
return blob.decode("utf-8")


def _encode_zstd(text: str, dictionary: object | None) -> bytes:
_reject_dictionary("zstd", dictionary)
try:
return zstd.ZstdCompressor().compress(text.encode("utf-8"))
except zstd.ZstdError as e:
raise ValueError(f"zstd encode failed: {e}") from e


def _decode_zstd(blob: bytes, dictionary: object | None) -> str:
_reject_dictionary("zstd", dictionary)
try:
return zstd.ZstdDecompressor().decompress(blob).decode("utf-8")
except zstd.ZstdError as e:
raise ValueError(f"zstd decode failed: {e}") from e


def _encode_zstd_dict(text: str, dictionary: object | None) -> bytes:
try:
return zstd.ZstdCompressor(dict_data=_coerce_dictionary(dictionary)).compress(
text.encode("utf-8")
)
except zstd.ZstdError as e:
raise ValueError(f"zstd dictionary encode failed: {e}") from e


def _decode_zstd_dict(blob: bytes, dictionary: object | None) -> str:
try:
return (
zstd.ZstdDecompressor(dict_data=_coerce_dictionary(dictionary))
.decompress(blob)
.decode("utf-8")
)
except zstd.ZstdError as e:
raise ValueError(f"zstd dictionary decode failed: {e}") from e


def _reject_dictionary(codec: str, dictionary: object | None) -> None:
if dictionary is not None:
raise ValueError(f"Codec {codec!r} does not use a dictionary")


def _coerce_dictionary(dictionary: object | None) -> zstd.ZstdCompressionDict:
if dictionary is None:
raise ValueError("Codec 'zstd-dict-v1' requires an explicit dictionary")
if isinstance(dictionary, zstd.ZstdCompressionDict):
return dictionary
if isinstance(dictionary, bytes):
return zstd.ZstdCompressionDict(dictionary)
if isinstance(dictionary, bytearray | memoryview):
return zstd.ZstdCompressionDict(bytes(dictionary))
raise TypeError(f"Unsupported zstd dictionary object: {type(dictionary).__name__}")


_REGISTRY: dict[str, _Codec] = {
"none": _Codec(_encode_none, _decode_none),
"zstd": _Codec(_encode_zstd, _decode_zstd),
"zstd-dict-v1": _Codec(_encode_zstd_dict, _decode_zstd_dict),
}

__all__ = ["decode", "encode", "list_supported"]
42 changes: 35 additions & 7 deletions src/mcp_server_python_docs/services/persistent_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

from pydantic import ValidationError

from mcp_server_python_docs.cache.codec import decode as decode_cache_payload
from mcp_server_python_docs.cache.codec import encode as encode_cache_payload
from mcp_server_python_docs.models import GetDocsResult

logger = logging.getLogger(__name__)
_NO_ANCHOR_KEY = "\x00mcp-python-docs:no-anchor\x00"
DEFAULT_RETRIEVED_DOCS_CACHE_CODEC = "zstd"


class CacheStats(NamedTuple):
Expand All @@ -25,8 +28,15 @@ class CacheStats(NamedTuple):
class PersistentDocsCache:
"""Persist get_docs results by index fingerprint, version, and request identity."""

def __init__(self, cache_path: Path, index_path: Path) -> None:
def __init__(
self,
cache_path: Path,
index_path: Path,
*,
default_codec: str = DEFAULT_RETRIEVED_DOCS_CACHE_CODEC,
) -> None:
self._cache_path = Path(cache_path)
self._default_codec = default_codec
# Set after fingerprint stat succeeds; stays "" if init fails so the
# cache disables cleanly without leaking partial state.
self._fingerprint = ""
Expand All @@ -47,9 +57,11 @@ def __init__(self, cache_path: Path, index_path: Path) -> None:
"CREATE TABLE IF NOT EXISTS retrieved_docs_cache ("
"index_fingerprint TEXT NOT NULL, version TEXT NOT NULL, slug TEXT NOT NULL, "
"anchor TEXT NOT NULL, max_chars INTEGER NOT NULL, start_index INTEGER NOT NULL, "
"result_json TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, "
"result_json TEXT NOT NULL, compression TEXT NOT NULL DEFAULT 'none', "
"created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, "
"PRIMARY KEY (index_fingerprint, version, slug, anchor, max_chars, start_index))"
)
self._ensure_compression_column()
self._conn.execute(
"DELETE FROM retrieved_docs_cache WHERE index_fingerprint != ?",
(self._fingerprint,),
Expand All @@ -74,6 +86,18 @@ def _fingerprint_index(index_path: Path) -> str:
def _anchor_key(anchor: str | None) -> str:
return _NO_ANCHOR_KEY if anchor is None else anchor

def _ensure_compression_column(self) -> None:
if self._conn is None:
return
columns = {
row[1] for row in self._conn.execute("PRAGMA table_info(retrieved_docs_cache)")
}
if "compression" not in columns:
self._conn.execute(
"ALTER TABLE retrieved_docs_cache "
"ADD COLUMN compression TEXT NOT NULL DEFAULT 'none'"
)

def stats(self) -> CacheStats:
return CacheStats(self._hits, self._misses, self._writes)

Expand All @@ -87,7 +111,8 @@ def get(
with self._lock:
try:
row = self._conn.execute(
"SELECT result_json FROM retrieved_docs_cache WHERE index_fingerprint = ? "
"SELECT result_json, compression FROM retrieved_docs_cache "
"WHERE index_fingerprint = ? "
"AND version = ? AND slug = ? AND anchor = ? AND max_chars = ? "
"AND start_index = ?",
(
Expand All @@ -107,8 +132,10 @@ def get(
self._misses += 1
return None
try:
result = GetDocsResult.model_validate_json(row[0])
except (ValidationError, ValueError) as e:
payload = row[0].encode("utf-8") if isinstance(row[0], str) else bytes(row[0])
result_json = decode_cache_payload(payload, row[1])
result = GetDocsResult.model_validate_json(result_json)
except (ValidationError, ValueError, TypeError) as e:
self._misses += 1
logger.warning("Persistent docs cache entry ignored: %s", e)
return None
Expand All @@ -123,15 +150,16 @@ def put(self, *, result: GetDocsResult, max_chars: int, start_index: int) -> Non
self._conn.execute(
"INSERT OR REPLACE INTO retrieved_docs_cache "
"(index_fingerprint, version, slug, anchor, max_chars, start_index, "
"result_json) VALUES (?, ?, ?, ?, ?, ?, ?)",
"result_json, compression) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(
self._fingerprint,
result.version,
result.slug,
self._anchor_key(result.anchor),
max_chars,
start_index,
result.model_dump_json(),
encode_cache_payload(result.model_dump_json(), self._default_codec),
self._default_codec,
),
)
self._conn.commit()
Expand Down
1 change: 1 addition & 0 deletions tests/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Cache tests."""
49 changes: 49 additions & 0 deletions tests/cache/test_codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Cache codec coverage."""

from __future__ import annotations

import zstandard as zstd

from mcp_server_python_docs.cache.codec import decode, encode, list_supported


def _test_dictionary() -> zstd.ZstdCompressionDict:
samples = [
(
f"Python documentation section {i}: json dumps loads encoder decoder "
"arguments return values exceptions examples. "
).encode("utf-8")
* 8
for i in range(64)
]
return zstd.train_dictionary(512, samples)


def test_list_supported_is_stable() -> None:
assert list_supported() == ["none", "zstd", "zstd-dict-v1"]


def test_none_round_trips_text() -> None:
text = '{"content":"plain json payload","version":"3.13"}'
encoded = encode(text, "none")
assert encoded == text.encode("utf-8")
assert decode(encoded, "none") == text


def test_zstd_round_trips_text() -> None:
text = '{"content":"compressed json payload","version":"3.13"}'
encoded = encode(text, "zstd")
assert encoded != text.encode("utf-8")
assert decode(encoded, "zstd") == text


def test_zstd_dict_v1_round_trips_with_explicit_dictionary() -> None:
dictionary = _test_dictionary()
text = "Python documentation section 7: json dumps loads encoder decoder arguments."
encoded = encode(text, "zstd-dict-v1", dictionary=dictionary)
assert decode(encoded, "zstd-dict-v1", dictionary=dictionary) == text


def test_none_decodes_payload_from_prior_server_version() -> None:
prior_payload = b'{"content":"legacy uncompressed json","version":"3.12"}'
assert decode(prior_payload, "none") == prior_payload.decode("utf-8")
6 changes: 4 additions & 2 deletions tests/test_mcp_get_docs_cache_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,12 @@ def test_get_docs_cache_restart_and_corrupt_cache_fallback(tmp_path: Path):

with sqlite3.connect(cache_path) as conn:
rows = conn.execute(
"SELECT version, slug, anchor, max_chars, start_index, length(result_json) "
"SELECT version, slug, anchor, max_chars, start_index, "
"length(result_json), compression "
"FROM retrieved_docs_cache"
).fetchall()
assert len(rows) == 1
version, slug, anchor, max_chars, start_index, result_json_length = rows[0]
version, slug, anchor, max_chars, start_index, result_json_length, compression = rows[0]
assert (version, slug, anchor, max_chars, start_index) == (
"3.13",
"library/json.html",
Expand All @@ -153,6 +154,7 @@ def test_get_docs_cache_restart_and_corrupt_cache_fallback(tmp_path: Path):
0,
)
assert result_json_length > 0
assert compression == "zstd"

restarted_page = _tool_structured_content(
_run_server(
Expand Down
72 changes: 71 additions & 1 deletion tests/test_persistent_docs_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mcp_server_python_docs.models import GetDocsResult
from mcp_server_python_docs.services.content import ContentService
from mcp_server_python_docs.services.persistent_cache import PersistentDocsCache
from mcp_server_python_docs.services.persistent_cache import _NO_ANCHOR_KEY, PersistentDocsCache


def _doc(db, version: str, content: str, default: int = 0) -> None:
Expand Down Expand Up @@ -68,6 +68,76 @@ def test_cache_survives_restart_and_miss_falls_back(populated_db, tmp_path: Path
assert restarted.stats().hits == 1


def test_current_default_codec_reads_identically_after_restart(tmp_path: Path):
index_path, cache = _cache(tmp_path)
expected = _result("compressed docs payload")
cache.put(result=expected, max_chars=500, start_index=0)

with sqlite3.connect(cache.cache_path) as conn:
compression = conn.execute("SELECT compression FROM retrieved_docs_cache").fetchone()[0]
assert compression == "zstd"

restarted = PersistentDocsCache(tmp_path / "retrieved.sqlite3", index_path)
assert (
restarted.get(
version="3.12",
slug="library/json.html",
anchor=None,
max_chars=500,
start_index=0,
)
== expected
)
assert restarted.stats().hits == 1


def test_legacy_uncompressed_cache_row_migrates_and_reads(tmp_path: Path):
index_path = tmp_path / "index.db"
index_path.write_bytes(b"index-1")
fingerprint = PersistentDocsCache._fingerprint_index(index_path)
cache_path = tmp_path / "retrieved.sqlite3"
expected = _result("legacy docs payload")
with sqlite3.connect(cache_path) as conn:
conn.execute(
"CREATE TABLE retrieved_docs_cache ("
"index_fingerprint TEXT NOT NULL, version TEXT NOT NULL, slug TEXT NOT NULL, "
"anchor TEXT NOT NULL, max_chars INTEGER NOT NULL, start_index INTEGER NOT NULL, "
"result_json TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, "
"PRIMARY KEY (index_fingerprint, version, slug, anchor, max_chars, start_index))"
)
conn.execute(
"INSERT INTO retrieved_docs_cache "
"(index_fingerprint, version, slug, anchor, max_chars, start_index, result_json) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(
fingerprint,
expected.version,
expected.slug,
_NO_ANCHOR_KEY,
500,
0,
expected.model_dump_json(),
),
)

migrated = PersistentDocsCache(cache_path, index_path)
assert (
migrated.get(
version="3.12",
slug="library/json.html",
anchor=None,
max_chars=500,
start_index=0,
)
== expected
)
with sqlite3.connect(cache_path) as conn:
columns = {row[1] for row in conn.execute("PRAGMA table_info(retrieved_docs_cache)")}
compression = conn.execute("SELECT compression FROM retrieved_docs_cache").fetchone()[0]
assert "compression" in columns
assert compression == "none"


def test_cache_key_includes_python_version(populated_db, tmp_path: Path):
_doc(populated_db, "3.12", "docs for 3.12")
_doc(populated_db, "3.13", "docs for 3.13", 1)
Expand Down