CatFood

Knowledge Nuggets for minbpe

Table of Content


This post contains the knowledge nuggets I collected while learning minbpe. A big thanks to Andrej Karpathy for his super clean and elegant codeā€”sometimes, a single line of code is worth a thousand words.

I also received a lot of help from both ChatGPT and Claude in figuring things out, and they deserve credit for their invaluable assistance.

šŸ˜¼ Letā€™s visualize it!

Iā€™ve created a D3.js visualization (the cover image) to showcase the vocabulary and merges trained by the BPE algorithm here.

Byte Representation and Character Encoding

Before diving into minbpe, letā€™s first go through the basics of bytes and character encoding. Bytes are the fundamental units of digital information, consisting of 8 bits and capable of representing 256 distinct values (ranging from 0 to 255). In character encoding, each byte can represent a single character, particularly within ASCII or extended ASCII encoding frameworks.

ASCII and Extended ASCII

ASCII (American Standard Code for Information Interchange) is a character encoding standard that uses 7 bits to represent characters. This means it can represent 128 characters (from 0 to 127), including:

Category Example(s)
Control Characters null (\x00), newline (\x0A)
Digits 48 for '0'
Uppercase Letters 65 for 'A'
Lowercase Letters 97 for 'a'
Punctuation Marks 33 for '!', 46 for '.', 63 for '?'

Extended ASCII uses 8 bits (1 byte), allowing for 256 characters (from 0 to 255). Characters from 128 to 255 include various additional characters, such as accented letters, symbols, and graphical characters.

Decimal Hex Character Description
128 0x80 Ƈ Latin capital letter C with cedilla
129 0x81 Ć¼ Latin small letter u with diaeresis
130 0x82 Ć© Latin small letter e with acute
131 0x83 Ć¢ Latin small letter a with circumflex

Bytes and Unicode in Python

In Python, the bytes type is used to represent a sequence of bytes. When we create a bytes object with bytes([i]), weā€™re creating a bytes object where the single byte is the byte representation of the integer i. For example:

  • bytes([65]) creates a bytes object b'A', where 65 is the ASCII code for the character 'A'.
  • bytes([65, 66]) creates a bytes object b'AB', where 66 is the ASCII code for the character 'B'.

When we iterate over a bytes object in Python, each element we get is actually an integer representing the byte value, not a single-byte bytes object. This behavior might seem counterintuitive at first, but itā€™s by design. Letā€™s break it down:

  1. A bytes object is an immutable sequence of integers in the range 0 to 255.
  2. When we iterate over a bytes object, Python yields these integer values one by one.
  3. These integers represent the ASCII value (or more generally, the byte value) of each byte in the sequence.

Hereā€™s an example to illustrate this:

text_bytes = b"Hello"
for b in text_bytes:
    print(f"{b} - {type(b)}")

This will output:

72 - <class 'int'>
101 - <class 'int'>
108 - <class 'int'>
108 - <class 'int'>
111 - <class 'int'>

However, this behavior is specific to bytes objects. When we iterate over a string (str object), we get individual characters (which are also strings of length 1), not their ASCII values.

UTF-8 Encoding and Decoding

When we use str.encode("utf-8") in Python, weā€™re converting a string (which is a sequence of characters) into a bytes object (which is a sequence of bytes) using the "utf-8" encoding. Each character in the string is translated into one or more bytes.

UTF-8 (Unicode Transformation Format - 8-bit) is a variable-length character encoding for Unicode. It can encode each of the 1,112,064 valid character code points in Unicode using one to four 8-bit bytes.

Representation Type Unicode Range Example Character Unicode Code Point Byte Representation
Single-byte (ASCII) U+0000 to U+007F 'A' U+0041 01000001
Two-byte U+0080 to U+07FF 'Ć©' U+00E9 11000011 10101001
Three-byte U+0800 to U+FFFF 'ą¤¹' U+0939 11100000 10100111 10011001
Four-byte U+10000 to U+10FFFF 'šˆ' U+10348 11110000 10010000 10001101 10001000

