Skip to content

Commit 356098f

Browse files
authored
feat: implement download and configure tiktoken cache (#15)
1 parent 88f420a commit 356098f

File tree

4 files changed

+181
-26
lines changed

4 files changed

+181
-26
lines changed

deepdoc/common/tiktoken_cache.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
import hashlib
4+
import os
5+
import uuid
6+
from pathlib import Path
7+
from urllib.request import urlopen
8+
9+
10+
DEEPDOC_TIKTOKEN_CACHE_DIR_ENV = "DEEPDOC_TIKTOKEN_CACHE_DIR"
11+
CL100K_BASE_BLOB_URL = "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
12+
CL100K_BASE_EXPECTED_HASH = "223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7"
13+
14+
15+
def resolve_tiktoken_cache_dir(cache_dir: str | None = None, model_home: str | None = None) -> Path:
16+
if cache_dir:
17+
return Path(cache_dir).expanduser().resolve()
18+
19+
explicit_cache_dir = os.getenv(DEEPDOC_TIKTOKEN_CACHE_DIR_ENV) or os.getenv("TIKTOKEN_CACHE_DIR")
20+
if explicit_cache_dir:
21+
return Path(explicit_cache_dir).expanduser().resolve()
22+
23+
configured_model_home = model_home or os.getenv("DEEPDOC_MODEL_HOME")
24+
if configured_model_home:
25+
return Path(configured_model_home).expanduser().resolve().joinpath("tiktoken_cache")
26+
27+
return Path.home().joinpath(".cache", "deepdoc", "tiktoken_cache").resolve()
28+
29+
30+
def configure_tiktoken_cache_env(cache_dir: str | None = None, model_home: str | None = None) -> str:
31+
resolved_cache_dir = resolve_tiktoken_cache_dir(cache_dir=cache_dir, model_home=model_home)
32+
os.environ["TIKTOKEN_CACHE_DIR"] = str(resolved_cache_dir)
33+
return str(resolved_cache_dir)
34+
35+
36+
def cl100k_base_cache_key(blob_url: str = CL100K_BASE_BLOB_URL) -> str:
37+
return hashlib.sha1(blob_url.encode()).hexdigest()
38+
39+
40+
def _matches_expected_hash(data: bytes, expected_hash: str | None) -> bool:
41+
if not expected_hash:
42+
return True
43+
return hashlib.sha256(data).hexdigest() == expected_hash
44+
45+
46+
def download_cl100k_base(
47+
*,
48+
cache_dir: str | None = None,
49+
model_home: str | None = None,
50+
offline: bool = False,
51+
blob_url: str = CL100K_BASE_BLOB_URL,
52+
expected_hash: str = CL100K_BASE_EXPECTED_HASH,
53+
timeout: int = 60,
54+
) -> Path:
55+
resolved_cache_dir = resolve_tiktoken_cache_dir(cache_dir=cache_dir, model_home=model_home)
56+
cache_key = cl100k_base_cache_key(blob_url)
57+
target_path = resolved_cache_dir.joinpath(cache_key)
58+
59+
if target_path.exists():
60+
data = target_path.read_bytes()
61+
if _matches_expected_hash(data, expected_hash):
62+
return target_path
63+
target_path.unlink()
64+
65+
if offline:
66+
raise FileNotFoundError(
67+
"Missing cached tiktoken encoder '{}'. Expected file at {}. Run the download command without --offline first."
68+
.format(cache_key, target_path)
69+
)
70+
71+
with urlopen(blob_url, timeout=timeout) as response:
72+
data = response.read()
73+
74+
if not _matches_expected_hash(data, expected_hash):
75+
raise ValueError(
76+
"Hash mismatch for tiktoken encoder downloaded from {}. Expected SHA256 {}."
77+
.format(blob_url, expected_hash)
78+
)
79+
80+
resolved_cache_dir.mkdir(parents=True, exist_ok=True)
81+
tmp_path = target_path.with_name("{}.{}.tmp".format(target_path.name, uuid.uuid4().hex))
82+
tmp_path.write_bytes(data)
83+
tmp_path.replace(target_path)
84+
return target_path

deepdoc/common/token_utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,26 @@
1414
# limitations under the License.
1515
#
1616

17-
import os
1817
import tiktoken
1918

20-
from .file_utils import get_project_base_directory
19+
from .tiktoken_cache import configure_tiktoken_cache_env
2120

22-
tiktoken_cache_dir = get_project_base_directory()
23-
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
24-
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
25-
encoder = tiktoken.get_encoding("cl100k_base")
21+
22+
_encoder = None
23+
24+
25+
def _get_encoder():
26+
global _encoder
27+
if _encoder is None:
28+
configure_tiktoken_cache_env()
29+
_encoder = tiktoken.get_encoding("cl100k_base")
30+
return _encoder
2631

2732

2833
def num_tokens_from_string(string: str) -> int:
2934
"""Returns the number of tokens in a text string."""
3035
try:
31-
code_list = encoder.encode(string)
36+
code_list = _get_encoder().encode(string)
3237
return len(code_list)
3338
except Exception:
3439
return 0
@@ -84,4 +89,8 @@ def total_token_count_from_response(resp):
8489

8590
def truncate(string: str, max_len: int) -> str:
8691
"""Returns truncated text if the length of text exceed max_len."""
87-
return encoder.decode(encoder.encode(string)[:max_len])
92+
try:
93+
encoder = _get_encoder()
94+
return encoder.decode(encoder.encode(string)[:max_len])
95+
except Exception:
96+
return string[:max_len]

deepdoc/depend/num_tokens_from_string.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

tests/test_tiktoken_cache.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import hashlib
2+
import os
3+
import tempfile
4+
import unittest
5+
from pathlib import Path
6+
from unittest.mock import patch
7+
8+
9+
from deepdoc.common import tiktoken_cache as tc
10+
11+
12+
class _FakeResponse:
13+
def __init__(self, data: bytes) -> None:
14+
self._data = data
15+
16+
def read(self) -> bytes:
17+
return self._data
18+
19+
def __enter__(self) -> "_FakeResponse":
20+
return self
21+
22+
def __exit__(self, exc_type, exc, tb) -> None:
23+
return None
24+
25+
26+
class TestTiktokenCache(unittest.TestCase):
27+
def setUp(self) -> None:
28+
self._old_env = os.environ.copy()
29+
30+
def tearDown(self) -> None:
31+
os.environ.clear()
32+
os.environ.update(self._old_env)
33+
34+
def test_resolve_tiktoken_cache_dir_uses_model_home(self) -> None:
35+
with tempfile.TemporaryDirectory() as tmp:
36+
os.environ.pop("TIKTOKEN_CACHE_DIR", None)
37+
os.environ.pop(tc.DEEPDOC_TIKTOKEN_CACHE_DIR_ENV, None)
38+
os.environ["DEEPDOC_MODEL_HOME"] = tmp
39+
40+
resolved = tc.resolve_tiktoken_cache_dir()
41+
42+
self.assertEqual(resolved, Path(tmp).resolve().joinpath("tiktoken_cache"))
43+
44+
def test_download_cl100k_base_writes_cache_key_named_file(self) -> None:
45+
payload = b"test-tiktoken-data"
46+
blob_url = "https://example.com/cl100k_base.tiktoken"
47+
expected_hash = hashlib.sha256(payload).hexdigest()
48+
49+
with tempfile.TemporaryDirectory() as tmp:
50+
with patch.object(tc, "urlopen", return_value=_FakeResponse(payload)) as mocked_urlopen:
51+
target = tc.download_cl100k_base(
52+
cache_dir=tmp,
53+
blob_url=blob_url,
54+
expected_hash=expected_hash,
55+
)
56+
57+
self.assertEqual(target, Path(tmp).resolve().joinpath(hashlib.sha1(blob_url.encode()).hexdigest()))
58+
self.assertEqual(target.read_bytes(), payload)
59+
mocked_urlopen.assert_called_once_with(blob_url, timeout=60)
60+
61+
def test_download_cl100k_base_offline_uses_existing_cache(self) -> None:
62+
payload = b"cached-tiktoken-data"
63+
blob_url = "https://example.com/cl100k_base.tiktoken"
64+
expected_hash = hashlib.sha256(payload).hexdigest()
65+
66+
with tempfile.TemporaryDirectory() as tmp:
67+
target = Path(tmp).resolve().joinpath(hashlib.sha1(blob_url.encode()).hexdigest())
68+
target.parent.mkdir(parents=True, exist_ok=True)
69+
target.write_bytes(payload)
70+
71+
with patch.object(tc, "urlopen") as mocked_urlopen:
72+
resolved = tc.download_cl100k_base(
73+
cache_dir=tmp,
74+
blob_url=blob_url,
75+
expected_hash=expected_hash,
76+
offline=True,
77+
)
78+
79+
self.assertEqual(resolved, target)
80+
mocked_urlopen.assert_not_called()

0 commit comments

Comments
 (0)