Table of Content
- š¼ Letās visualize it!
- Byte Representation and Character Encoding
- BPE: Byte Pair Encoding
- Implement with PyTorch
This post contains the knowledge nuggets I collected while learning minbpe. A big thanks to Andrej Karpathy for his super clean and elegant codeāsometimes, a single line of code is worth a thousand words.
I also received a lot of help from both ChatGPT and Claude in figuring things out, and they deserve credit for their invaluable assistance.
š¼ Letās visualize it!
Iāve created a D3.js visualization (the cover image) to showcase the vocabulary and merges trained by the BPE algorithm here.
Byte Representation and Character Encoding
Before diving into minbpe, letās first go through the basics of bytes and character encoding. Bytes are the fundamental units of digital information, consisting of 8 bits and capable of representing 256 distinct values (ranging from 0 to 255). In character encoding, each byte can represent a single character, particularly within ASCII or extended ASCII encoding frameworks.
ASCII and Extended ASCII
ASCII (American Standard Code for Information Interchange) is a character encoding standard that uses 7 bits to represent characters. This means it can represent 128 characters (from 0 to 127), including:
Category | Example(s) |
---|---|
Control Characters | null (\x00 ), newline (\x0A ) |
Digits |
48 for '0'
|
Uppercase Letters |
65 for 'A'
|
Lowercase Letters |
97 for 'a'
|
Punctuation Marks |
33 for '!' , 46 for '.' , 63 for '?'
|
Extended ASCII uses 8 bits (1 byte), allowing for 256 characters (from 0 to 255). Characters from 128 to 255 include various additional characters, such as accented letters, symbols, and graphical characters.
Decimal | Hex | Character | Description |
---|---|---|---|
128 | 0x80 | Ć | Latin capital letter C with cedilla |
129 | 0x81 | Ć¼ | Latin small letter u with diaeresis |
130 | 0x82 | Ć© | Latin small letter e with acute |
131 | 0x83 | Ć¢ | Latin small letter a with circumflex |
Bytes and Unicode in Python
In Python, the bytes
type is used to represent a sequence of bytes. When we create a bytes object with bytes([i])
, weāre creating a bytes object where the single byte is the byte representation of the integer i
. For example:
-
bytes([65])
creates a bytes objectb'A'
, where65
is the ASCII code for the character'A'
. -
bytes([65, 66])
creates a bytes objectb'AB'
, where66
is the ASCII code for the character'B'
.
When we iterate over a bytes object in Python, each element we get is actually an integer representing the byte value, not a single-byte bytes object. This behavior might seem counterintuitive at first, but itās by design. Letās break it down:
- A bytes object is an immutable sequence of integers in the range 0 to 255.
- When we iterate over a bytes object, Python yields these integer values one by one.
- These integers represent the ASCII value (or more generally, the byte value) of each byte in the sequence.
Hereās an example to illustrate this:
text_bytes = b"Hello"
for b in text_bytes:
print(f"{b} - {type(b)}")
This will output:
72 - <class 'int'>
101 - <class 'int'>
108 - <class 'int'>
108 - <class 'int'>
111 - <class 'int'>
However, this behavior is specific to bytes objects. When we iterate over a string (str object), we get individual characters (which are also strings of length 1), not their ASCII values.
UTF-8 Encoding and Decoding
When we use str.encode("utf-8")
in Python, weāre converting a string (which is a sequence of characters) into a bytes object (which is a sequence of bytes) using the "utf-8"
encoding. Each character in the string is translated into one or more bytes.
UTF-8 (Unicode Transformation Format - 8-bit) is a variable-length character encoding for Unicode. It can encode each of the 1,112,064
valid character code points in Unicode using one to four 8-bit bytes.
Representation Type | Unicode Range | Example Character | Unicode Code Point | Byte Representation |
---|---|---|---|---|
Single-byte (ASCII) |
U+0000 to U+007F
|
'A' |
U+0041 |
01000001 |
Two-byte |
U+0080 to U+07FF
|
'Ć©' |
U+00E9 |
11000011 10101001 |
Three-byte |
U+0800 to U+FFFF
|
'ą¤¹' |
U+0939 |
11100000 10100111 10011001 |
Four-byte |
U+10000 to U+10FFFF
|
'š' |
U+10348 |
11110000 10010000 10001101 10001000 |
When we use bytes.decode("utf-8", errors="replace")
in Python, we attempt to decode bytes
as UTF-8
. If there are any invalid byte sequences (which could occur if the bytes do not form valid UTF-8
sequences), these invalid sequences are replaced with the Unicode replacement character (ļæ½
).
# Define a string
s = "Hello, world! Ć©"
# Encode the string to UTF-8 bytes
encoded_s = s.encode("utf-8")
# encoded_s: b'Hello, world! \xc3\xa9'
decoded_s = encoded_s.decode("utf-8")
# decoded_s: Hello, world! Ć©
Unicode
Unicode is a universal character encoding standard designed to support the digital representation of the worldās writing systems. It aims to provide a unique code point (a number) for every character, regardless of platform, program, or language, thus enabling consistent encoding, representation, and handling of text.
Key Features of Unicode include:
- Universal: Unicode covers almost all written languages, symbols, and scripts used around the world.
-
Unique Code Points: Each character in Unicode is assigned a unique code point, which is a number written in the form
U+XXXX
(e.g.,U+0041
for the letter'A'
). - Scalability: Unicode can accommodate over a million unique characters, making it future-proof for adding new characters and symbols as needed.
Why 1,112,064 Valid Character Code Points?
The Unicode standard defines a range of code points from U+0000
to U+10FFFF
. This range allows for 1,114,112 possible code points:
-
U+0000
in decimal is0
. -
U+10FFFF
in decimal is:16 * 65536 + 65535
=1114111
-
10
in hexadecimal is16
in decimal. -
FFFF
in hexadecimal is65535
in decimal:15 * 16^0 + 15 * 16^1 + 15 * 16^2 + 15 * 16^3 = 65535
-
- Total number of code points:
1114111 - 0 + 1 = 1114112
However, not all of these code points are assigned to characters. Some code points are reserved for special purposes and are not assigned to characters. These include:
-
Surrogate pairs: Code points from
U+D800
toU+DFFF
are reserved forUTF-16
encoding and do not represent characters. - Private Use Areas: Code points that are reserved for private use and not assigned to any standard character.
- Non-characters: Code points that are reserved for internal use and will not be assigned characters.
After excluding these reserved ranges, there are 1,112,064
valid code points available for assigning characters.
BPE: Byte Pair Encoding
Now letās look at some key aspects of BPE, which I have examined closely to understand the algorithm better. Other details are already covered in the original code or Karpathyās video tutorial, so I wonāt repeat them here.
Training and Encoding
The training and encoding phases of BPE are mirrored processes, much like training and inference in machine learning models. This mirroring guarantees that the encoding process adheres strictly to the training process, ensuring consistent tokenization.
Aspect | Training | Encoding |
---|---|---|
Starting Point | Start with raw bytes | Start with raw bytes |
Pair Frequencies | Count pair frequencies | Count pair frequencies (but only use the pairs) |
Merge Strategy | Merge most frequent pair | Merge pair with lowest rank in learned merges |
Repetition | Stop until vocab size is reached | Stop when no more merges possible |
BasicTokenizer and RegexTokenizer
They are quite similar, but the major differences lie in the fact that RegexTokenizer
will first split the text into chunks using a regex pattern, then apply the tokenization process within each chunk independently.
It also handles special tokens. After splitting the text into chunks and special tokens, the rest of the process is the same as BasicTokenizer
.
Process | BasicTokenizer | RegexTokenizer |
---|---|---|
Text Processing | Treats entire text as one sequence | Splits text into chunks using regex pattern as well as special tokens |
Pair Counting | Counts pairs across whole text | Counts within chunks, then aggregates |
Merge Application | Applies to whole text | Applies within each chunk independently |
Encoding Ordinary Text | Encodes entire text sequence | Encodes text chunks independently |
Encoding with Special Tokens
def encode(self, text, allowed_special="none_raise"):
# ... ignore ...
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
special_chunks = re.split(special_pattern, text)
ids = []
for part in special_chunks:
if part in special:
ids.append(special[part])
else:
ids.extend(self.encode_ordinary(part))
return ids
-
Escape Special Tokens:
re.escape(k)
ensures that any special characters within the special tokens are properly escaped so they are treated as literal strings in the regex. -
Join Tokens: The join function concatenates all the escaped special tokens with the
|
(or) operator, allowing the regex to match any of the special tokens. -
Capture Group: Wrapping the entire pattern in
()
creates a capturing group, ensuring that the special tokens themselves are included in the split results. -
Example:
special_chunks = re.split(special_pattern, "Hello, [SPECIAL] world! [EXTRA]") # special_chunks = ["Hello, ", "[SPECIAL]", " world! ", "[EXTRA]", ""]
Decoding with Special Tokens
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = []
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
-
Handling Special Tokens: Checks if a token ID corresponds to a special token using
self.inverse_special_tokens
. Converts special token IDs back into their string representations. -
Standard Decoding: For ordinary tokens, the process remains the same as in the
BasicTokenizer
.
GPT4Tokenizer: GPT4_SPLIT_PATTERN
In the GPT4Tokenizer
class, weāve got this thing called GPT4_SPLIT_PATTERN
. Itās a complex regular expression pattern used to chop up text into chunks before we do any more processing for training. Letās break down the components of this pattern and understand how it works.
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"""
compiled_pattern = re.compile(GPT4_SPLIT_PATTERN)
text_chunks = re.findall(compiled_pattern, text)
Firstly, GPT4_SPLIT_PATTERN
is a raw string that defines a complex regular expression. Raw strings treat backslashes as literal characters, simplifying regex patterns by avoiding the need to double escape them. For instance, to match a literal backslash in regex:
- Normal string:
"\\"
, because each backslash needs to be escaped - Raw string:
r"\"
, much cleaner!
String | Meaning |
---|---|
\n |
Newline |
\t |
Tab |
\\ |
Single backslash |
r"\n" |
Literally a backslash followed by 'n' , not a newline |
r"\t" |
Literally a backslash followed by 't' , not a tab |
r"\\" |
Two backslashes, not one |
Now letās break down the components of GPT4_SPLIT_PATTERN
:
'(?i:[sdmt]|ll|ve|re)
:
- Matches a single quote followed by certain contractions (case-insensitive due to
(?i:)
). - Examples:
's
,'d
,'m
,'t
,'ll
,'ve
,'re
.
[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+
:
- Matches a sequence of one or more Unicode letters (
\\p{L}
) optionally preceded by a character that is not a newline (\\r\\n
) or a Unicode letter/number (\\p{L}\\p{N}
). - Examples:
Hello
,@World
,-CafƩ
,_Ćber
.
\\p{N}{1,3}
:
- Matches a sequence of 1 to 3 Unicode digits (
\\p{N}
). - Examples:
5
,42
,999
, but not1000
.
?[^\\s\\p{L}\\p{N}]++[\\r\\n]*
:
- Matches an optional space followed by one or more characters that are not whitespace or Unicode letters/numbers, followed by zero or more newline characters.
- Examples:
@#$%
,!!!
,;;;
.
\\s*[\\r\\n]
:
- Matches zero or more whitespace characters followed by a newline character.
\\s+(?!\\S)
:
- Matches one or more whitespace characters that are not followed by a non-whitespace character (essentially trailing whitespace).
\\s+
:
- Matches one or more whitespace characters.
Then the re.findall
function finds all non-overlapping matches of the pattern in the string text
and returns a list of all matches.
sample_text = 'Copy paste of the Wikipedia article on Taylor Swift, as of Feb 16, 2024.\n---\n\nMain menu\n\nWikipediaTh'
res = re.findall(compiled_pattern, sample_text)
res = [
'Copy', ' paste', ' of', ' the', ' Wikipedia', ' article', ' on',
' Taylor', ' Swift', ',', ' as', ' of', ' Feb', ' ', '16', ',', ' ',
'202', '4', '.\n', '---\n\n', 'Main', ' menu', '\n\n', 'WikipediaTh'
]
GPT4Tokenizer
Here we are! We have finally arrived at GPT4Tokenizer
. However, we donāt actually perform the training here. All we need to do is recover the merges from the official tokenizer used by GPT-4 and make some extra tricky adjustments.
Recover Merges from tiktoken
class GPT4Tokenizer(RegexTokenizer):
"""Lightweight wrapper on RegexTokenizer that matches GPT-4's tokenizer."""
def __init__(self):
super().__init__(pattern=GPT4_SPLIT_PATTERN)
# get the official tokenizer and its merges
enc = tiktoken.get_encoding("cl100k_base")
mergeable_ranks = enc._mergeable_ranks
# the merges are those of gpt4, but we have to recover them
self.merges = recover_merges(mergeable_ranks)
# reconstruct the vocab from the merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
self.vocab = vocab
# ... ignore ...
# this is a pretrained tokenizer, it is not intended to be trained
def train(self, text, vocab_size, verbose=False):
raise NotImplementedError
function: recover_merges
Recover merges is essentially about un-merging the byte sequences that have been combined. Itās all about reverting to those original pairings of byte sequences.
def recover_merges(mergeable_ranks):
merges = {}
for token, rank in mergeable_ranks.items():
if len(token) == 1:
continue # skip raw bytes
pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
assert len(pair) == 2
# recover the integer ranks of the pair
ix0 = mergeable_ranks[pair[0]]
ix1 = mergeable_ranks[pair[1]]
merges[(ix0, ix1)] = rank
return merges
-
Iterate through each token and its rank in
mergeable_ranks
:for token, rank in mergeable_ranks.items(): if len(token) == 1: continue # skip raw bytes
This skips raw bytes (single byte tokens) since they donāt need merging.
-
Reconstruct the pairings using the
bpe
function:pair = tuple(bpe(mergeable_ranks, token, max_rank=rank)) assert len(pair) == 2
The
bpe
function is called to get the merged pair for the given token up to its rank. -
Recover the integer ranks of the pair:
ix0 = mergeable_ranks[pair[0]] ix1 = mergeable_ranks[pair[1]] merges[(ix0, ix1)] = rank
function: bpe
The bpe
function takes a byte sequence (token) and a dictionary of mergeable ranks and iteratively merges pairs of bytes based on their ranks.
def bpe(mergeable_ranks, token, max_rank):
# helper function used in get_gpt4_merges() to reconstruct the merge forest
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
return parts
-
Convert the token to a list of bytes:
parts = [bytes([b]) for b in token]
-
Iteratively merge pairs:
- The while loop runs until no more valid merges are found.
-
In each iteration, the code looks for the pair of bytes with the minimum rank:
min_idx = None min_rank = None for i, pair in enumerate(zip(parts[:-1], parts[1:])): rank = mergeable_ranks.get(pair[0] + pair[1]) if rank is not None and (min_rank is None or rank < min_rank): min_idx = i min_rank = rank
This finds the pair with the smallest rank (highest frequency).
-
Perform the merge:
if min_rank is None or (max_rank is not None and min_rank >= max_rank): break assert min_idx is not None parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
If no valid merge is found or the rank exceeds
max_rank
, the loop breaks. Otherwise, the best pair is merged, and theparts
list is updated.
Example
Suppose we have:
mergeable_ranks = {
b'a': 0,
b'b': 1,
b'c': 2,
b'ab': 3,
b'abc': 4
}
Hereās what happens in recover_merges
:
- It loops through each token in
mergeable_ranks
. - For
b'a'
,b'b'
, andb'c'
, it skips them because theyāre single bytes. - For
b'ab'
:- It calls
bpe(mergeable_ranks, b'ab', max_rank=3)
- The
bpe
function returns[b'a', b'b']
- It gets the ranks of
b'a'
(0) andb'b'
(1) - It adds
(0, 1): 3
to themerges
dictionary
- It calls
- For
b'abc'
:- It calls
bpe(mergeable_ranks, b'abc', max_rank=4)
- The
bpe
function first mergesa
andb
(becauseab
has the lowest rank), resulting in[b'ab', b'c']
- It gets the ranks of
b'ab'
(3) andb'c'
(2) - It adds
(3, 2): 4
to themerges
dictionary
- It calls
- Finally, it returns the
merges
dictionary:{(0, 1): 3, (3, 2): 4}
This result tells us:
- The first merge operation was to combine the tokens with ranks 0 and 1 (
a
andb
), creating a new token with rank 3. - The second merge operation was to combine the tokens with ranks 3 and 2 (
ab
andc
), creating a new token with rank 4.
This process essentially reverse-engineers the BPE merge operations that were used to create the vocabulary in mergeable_ranks
. Itās super useful for understanding how the tokenizer was trained and how it splits up text.
Handling Byte Permutations
Another key aspect of GPT4Tokenizer
is handling the permutation of byte tokens, which requires special attention.
We create two dictionaries: self.byte_shuffle
and self.inverse_byte_shuffle
. These dictionaries are used to permute the bytes before encoding and un-permute them after decoding.
class GPT4Tokenizer(RegexTokenizer):
"""Lightweight wrapper on RegexTokenizer that matches GPT-4's tokenizer."""
def __init__(self):
super().__init__(pattern=GPT4_SPLIT_PATTERN)
# get the official tokenizer and its merges
enc = tiktoken.get_encoding("cl100k_base")
mergeable_ranks = enc._mergeable_ranks
# the merges are those of gpt4, but we have to recover them
self.merges = recover_merges(mergeable_ranks)
# reconstruct the vocab from the merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
self.vocab = vocab
# now here is another tricky part.
# for some reason, the tokens corresponding to individual bytes
# are permuted in a different order. This is completely non-sensical
# and probably historical, but therefore we have to deal with it here.
self.byte_shuffle = {i: mergeable_ranks[bytes([i])] for i in range(256)}
self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}
# finally register the special tokens
self.register_special_tokens(GPT4_SPECIAL_TOKENS)
def _encode_chunk(self, text_bytes):
# before we start processing bytes, we have to permute them
text_bytes = bytes(self.byte_shuffle[b] for b in text_bytes)
ids = super()._encode_chunk(text_bytes)
return ids
def decode(self, ids):
# we have to un-permute the bytes before we decode
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text_bytes = bytes(self.inverse_byte_shuffle[b] for b in text_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
Letās use concrete examples to illustrate the permutation and un-permutation process.
As we can see from the previous session, self.merges
is created from mergeable_ranks
, where the byte tokens are permuted. However, we still construct self.vocab
based on the original byte order.
Consequently, our vocabulary after index 256 (where the merged tokens start) is actually in a mess. Letās examine this:
# Official vocabulary
list(enc._mergeable_ranks.items())[256:261]
# [(b' ', 256),
# (b' ', 257),
# (b'in', 258),
# (b' t', 259),
# (b' ', 260)]
# Our vocabulary
list(tokenizer.vocab.items())[256:261]
# [(256, b'\xdc\xdc'),
# (257, b'\xdc\xdc\xdc\xdc'),
# (258, b'HM'),
# (259, b'\xdcS'),
# (260, b'\xdc\xdc\xdc\xdc\xdc\xdc\xdc\xdc')]
Thus, when we encode a text chunk, we first permute the bytes according to self.byte_shuffle
and then proceed with the encoding process. This step maps the normal byte order to the permuted byte order, allowing us to properly utilize self.merges
.
For example, letās examine the top 5 merges:
list(tokenizer.merges.items())[0:5]
# [((220, 220), 256),
# ((256, 256), 257),
# ((72, 77), 258),
# ((220, 83), 259),
# ((257, 257), 260)]
From enc._mergeable_ranks
, we understand that vocab[256]
(two spaces) is actually a combination of two vocab[220]
(one space). When we use text.encode("utf-8")
, a space will be represented by 32
:
text = " "
for i in text.encode("utf-8"):
print(i)
# 32
Therefore, what self.byte_shuffle
does is map 32
to 220
. Letās verify:
tokenizer.byte_shuffle[32]
# 220
With this step, we can ensure our encoding function produces the same ids as tiktoken:
enc.encode(" ")
# [220]
enc.encode(" ")
# [256]
tokenizer.encode(" ")
# [220]
tokenizer.encode(" ")
# [256]
For decoding, we first obtain a differently ordered text_bytes
, and then map it back to the original byte order using self.inverse_byte_shuffle
. This approach allows us to decode the text properly.
tokenizer.vocab[220]
# b'\xdc': not a space
tokenizer.inverse_byte_shuffle[220]
# 32: is a space
Implement with PyTorch
Reference: minbpe-pytorch
The original minbpe
library is written in pure Python, mainly for clarity, and is not optimized for speed. Letās see how we can optimize it using PyTorch to leverage vectorization and GPU acceleration.
Similarly, Iāve only zoomed in on the tricky bits that arenāt super straightforward. For the complete implementation, please refer to the repository.
Optimizing Pair Merging
The original merge
function iterates through the entire list of tokens to merge pairs.
def merge_original(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
We can optimize this process using vectorized operations in PyTorch for improved speed.
def merge(ids: torch.Tensor, pair: torch.Tensor, idx: int):
# create a mask for the first element i of every matching pair (i, j)
pairs = torch.stack((ids[:-1], ids[1:]), dim=1)
is_pair = (pairs == pair).all(axis=1)
false_tensor = torch.tensor([False], dtype=torch.bool, device=ids.device)
is_pair_i = torch.cat((is_pair, false_tensor))
# create a mask for the second element j of every matching pair (i, j)
is_pair_j = is_pair_i.roll(1)
# handle overlapping pairs for repeated tokens
while True:
is_overlap = (is_pair_i & is_pair_j).any()
if not is_overlap:
break # no overlapping pairs
# remove first overlapping pairs in repeated sequences
is_first = (is_pair_i & is_pair_j).int().diff() == 1
is_first = torch.cat((false_tensor, is_first))
is_pair_i &= ~is_first
is_pair_j = is_pair_i.roll(1)
# change the first element i of every matching pair (i, j) to the new token
ids[is_pair_i] = idx
# remove the second element j of every matching pair (i, j)
ids = ids[~is_pair_j]
return ids
Letās use an example to illustrate:
-
ids
: [1, 2, 3, 1, 2, 1, 2, 3] -
pair
: (1, 2) -
idx
: 4
- Convert
ids
to a PyTorch tensor. - Create pairs of consecutive elements from
ids
:pairs = tensor([[1, 2], [2, 3], [3, 1], [1, 2], [2, 1], [1, 2], [2, 3]])
- Check where pairs match the given pair:
is_pair = tensor([True, False, False, True, False, True])
- Create masks for the first and second elements of matching pairs:
is_pair_i = tensor([True, False, False, True, False, True, False])
is_pair_j = tensor([False, True, False, False, True, False, True])
- Handle overlapping pairs (none in this example):
- No overlapping pairs detected, so no change in
is_pair_i
andis_pair_j
.
- No overlapping pairs detected, so no change in
- Replace first elements of matching pairs with
idx
:ids[is_pair_i] = 4
ids = tensor([4, 2, 3, 4, 2, 4, 3])
- Remove second elements of matching pairs:
ids = ids[~is_pair_j]
ids = tensor([4, 3, 4, 4, 3])
Output: tensor([4, 3, 4, 4, 3])
Handling Overlapping Pairs
Consider this tricky case:
-
ids
: [1, 1, 1, 1, 1, 2, 3] -
pair
: (1, 1) -
idx
: 4
Initially, is_pair_i
and is_pair_j
would look like:
True False
True True
True True
True True
False True
False False
False False
Without special handling, this would incorrectly merge to [4, 2, 3]
. The correct result should be [4, 4, 1, 2, 3]
. To handle this, we:
- Detect overlapping pairs where a position is both the start and end of a pair
- Identify the first occurrence in each overlapping sequence
- Remove these from
is_pair_i
- Repeat until no overlaps remain
while True:
is_overlap = (is_pair_i & is_pair_j).any()
if not is_overlap:
break # no overlapping pairs
# remove first overlapping pairs in repeated sequences
is_first = (is_pair_i & is_pair_j).int().diff() == 1
is_first = torch.cat((false_tensor, is_first))
is_pair_i &= ~is_first
is_pair_j = is_pair_i.roll(1)
The .int().diff()
operation is used to convert a boolean tensor to an integer tensor and then compute the discrete difference between consecutive elements of the tensor. This is useful in the context of identifying the first occurrence in a sequence of overlapping pairs.
The .roll()
operation is used to roll the elements of a tensor along a specified dimension. When you use roll(1)
, it shifts all elements of the tensor by one position along the specified dimension, and the elements at the end of the tensor are wrapped around to the beginning.
This ensures correct merging even with repeated elements, preserving the intended tokenization structure.
Optimizing BasicTokenizer
BasicTokenizer.train
The original Python implementation of BasicTokenizer.train
counts pairs across the entire text by iterating through each pair of consecutive elements. While functional, this approach can be slow for large datasets.
for i in range(num_merges):
stats = get_stats(ids)
pair = max(stats, key=stats.get)
def get_stats(ids, counts=None):
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
The PyTorch version creates a tensor of all consecutive pairs using torch.stack
and uses torch.unique
to find unique pairs and their counts simultaneously.
for i in range(num_merges):
pairs = torch.stack((ids[:-1], ids[1:]), dim=1)
unique, counts = torch.unique(pairs, return_counts=True, dim=0)
pair_index = torch.argmax(counts)
pair, count = unique[pair_index], counts[pair_index]
For large datasets, this approach can be significantly faster than the original method. And this version can leverage GPU acceleration for even greater speedups.
BasicTokenizer.encode
The process in BasicTokenizer.encode
can be optimized in a similar manner.
while len(ids) >= 2:
pairs = torch.stack((ids[:-1], ids[1:]), dim=1)
unique = torch.unique(pairs, dim=0)
is_present = (merges[:, None] == unique[None]).all(-1).any(-1)
if not is_present.any():
break # no more mergeable pairs
pair_index = is_present.nonzero()[0]
pair = merges[pair_index]
-
Identify Unique Pairs: Utilize
torch.stack
andtorch.unique
to efficiently extract and deduplicate consecutive pairs from the input sequence. - Check Pair Presence: Employ broadcasting and advanced tensor operations to rapidly compare the unique pairs against the entire merges tensor, identifying all potential merge candidates.
- Select Optimal Pair: Determine the pair with the lowest merge index by finding the first match in the merges tensor.
Letās illustrate this with an example: assume merges
has shape [3840, 2]
and unique
has shape [11, 2]
.
-
Expand Dimensions: The use of
None
adds a new dimension to the tensors, making them compatible for broadcasting in the next step.-
merges[:, None]
changes the shape ofmerges
from[3840, 2]
to[3840, 1, 2]
. -
unique[None]
changes the shape ofunique
from[11, 2]
to[1, 11, 2]
.
-
-
Broadcast and Compare:
- The expression
(merges[:, None] == unique[None])
compares each pair inmerges
with each pair inunique
. - After broadcasting, the shape of the comparison tensor is
[3840, 11, 2]
. - The comparison is element-wise, so for each pair in
merges
, it is compared to each pair inunique
, resulting in a boolean tensor indicating where pairs match.
- The expression
-
Check All Elements in the Pair:
-
.all(-1)
checks if both elements in the pair match for each comparison. - This reduces the last dimension (2), resulting in a boolean tensor of shape
[3840, 11]
.
-
-
Check Any Match Across Unique Pairs:
-
.any(-1)
checks if any of the comparisons for eachmerges
pair matches anyunique
pair. - This reduces the second-to-last dimension (11), resulting in a boolean tensor of shape
[3840]
.
-
Then pair_index = is_present.nonzero()[0]
means we find the index of the first occurrence of True
in the boolean tensor is_present
. This index corresponds to the best pair to merge.