When we use bytes.decode("utf-8", errors="replace") in Python, we attempt to decode bytes as UTF-8. If there are any invalid byte sequences (which could occur if the bytes do not form valid UTF-8 sequences), these invalid sequences are replaced with the Unicode replacement character (ļæ½).

# Define a string
s = "Hello, world! Ć©"

# Encode the string to UTF-8 bytes
encoded_s = s.encode("utf-8")
# encoded_s: b'Hello, world! \xc3\xa9'

decoded_s = encoded_s.decode("utf-8")
# decoded_s: Hello, world! Ć©

Unicode

Unicode is a universal character encoding standard designed to support the digital representation of the worldā€™s writing systems. It aims to provide a unique code point (a number) for every character, regardless of platform, program, or language, thus enabling consistent encoding, representation, and handling of text.

Key Features of Unicode include:

  • Universal: Unicode covers almost all written languages, symbols, and scripts used around the world.
  • Unique Code Points: Each character in Unicode is assigned a unique code point, which is a number written in the form U+XXXX (e.g., U+0041 for the letter 'A').
  • Scalability: Unicode can accommodate over a million unique characters, making it future-proof for adding new characters and symbols as needed.

Why 1,112,064 Valid Character Code Points?

The Unicode standard defines a range of code points from U+0000 to U+10FFFF. This range allows for 1,114,112 possible code points:

  • U+0000 in decimal is 0.
  • U+10FFFF in decimal is: 16 * 65536 + 65535 = 1114111
    • 10 in hexadecimal is 16 in decimal.
    • FFFF in hexadecimal is 65535 in decimal: 15 * 16^0 + 15 * 16^1 + 15 * 16^2 + 15 * 16^3 = 65535
  • Total number of code points: 1114111 - 0 + 1 = 1114112

However, not all of these code points are assigned to characters. Some code points are reserved for special purposes and are not assigned to characters. These include:

  • Surrogate pairs: Code points from U+D800 to U+DFFF are reserved for UTF-16 encoding and do not represent characters.
  • Private Use Areas: Code points that are reserved for private use and not assigned to any standard character.
  • Non-characters: Code points that are reserved for internal use and will not be assigned characters.

After excluding these reserved ranges, there are 1,112,064 valid code points available for assigning characters.

BPE: Byte Pair Encoding

Now letā€™s look at some key aspects of BPE, which I have examined closely to understand the algorithm better. Other details are already covered in the original code or Karpathyā€™s video tutorial, so I wonā€™t repeat them here.

Training and Encoding

The training and encoding phases of BPE are mirrored processes, much like training and inference in machine learning models. This mirroring guarantees that the encoding process adheres strictly to the training process, ensuring consistent tokenization.

Aspect Training Encoding
Starting Point Start with raw bytes Start with raw bytes
Pair Frequencies Count pair frequencies Count pair frequencies (but only use the pairs)
Merge Strategy Merge most frequent pair Merge pair with lowest rank in learned merges
Repetition Stop until vocab size is reached Stop when no more merges possible

BasicTokenizer and RegexTokenizer

They are quite similar, but the major differences lie in the fact that RegexTokenizer will first split the text into chunks using a regex pattern, then apply the tokenization process within each chunk independently.

It also handles special tokens. After splitting the text into chunks and special tokens, the rest of the process is the same as BasicTokenizer.

Process BasicTokenizer RegexTokenizer
Text Processing Treats entire text as one sequence Splits text into chunks using regex pattern as well as special tokens
Pair Counting Counts pairs across whole text Counts within chunks, then aggregates
Merge Application Applies to whole text Applies within each chunk independently
Encoding Ordinary Text Encodes entire text sequence Encodes text chunks independently

Encoding with Special Tokens

