Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ python = [
]

[dependencies]
pyo3 = { version = "0.26.0", default-features = false, features = [
pyo3 = { version = "0.27.2", default-features = false, features = [
"extension-module",
"macros",
], optional = true }

# tiktoken dependencies
fancy-regex = "0.13.0"
fancy-regex = "0.17.0"
regex = "1.10.3"
rustc-hash = "2"
bstr = "1.5.0"
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
readme = "README.md"
license = { file = "LICENSE" }
authors = [{ name = "Shantanu Jain" }, { email = "shantanu@openai.com" }]
dependencies = ["regex>=2022.1.18", "requests>=2.26.0"]
optional-dependencies = { blobfile = ["blobfile>=2"] }
dependencies = ["regex", "requests"]
optional-dependencies = { blobfile = ["blobfile>=3"] }
requires-python = ">=3.9"

[project.urls]
Expand Down
142 changes: 135 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::borrow::Borrow;
use std::borrow::Cow;
use std::collections::HashSet;
use std::num::NonZeroU64;
use std::thread;
Expand All @@ -14,6 +12,131 @@ mod py;

pub type Rank = u32;

use std::collections::BinaryHeap;

#[derive(Eq, PartialEq, Clone, Copy)]
struct Merge {
start: usize,
rank: Rank,
}

impl Ord for Merge {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.rank
.cmp(&self.rank)
.then_with(|| other.start.cmp(&self.start))
}
}

impl PartialOrd for Merge {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

struct State {
prev: usize,
end: usize,
next_end: usize,
next_rank: Rank,
cur_rank: Rank,
}

fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
let mut state = Vec::with_capacity(piece.len());
state.push(State {
prev: usize::MAX,
end: 1,
next_end: 2,
next_rank: Rank::MAX,
cur_rank: Rank::MAX,
});

let mut heap = BinaryHeap::with_capacity(piece.len());
for i in 0..piece.len() - 1 {
if let Some(&rank) = ranks.get(&piece[i..i + 2]) {
heap.push(Merge { start: i, rank });
state[i].next_rank = rank;
}
// note this is happening offset by 1
state.push(State {
prev: i,
end: i + 2,
next_end: i + 3,
next_rank: Rank::MAX,
cur_rank: Rank::MAX,
});
}

// Repeatedly find the valid merge with smallest rank. We merge the (left) token that
// starts at `start` and ends at `state[start].end` with the (right) token that starts at
// `state[start].end` and ends at `state[start].next_end`. We invalidate the old merges
// (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
// new potential merges to the heap.

let potential_merge = {
#[inline(always)]
|state: &mut Vec<State>,
heap: &mut BinaryHeap<Merge>,
start: usize,
next_end_item: usize| {
state[start].next_end = next_end_item;
state[start].next_rank = Rank::MAX; // Always invalidate the old merge
if next_end_item <= piece.len()
&& let Some(&rank) = ranks.get(&piece[start..next_end_item])
{
// We have a valid potential merge!
heap.push(Merge { start, rank });
state[start].next_rank = rank;
}
}
};

while let Some(left) = heap.pop() {
if left.rank == Rank::MAX {
break;
}
if left.rank != state[left.start].next_rank {
continue; // This merge was invalidated, ignore it
}

let left_start = left.start;
let right_start = state[left_start].end;
let right_end = state[left_start].next_end;
debug_assert!(right_end == state[right_start].end);
let right_next_end = state[right_start].next_end;

// Merge left and right into a single token
state[left_start].cur_rank = state[left_start].next_rank;
state[left_start].end = right_end;
potential_merge(&mut state, &mut heap, left_start, right_next_end);
if right_end < state.len() {
state[right_end].prev = left_start;
}
// Update the merge that ends at left_start
if left_start > 0 {
let prev_start = state[left_start].prev;
potential_merge(&mut state, &mut heap, prev_start, right_end);
}
// Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
state[right_start].next_rank = Rank::MAX;
}

let mut result = Vec::new();
let mut i = 0;
while i < state.len() {
if state[i].cur_rank != Rank::MAX {
result.push(state[i].cur_rank);
} else {
result.push(ranks[&piece[i..state[i].end]]);
}
i = state[i].end;
}
result
}

fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the pair starting at position start.
Expand Down Expand Up @@ -73,13 +196,18 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
if piece.len() == 1 {
let piece_len = piece.len();

if piece_len == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(ranks, piece)
.windows(2)
.map(|part| ranks[&piece[part[0].0..part[1].0]])
.collect()
if piece_len < 100 {
return _byte_pair_merge(ranks, piece)
.windows(2)
.map(|part| ranks[&piece[part[0].0..part[1].0]])
.collect();
}
_byte_pair_merge_large(ranks, piece)
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
Expand Down
5 changes: 3 additions & 2 deletions tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def test_simple_repeated():
def test_large_repeated():
enc = tiktoken.get_encoding("o200k_base")

with pytest.raises(ValueError):
enc.encode("x" * 1_000_000)
# Large inputs should be handled without raising.
tokens = enc.encode("x" * 1_000_000)
assert tokens


def test_simple_regex():
Expand Down
3 changes: 1 addition & 2 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def read_file(blobpath: str) -> bytes:
raise ImportError(
"blobfile is not installed. Please install it by running `pip install blobfile`."
) from e
with blobfile.BlobFile(blobpath, "rb") as f:
return f.read()
return blobfile.read_bytes(blobpath)


def check_hash(data: bytes, expected_hash: str) -> bool:
Expand Down
Loading