diff --git a/README.md b/README.md index 4f36c537..5f2664e1 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # ⏳ tiktoken +# **Tiktoken Fork for Test and Bug Fixes, CPU Performance Enhancement and More.** + tiktoken is a fast [BPE](https://en.wikipedia.org/wiki/Byte_pair_encoding) tokeniser for use with OpenAI's models. diff --git a/src/lib.rs b/src/lib.rs index ea54eac8..b6b8ea28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -341,7 +341,7 @@ impl CoreBPE { /// Decodes tokens into a list of bytes. /// - /// The bytes are not gauranteed to be a valid utf-8 string. + /// The bytes are not guaranteed to be a valid utf-8 string. fn decode_bytes(&self, tokens: &[Rank]) -> Result, DecodeKeyError> { let mut ret = Vec::with_capacity(tokens.len() * 2); for &token in tokens { diff --git a/tests/test_misc.py b/tests/test_misc.py index 0832c8ee..fca6b3fc 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -28,3 +28,33 @@ def test_optional_blobfile_dependency(): assert "blobfile" not in sys.modules """ subprocess.check_call([sys.executable, "-c", prog]) + + +def test_is_special_token(): + enc = tiktoken.get_encoding("gpt2") + eot_token = enc.eot_token + # The eot_token should be identified as a special token + assert enc.is_special_token(eot_token) is True + # Token 0 is a regular mergeable token, not special + assert enc.is_special_token(0) is False + + +def test_max_threads_default(): + import os + + from tiktoken.core import _MAX_THREADS + + cpu_count = os.cpu_count() or 8 + assert _MAX_THREADS == min(cpu_count, 32) + assert _MAX_THREADS >= 1 + + +def test_list_encoding_names_optimized(): + """Test that list_encoding_names works even with assertions disabled (python -O).""" + prog = """ +import tiktoken +names = tiktoken.list_encoding_names() +assert len(names) > 0 +assert "gpt2" in names +""" + subprocess.check_call([sys.executable, "-O", "-c", prog]) diff --git a/tiktoken/core.py b/tiktoken/core.py index 225fffb8..05103099 100644 --- a/tiktoken/core.py +++ b/tiktoken/core.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import os from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence @@ -12,6 +13,9 @@ import numpy as np import numpy.typing as npt +# Default thread count for batch operations, based on available CPU cores +_MAX_THREADS = min(os.cpu_count() or 8, 32) + class Encoding: def __init__( @@ -155,10 +159,14 @@ def encode_to_numpy( import numpy as np - buffer = self._core_bpe.encode_to_tiktoken_buffer(text, allowed_special) + try: + buffer = self._core_bpe.encode_to_tiktoken_buffer(text, allowed_special) + except UnicodeEncodeError: + text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace") + buffer = self._core_bpe.encode_to_tiktoken_buffer(text, allowed_special) return np.frombuffer(buffer, dtype=np.uint32) - def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]: + def encode_ordinary_batch(self, text: list[str], *, num_threads: int = _MAX_THREADS) -> list[list[int]]: """Encodes a list of strings into tokens, in parallel, ignoring special tokens. This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster). @@ -176,7 +184,7 @@ def encode_batch( self, text: list[str], *, - num_threads: int = 8, + num_threads: int = _MAX_THREADS, allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 disallowed_special: Literal["all"] | Collection[str] = "all", ) -> list[list[int]]: @@ -193,8 +201,9 @@ def encode_batch( allowed_special = self.special_tokens_set if disallowed_special == "all": disallowed_special = self.special_tokens_set - allowed_special - if not isinstance(disallowed_special, frozenset): - disallowed_special = frozenset(disallowed_special) + if disallowed_special: + if not isinstance(disallowed_special, frozenset): + disallowed_special = frozenset(disallowed_special) encoder = functools.partial( self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special @@ -332,7 +341,7 @@ def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]: return text, offsets def decode_batch( - self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = 8 + self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = _MAX_THREADS ) -> list[str]: """Decodes a batch (list of lists of tokens) into a list of strings.""" decoder = functools.partial(self.decode, errors=errors) @@ -340,7 +349,7 @@ def decode_batch( return list(e.map(decoder, batch)) def decode_bytes_batch( - self, batch: Sequence[Sequence[int]], *, num_threads: int = 8 + self, batch: Sequence[Sequence[int]], *, num_threads: int = _MAX_THREADS ) -> list[bytes]: """Decodes a batch (list of lists of tokens) into a list of bytes.""" with ThreadPoolExecutor(num_threads) as e: @@ -364,7 +373,7 @@ def special_tokens_set(self) -> set[str]: def is_special_token(self, token: int) -> bool: assert isinstance(token, int) - return token in self._special_token_values + return token in self._special_tokens.values() @property def n_vocab(self) -> int: @@ -394,9 +403,9 @@ def _encode_only_native_bpe(self, text: str) -> list[int]: # We need specifically `regex` in order to compile pat_str due to e.g. \p import regex - _unused_pat = regex.compile(self._pat_str) + pat = regex.compile(self._pat_str) ret = [] - for piece in regex.findall(_unused_pat, text): + for piece in regex.findall(pat, text): ret.extend(self._core_bpe.encode_single_piece(piece.encode("utf-8"))) return ret diff --git a/tiktoken/registry.py b/tiktoken/registry.py index 17c4574f..7e30477b 100644 --- a/tiktoken/registry.py +++ b/tiktoken/registry.py @@ -58,6 +58,15 @@ def _find_constructors() -> None: raise +def _ensure_constructors_loaded() -> None: + if ENCODING_CONSTRUCTORS is None: + _find_constructors() + if ENCODING_CONSTRUCTORS is None: + raise RuntimeError( + "Could not find any encoding constructors. " + f"Plugins found: {_available_plugin_modules()}\n" + f"tiktoken version: {tiktoken.__version__}" + ) def get_encoding(encoding_name: str) -> Encoding: @@ -71,9 +80,7 @@ def get_encoding(encoding_name: str) -> Encoding: if encoding_name in ENCODINGS: return ENCODINGS[encoding_name] - if ENCODING_CONSTRUCTORS is None: - _find_constructors() - assert ENCODING_CONSTRUCTORS is not None + _ensure_constructors_loaded() if encoding_name not in ENCODING_CONSTRUCTORS: raise ValueError( @@ -90,7 +97,5 @@ def get_encoding(encoding_name: str) -> Encoding: def list_encoding_names() -> list[str]: with _lock: - if ENCODING_CONSTRUCTORS is None: - _find_constructors() - assert ENCODING_CONSTRUCTORS is not None + _ensure_constructors_loaded() return list(ENCODING_CONSTRUCTORS)