def encode(self, text, allowed_special="none_raise"):
    # ... ignore ...
    special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
    special_chunks = re.split(special_pattern, text)
    ids = []
    for part in special_chunks:
        if part in special:
            ids.append(special[part])
        else:
            ids.extend(self.encode_ordinary(part))
    return ids
  • Escape Special Tokens: re.escape(k) ensures that any special characters within the special tokens are properly escaped so they are treated as literal strings in the regex.
  • Join Tokens: The join function concatenates all the escaped special tokens with the | (or) operator, allowing the regex to match any of the special tokens.
  • Capture Group: Wrapping the entire pattern in () creates a capturing group, ensuring that the special tokens themselves are included in the split results.
  • Example:
      special_chunks = re.split(special_pattern, "Hello, [SPECIAL] world! [EXTRA]")
      # special_chunks = ["Hello, ", "[SPECIAL]", " world! ", "[EXTRA]", ""]
    

Decoding with Special Tokens

def decode(self, ids):
    # given ids (list of integers), return Python string
    part_bytes = []
    for idx in ids:
        if idx in self.vocab:
            part_bytes.append(self.vocab[idx])
        elif idx in self.inverse_special_tokens:
            part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
        else:
            raise ValueError(f"invalid token id: {idx}")
    text_bytes = b"".join(part_bytes)
    text = text_bytes.decode("utf-8", errors="replace")
    return text
  • Handling Special Tokens: Checks if a token ID corresponds to a special token using self.inverse_special_tokens. Converts special token IDs back into their string representations.
  • Standard Decoding: For ordinary tokens, the process remains the same as in the BasicTokenizer.

GPT4Tokenizer: GPT4_SPLIT_PATTERN

In the GPT4Tokenizer class, weā€™ve got this thing called GPT4_SPLIT_PATTERN. Itā€™s a complex regular expression pattern used to chop up text into chunks before we do any more processing for training. Letā€™s break down the components of this pattern and understand how it works.

GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"""
compiled_pattern = re.compile(GPT4_SPLIT_PATTERN)
text_chunks = re.findall(compiled_pattern, text)

Firstly, GPT4_SPLIT_PATTERN is a raw string that defines a complex regular expression. Raw strings treat backslashes as literal characters, simplifying regex patterns by avoiding the need to double escape them. For instance, to match a literal backslash in regex:

  • Normal string: "\\", because each backslash needs to be escaped
  • Raw string: r"\", much cleaner!
String Meaning
\n Newline
\t Tab
\\ Single backslash
r"\n" Literally a backslash followed by 'n', not a newline
r"\t" Literally a backslash followed by 't', not a tab
r"\\" Two backslashes, not one

Now letā€™s break down the components of GPT4_SPLIT_PATTERN:

'(?i:[sdmt]|ll|ve|re):

  • Matches a single quote followed by certain contractions (case-insensitive due to (?i:)).
  • Examples: 's, 'd, 'm, 't, 'll, 've, 're.

[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+:

  • Matches a sequence of one or more Unicode letters (\\p{L}) optionally preceded by a character that is not a newline (\\r\\n) or a Unicode letter/number (\\p{L}\\p{N}).
  • Examples: Hello, @World, -CafĆ©, _Ɯber.

\\p{N}{1,3}:

  • Matches a sequence of 1 to 3 Unicode digits (\\p{N}).
  • Examples: 5, 42, 999, but not 1000.

?[^\\s\\p{L}\\p{N}]++[\\r\\n]*:

  • Matches an optional space followed by one or more characters that are not whitespace or Unicode letters/numbers, followed by zero or more newline characters.
  • Examples: @#$%, !!!, ;;;.

\\s*[\\r\\n]:

  • Matches zero or more whitespace characters followed by a newline character.

\\s+(?!\\S):

  • Matches one or more whitespace characters that are not followed by a non-whitespace character (essentially trailing whitespace).

\\s+:

  • Matches one or more whitespace characters.

Then the re.findall function finds all non-overlapping matches of the pattern in the string text and returns a list of all matches.

sample_text = 'Copy paste of the Wikipedia article on Taylor Swift, as of Feb 16, 2024.\n---\n\nMain menu\n\nWikipediaTh'
res = re.findall(compiled_pattern, sample_text)
res = [
    'Copy', ' paste', ' of', ' the', ' Wikipedia', ' article', ' on',
    ' Taylor', ' Swift', ',', ' as', ' of', ' Feb', ' ', '16', ',', ' ',
    '202', '4', '.\n', '---\n\n', 'Main', ' menu', '\n\n', 'WikipediaTh'
]

GPT4Tokenizer

Here we are! We have finally arrived at GPT4Tokenizer. However, we donā€™t actually perform the training here. All we need to do is recover the merges from the official tokenizer used by GPT-4 and make some extra tricky adjustments.

Recover Merges from tiktoken

class GPT4Tokenizer(RegexTokenizer):
    """Lightweight wrapper on RegexTokenizer that matches GPT-4's tokenizer."""

    def __init__(self):
        super().__init__(pattern=GPT4_SPLIT_PATTERN)
        # get the official tokenizer and its merges

        enc = tiktoken.get_encoding("cl100k_base")
        mergeable_ranks = enc._mergeable_ranks
        # the merges are those of gpt4, but we have to recover them
        self.merges = recover_merges(mergeable_ranks)
        # reconstruct the vocab from the merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        self.vocab = vocab
        # ... ignore ...

    # this is a pretrained tokenizer, it is not intended to be trained
    def train(self, text, vocab_size, verbose=False):
        raise NotImplementedError

