Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ⏳ tiktoken

# **Tiktoken Fork for Test and Bug Fixes, CPU Performance Enhancement and More.**

tiktoken is a fast [BPE](https://en.wikipedia.org/wiki/Byte_pair_encoding) tokeniser for use with
OpenAI's models.

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ impl CoreBPE {

/// Decodes tokens into a list of bytes.
///
/// The bytes are not gauranteed to be a valid utf-8 string.
/// The bytes are not guaranteed to be a valid utf-8 string.
fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for &token in tokens {
Expand Down
30 changes: 30 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,33 @@ def test_optional_blobfile_dependency():
assert "blobfile" not in sys.modules
"""
subprocess.check_call([sys.executable, "-c", prog])


def test_is_special_token():
enc = tiktoken.get_encoding("gpt2")
eot_token = enc.eot_token
# The eot_token should be identified as a special token
assert enc.is_special_token(eot_token) is True
# Token 0 is a regular mergeable token, not special
assert enc.is_special_token(0) is False


def test_max_threads_default():
import os

from tiktoken.core import _MAX_THREADS

cpu_count = os.cpu_count() or 8
assert _MAX_THREADS == min(cpu_count, 32)
assert _MAX_THREADS >= 1


def test_list_encoding_names_optimized():
"""Test that list_encoding_names works even with assertions disabled (python -O)."""
prog = """
import tiktoken
names = tiktoken.list_encoding_names()
assert len(names) > 0
assert "gpt2" in names
Comment on lines +57 to +58
"""
subprocess.check_call([sys.executable, "-O", "-c", prog])
29 changes: 19 additions & 10 deletions tiktoken/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import os
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence

Expand All @@ -12,6 +13,9 @@
import numpy as np
import numpy.typing as npt

# Default thread count for batch operations, based on available CPU cores
_MAX_THREADS = min(os.cpu_count() or 8, 32)


class Encoding:
def __init__(
Expand Down Expand Up @@ -155,10 +159,14 @@ def encode_to_numpy(

import numpy as np

buffer = self._core_bpe.encode_to_tiktoken_buffer(text, allowed_special)
try:
buffer = self._core_bpe.encode_to_tiktoken_buffer(text, allowed_special)
except UnicodeEncodeError:
text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
buffer = self._core_bpe.encode_to_tiktoken_buffer(text, allowed_special)
return np.frombuffer(buffer, dtype=np.uint32)
Comment on lines +162 to 167

def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
def encode_ordinary_batch(self, text: list[str], *, num_threads: int = _MAX_THREADS) -> list[list[int]]:
"""Encodes a list of strings into tokens, in parallel, ignoring special tokens.

This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster).
Expand All @@ -176,7 +184,7 @@ def encode_batch(
self,
text: list[str],
*,
num_threads: int = 8,
num_threads: int = _MAX_THREADS,
allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006
disallowed_special: Literal["all"] | Collection[str] = "all",
) -> list[list[int]]:
Expand All @@ -193,8 +201,9 @@ def encode_batch(
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - allowed_special
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)
if disallowed_special:
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)

encoder = functools.partial(
self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special
Expand Down Expand Up @@ -332,15 +341,15 @@ def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]:
return text, offsets

def decode_batch(
self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = 8
self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = _MAX_THREADS
) -> list[str]:
"""Decodes a batch (list of lists of tokens) into a list of strings."""
decoder = functools.partial(self.decode, errors=errors)
with ThreadPoolExecutor(num_threads) as e:
return list(e.map(decoder, batch))

def decode_bytes_batch(
self, batch: Sequence[Sequence[int]], *, num_threads: int = 8
self, batch: Sequence[Sequence[int]], *, num_threads: int = _MAX_THREADS
) -> list[bytes]:
"""Decodes a batch (list of lists of tokens) into a list of bytes."""
with ThreadPoolExecutor(num_threads) as e:
Expand All @@ -364,7 +373,7 @@ def special_tokens_set(self) -> set[str]:

def is_special_token(self, token: int) -> bool:
assert isinstance(token, int)
return token in self._special_token_values
return token in self._special_tokens.values()
Comment on lines 374 to +376

@property
def n_vocab(self) -> int:
Expand Down Expand Up @@ -394,9 +403,9 @@ def _encode_only_native_bpe(self, text: str) -> list[int]:
# We need specifically `regex` in order to compile pat_str due to e.g. \p
import regex

_unused_pat = regex.compile(self._pat_str)
pat = regex.compile(self._pat_str)
ret = []
for piece in regex.findall(_unused_pat, text):
for piece in regex.findall(pat, text):
ret.extend(self._core_bpe.encode_single_piece(piece.encode("utf-8")))
return ret

Expand Down
17 changes: 11 additions & 6 deletions tiktoken/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def _find_constructors() -> None:
raise


def _ensure_constructors_loaded() -> None:
if ENCODING_CONSTRUCTORS is None:
_find_constructors()
if ENCODING_CONSTRUCTORS is None:
raise RuntimeError(
"Could not find any encoding constructors. "
f"Plugins found: {_available_plugin_modules()}\n"
f"tiktoken version: {tiktoken.__version__}"
)


def get_encoding(encoding_name: str) -> Encoding:
Expand All @@ -71,9 +80,7 @@ def get_encoding(encoding_name: str) -> Encoding:
if encoding_name in ENCODINGS:
return ENCODINGS[encoding_name]

if ENCODING_CONSTRUCTORS is None:
_find_constructors()
assert ENCODING_CONSTRUCTORS is not None
_ensure_constructors_loaded()

if encoding_name not in ENCODING_CONSTRUCTORS:
raise ValueError(
Expand All @@ -90,7 +97,5 @@ def get_encoding(encoding_name: str) -> Encoding:

def list_encoding_names() -> list[str]:
with _lock:
if ENCODING_CONSTRUCTORS is None:
_find_constructors()
assert ENCODING_CONSTRUCTORS is not None
_ensure_constructors_loaded()
return list(ENCODING_CONSTRUCTORS)