|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import hashlib |
| 4 | +import os |
| 5 | +import uuid |
| 6 | +from pathlib import Path |
| 7 | +from urllib.request import urlopen |
| 8 | + |
| 9 | + |
| 10 | +DEEPDOC_TIKTOKEN_CACHE_DIR_ENV = "DEEPDOC_TIKTOKEN_CACHE_DIR" |
| 11 | +CL100K_BASE_BLOB_URL = "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" |
| 12 | +CL100K_BASE_EXPECTED_HASH = "223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7" |
| 13 | + |
| 14 | + |
| 15 | +def resolve_tiktoken_cache_dir(cache_dir: str | None = None, model_home: str | None = None) -> Path: |
| 16 | + if cache_dir: |
| 17 | + return Path(cache_dir).expanduser().resolve() |
| 18 | + |
| 19 | + explicit_cache_dir = os.getenv(DEEPDOC_TIKTOKEN_CACHE_DIR_ENV) or os.getenv("TIKTOKEN_CACHE_DIR") |
| 20 | + if explicit_cache_dir: |
| 21 | + return Path(explicit_cache_dir).expanduser().resolve() |
| 22 | + |
| 23 | + configured_model_home = model_home or os.getenv("DEEPDOC_MODEL_HOME") |
| 24 | + if configured_model_home: |
| 25 | + return Path(configured_model_home).expanduser().resolve().joinpath("tiktoken_cache") |
| 26 | + |
| 27 | + return Path.home().joinpath(".cache", "deepdoc", "tiktoken_cache").resolve() |
| 28 | + |
| 29 | + |
| 30 | +def configure_tiktoken_cache_env(cache_dir: str | None = None, model_home: str | None = None) -> str: |
| 31 | + resolved_cache_dir = resolve_tiktoken_cache_dir(cache_dir=cache_dir, model_home=model_home) |
| 32 | + os.environ["TIKTOKEN_CACHE_DIR"] = str(resolved_cache_dir) |
| 33 | + return str(resolved_cache_dir) |
| 34 | + |
| 35 | + |
| 36 | +def cl100k_base_cache_key(blob_url: str = CL100K_BASE_BLOB_URL) -> str: |
| 37 | + return hashlib.sha1(blob_url.encode()).hexdigest() |
| 38 | + |
| 39 | + |
| 40 | +def _matches_expected_hash(data: bytes, expected_hash: str | None) -> bool: |
| 41 | + if not expected_hash: |
| 42 | + return True |
| 43 | + return hashlib.sha256(data).hexdigest() == expected_hash |
| 44 | + |
| 45 | + |
| 46 | +def download_cl100k_base( |
| 47 | + *, |
| 48 | + cache_dir: str | None = None, |
| 49 | + model_home: str | None = None, |
| 50 | + offline: bool = False, |
| 51 | + blob_url: str = CL100K_BASE_BLOB_URL, |
| 52 | + expected_hash: str = CL100K_BASE_EXPECTED_HASH, |
| 53 | + timeout: int = 60, |
| 54 | +) -> Path: |
| 55 | + resolved_cache_dir = resolve_tiktoken_cache_dir(cache_dir=cache_dir, model_home=model_home) |
| 56 | + cache_key = cl100k_base_cache_key(blob_url) |
| 57 | + target_path = resolved_cache_dir.joinpath(cache_key) |
| 58 | + |
| 59 | + if target_path.exists(): |
| 60 | + data = target_path.read_bytes() |
| 61 | + if _matches_expected_hash(data, expected_hash): |
| 62 | + return target_path |
| 63 | + target_path.unlink() |
| 64 | + |
| 65 | + if offline: |
| 66 | + raise FileNotFoundError( |
| 67 | + "Missing cached tiktoken encoder '{}'. Expected file at {}. Run the download command without --offline first." |
| 68 | + .format(cache_key, target_path) |
| 69 | + ) |
| 70 | + |
| 71 | + with urlopen(blob_url, timeout=timeout) as response: |
| 72 | + data = response.read() |
| 73 | + |
| 74 | + if not _matches_expected_hash(data, expected_hash): |
| 75 | + raise ValueError( |
| 76 | + "Hash mismatch for tiktoken encoder downloaded from {}. Expected SHA256 {}." |
| 77 | + .format(blob_url, expected_hash) |
| 78 | + ) |
| 79 | + |
| 80 | + resolved_cache_dir.mkdir(parents=True, exist_ok=True) |
| 81 | + tmp_path = target_path.with_name("{}.{}.tmp".format(target_path.name, uuid.uuid4().hex)) |
| 82 | + tmp_path.write_bytes(data) |
| 83 | + tmp_path.replace(target_path) |
| 84 | + return target_path |
0 commit comments