function: recover_merges

Recover merges is essentially about un-merging the byte sequences that have been combined. Itā€™s all about reverting to those original pairings of byte sequences.

def recover_merges(mergeable_ranks):
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) == 1:
            continue # skip raw bytes
        pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
        assert len(pair) == 2
        # recover the integer ranks of the pair
        ix0 = mergeable_ranks[pair[0]]
        ix1 = mergeable_ranks[pair[1]]
        merges[(ix0, ix1)] = rank

    return merges
  1. Iterate through each token and its rank in mergeable_ranks:

     for token, rank in mergeable_ranks.items():
         if len(token) == 1:
             continue # skip raw bytes
    

    This skips raw bytes (single byte tokens) since they donā€™t need merging.

  2. Reconstruct the pairings using the bpe function:

     pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
     assert len(pair) == 2
    

    The bpe function is called to get the merged pair for the given token up to its rank.

  3. Recover the integer ranks of the pair:

     ix0 = mergeable_ranks[pair[0]]
     ix1 = mergeable_ranks[pair[1]]
     merges[(ix0, ix1)] = rank
    

function: bpe

The bpe function takes a byte sequence (token) and a dictionary of mergeable ranks and iteratively merges pairs of bytes based on their ranks.

def bpe(mergeable_ranks, token, max_rank):
    # helper function used in get_gpt4_merges() to reconstruct the merge forest
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
    return parts
  1. Convert the token to a list of bytes:

     parts = [bytes([b]) for b in token]
    
  2. Iteratively merge pairs:

    • The while loop runs until no more valid merges are found.
    • In each iteration, the code looks for the pair of bytes with the minimum rank:

        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
      

      This finds the pair with the smallest rank (highest frequency).

    • Perform the merge:

        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
      

      If no valid merge is found or the rank exceeds max_rank, the loop breaks. Otherwise, the best pair is merged, and the parts list is updated.

Example

Suppose we have:

mergeable_ranks = {
    b'a': 0,
    b'b': 1,
    b'c': 2,
    b'ab': 3,
    b'abc': 4
}

Hereā€™s what happens in recover_merges:

  1. It loops through each token in mergeable_ranks.
  2. For b'a', b'b', and b'c', it skips them because theyā€™re single bytes.
  3. For b'ab':
    • It calls bpe(mergeable_ranks, b'ab', max_rank=3)
    • The bpe function returns [b'a', b'b']
    • It gets the ranks of b'a' (0) and b'b' (1)
    • It adds (0, 1): 3 to the merges dictionary
  4. For b'abc':
    • It calls bpe(mergeable_ranks, b'abc', max_rank=4)
    • The bpe function first merges a and b (because ab has the lowest rank), resulting in [b'ab', b'c']
    • It gets the ranks of b'ab' (3) and b'c' (2)
    • It adds (3, 2): 4 to the merges dictionary
  5. Finally, it returns the merges dictionary: {(0, 1): 3, (3, 2): 4}

This result tells us:

  • The first merge operation was to combine the tokens with ranks 0 and 1 (a and b), creating a new token with rank 3.
  • The second merge operation was to combine the tokens with ranks 3 and 2 (ab and c), creating a new token with rank 4.

This process essentially reverse-engineers the BPE merge operations that were used to create the vocabulary in mergeable_ranks. Itā€™s super useful for understanding how the tokenizer was trained and how it splits up text.

Handling Byte Permutations

Another key aspect of GPT4Tokenizer is handling the permutation of byte tokens, which requires special attention.

We create two dictionaries: self.byte_shuffle and self.inverse_byte_shuffle. These dictionaries are used to permute the bytes before encoding and un-permute them after decoding.

class GPT4Tokenizer(RegexTokenizer):
    """Lightweight wrapper on RegexTokenizer that matches GPT-4's tokenizer."""

    def __init__(self):
        super().__init__(pattern=GPT4_SPLIT_PATTERN)
        # get the official tokenizer and its merges

        enc = tiktoken.get_encoding("cl100k_base")
        mergeable_ranks = enc._mergeable_ranks
        # the merges are those of gpt4, but we have to recover them
        self.merges = recover_merges(mergeable_ranks)
        # reconstruct the vocab from the merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        self.vocab = vocab
        # now here is another tricky part.
        # for some reason, the tokens corresponding to individual bytes
        # are permuted in a different order. This is completely non-sensical
        # and probably historical, but therefore we have to deal with it here.
        self.byte_shuffle = {i: mergeable_ranks[bytes([i])] for i in range(256)}
        self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}
        # finally register the special tokens
        self.register_special_tokens(GPT4_SPECIAL_TOKENS)

    def _encode_chunk(self, text_bytes):
        # before we start processing bytes, we have to permute them
        text_bytes = bytes(self.byte_shuffle[b] for b in text_bytes)
        ids = super()._encode_chunk(text_bytes)
        return ids
    
    def decode(self, ids):
        # we have to un-permute the bytes before we decode
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text_bytes = bytes(self.inverse_byte_shuffle[b] for b in text_bytes)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

Letā€™s use concrete examples to illustrate the permutation and un-permutation process.

As we can see from the previous session, self.merges is created from mergeable_ranks, where the byte tokens are permuted. However, we still construct self.vocab based on the original byte order.

Consequently, our vocabulary after index 256 (where the merged tokens start) is actually in a mess. Letā€™s examine this:

# Official vocabulary
list(enc._mergeable_ranks.items())[256:261]
# [(b'  ', 256),
#  (b'    ', 257),
#  (b'in', 258),
#  (b' t', 259),
#  (b'        ', 260)]

# Our vocabulary
list(tokenizer.vocab.items())[256:261]
# [(256, b'\xdc\xdc'),
#  (257, b'\xdc\xdc\xdc\xdc'),
#  (258, b'HM'),
#  (259, b'\xdcS'),
#  (260, b'\xdc\xdc\xdc\xdc\xdc\xdc\xdc\xdc')]

Thus, when we encode a text chunk, we first permute the bytes according to self.byte_shuffle and then proceed with the encoding process. This step maps the normal byte order to the permuted byte order, allowing us to properly utilize self.merges.

For example, letā€™s examine the top 5 merges:

list(tokenizer.merges.items())[0:5]
# [((220, 220), 256),
#  ((256, 256), 257),
#  ((72, 77), 258),
#  ((220, 83), 259),
#  ((257, 257), 260)]

From enc._mergeable_ranks, we understand that vocab[256] (two spaces) is actually a combination of two vocab[220] (one space). When we use text.encode("utf-8"), a space will be represented by 32:

text = " "
for i in text.encode("utf-8"):
    print(i)
# 32

Therefore, what self.byte_shuffle does is map 32 to 220. Letā€™s verify:

tokenizer.byte_shuffle[32]
# 220

With this step, we can ensure our encoding function produces the same ids as tiktoken:

enc.encode(" ")
# [220]
enc.encode("  ")
# [256]
tokenizer.encode(" ")
# [220]
tokenizer.encode("  ")
# [256]

For decoding, we first obtain a differently ordered text_bytes, and then map it back to the original byte order using self.inverse_byte_shuffle. This approach allows us to decode the text properly.

tokenizer.vocab[220]
# b'\xdc': not a space
tokenizer.inverse_byte_shuffle[220]
# 32: is a space

Implement with PyTorch

Reference: minbpe-pytorch

The original minbpe library is written in pure Python, mainly for clarity, and is not optimized for speed. Letā€™s see how we can optimize it using PyTorch to leverage vectorization and GPU acceleration.

Similarly, Iā€™ve only zoomed in on the tricky bits that arenā€™t super straightforward. For the complete implementation, please refer to the repository.

Optimizing Pair Merging

The original merge function iterates through the entire list of tokens to merge pairs.

def merge_original(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

We can optimize this process using vectorized operations in PyTorch for improved speed.

def merge(ids: torch.Tensor, pair: torch.Tensor, idx: int):
    # create a mask for the first element i of every matching pair (i, j)
    pairs = torch.stack((ids[:-1], ids[1:]), dim=1)
    is_pair = (pairs == pair).all(axis=1)
    false_tensor = torch.tensor([False], dtype=torch.bool, device=ids.device)
    is_pair_i = torch.cat((is_pair, false_tensor))

    # create a mask for the second element j of every matching pair (i, j)
    is_pair_j = is_pair_i.roll(1)

    # handle overlapping pairs for repeated tokens
    while True:
        is_overlap = (is_pair_i & is_pair_j).any()
        if not is_overlap:
            break # no overlapping pairs

        # remove first overlapping pairs in repeated sequences
        is_first = (is_pair_i & is_pair_j).int().diff() == 1
        is_first = torch.cat((false_tensor, is_first))
        is_pair_i &= ~is_first
        is_pair_j = is_pair_i.roll(1)

    # change the first element i of every matching pair (i, j) to the new token
    ids[is_pair_i] = idx

    # remove the second element j of every matching pair (i, j)
    ids = ids[~is_pair_j]
    return ids

Letā€™s use an example to illustrate:

  • ids: [1, 2, 3, 1, 2, 1, 2, 3]
  • pair: (1, 2)
  • idx: 4
  1. Convert ids to a PyTorch tensor.
  2. Create pairs of consecutive elements from ids:
    • pairs = tensor([[1, 2], [2, 3], [3, 1], [1, 2], [2, 1], [1, 2], [2, 3]])
  3. Check where pairs match the given pair:
    • is_pair = tensor([True, False, False, True, False, True])
  4. Create masks for the first and second elements of matching pairs:
    • is_pair_i = tensor([True, False, False, True, False, True, False])
    • is_pair_j = tensor([False, True, False, False, True, False, True])
  5. Handle overlapping pairs (none in this example):
    • No overlapping pairs detected, so no change in is_pair_i and is_pair_j.
  6. Replace first elements of matching pairs with idx:
    • ids[is_pair_i] = 4
    • ids = tensor([4, 2, 3, 4, 2, 4, 3])
  7. Remove second elements of matching pairs:
    • ids = ids[~is_pair_j]
    • ids = tensor([4, 3, 4, 4, 3])

Output: tensor([4, 3, 4, 4, 3])

Handling Overlapping Pairs

Consider this tricky case:

  • ids: [1, 1, 1, 1, 1, 2, 3]
  • pair: (1, 1)
  • idx: 4

Initially, is_pair_i and is_pair_j would look like:

True False
True True
True True
True True
False True
False False
False False

Without special handling, this would incorrectly merge to [4, 2, 3]. The correct result should be [4, 4, 1, 2, 3]. To handle this, we:

  1. Detect overlapping pairs where a position is both the start and end of a pair
  2. Identify the first occurrence in each overlapping sequence
  3. Remove these from is_pair_i
  4. Repeat until no overlaps remain
while True:
    is_overlap = (is_pair_i & is_pair_j).any()
    if not is_overlap:
        break # no overlapping pairs

    # remove first overlapping pairs in repeated sequences
    is_first = (is_pair_i & is_pair_j).int().diff() == 1
    is_first = torch.cat((false_tensor, is_first))
    is_pair_i &= ~is_first
    is_pair_j = is_pair_i.roll(1)

The .int().diff() operation is used to convert a boolean tensor to an integer tensor and then compute the discrete difference between consecutive elements of the tensor. This is useful in the context of identifying the first occurrence in a sequence of overlapping pairs.

The .roll() operation is used to roll the elements of a tensor along a specified dimension. When you use roll(1), it shifts all elements of the tensor by one position along the specified dimension, and the elements at the end of the tensor are wrapped around to the beginning.

This ensures correct merging even with repeated elements, preserving the intended tokenization structure.

Optimizing BasicTokenizer

BasicTokenizer.train

The original Python implementation of BasicTokenizer.train counts pairs across the entire text by iterating through each pair of consecutive elements. While functional, this approach can be slow for large datasets.

for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)

def get_stats(ids, counts=None):
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]):  # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

The PyTorch version creates a tensor of all consecutive pairs using torch.stack and uses torch.unique to find unique pairs and their counts simultaneously.

for i in range(num_merges):
    pairs = torch.stack((ids[:-1], ids[1:]), dim=1)
    unique, counts = torch.unique(pairs, return_counts=True, dim=0)
    pair_index = torch.argmax(counts)
    pair, count = unique[pair_index], counts[pair_index]

For large datasets, this approach can be significantly faster than the original method. And this version can leverage GPU acceleration for even greater speedups.

BasicTokenizer.encode

The process in BasicTokenizer.encode can be optimized in a similar manner.

while len(ids) >= 2:
    pairs = torch.stack((ids[:-1], ids[1:]), dim=1)
    unique = torch.unique(pairs, dim=0)

    is_present = (merges[:, None] == unique[None]).all(-1).any(-1)
    if not is_present.any():
        break  # no more mergeable pairs

    pair_index = is_present.nonzero()[0]
    pair = merges[pair_index]
  1. Identify Unique Pairs: Utilize torch.stack and torch.unique to efficiently extract and deduplicate consecutive pairs from the input sequence.
  2. Check Pair Presence: Employ broadcasting and advanced tensor operations to rapidly compare the unique pairs against the entire merges tensor, identifying all potential merge candidates.
  3. Select Optimal Pair: Determine the pair with the lowest merge index by finding the first match in the merges tensor.

Letā€™s illustrate this with an example: assume merges has shape [3840, 2] and unique has shape [11, 2].

  1. Expand Dimensions: The use of None adds a new dimension to the tensors, making them compatible for broadcasting in the next step.
    • merges[:, None] changes the shape of merges from [3840, 2] to [3840, 1, 2].
    • unique[None] changes the shape of unique from [11, 2] to [1, 11, 2].
  2. Broadcast and Compare:
    • The expression (merges[:, None] == unique[None]) compares each pair in merges with each pair in unique.
    • After broadcasting, the shape of the comparison tensor is [3840, 11, 2].
    • The comparison is element-wise, so for each pair in merges, it is compared to each pair in unique, resulting in a boolean tensor indicating where pairs match.
  3. Check All Elements in the Pair:
    • .all(-1) checks if both elements in the pair match for each comparison.
    • This reduces the last dimension (2), resulting in a boolean tensor of shape [3840, 11].
  4. Check Any Match Across Unique Pairs:
    • .any(-1) checks if any of the comparisons for each merges pair matches any unique pair.
    • This reduces the second-to-last dimension (11), resulting in a boolean tensor of shape [3840].

Then pair_index = is_present.nonzero()[0] means we find the index of the first occurrence of True in the boolean tensor is_present. This index corresponds to the best pair to merge.