Link between code and BPE algorithm already implemented

The main encoding logic is in the encode method:

def encode(self, text):
    bpe_tokens = []
    for token in re.findall(self.pat, text):
        token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
        bpe_tokens.extend(self.bpe(token).split(' '))
    return [self.encoder[bpe_token] for bpe_token in bpe_tokens]

Here's what's happening:

  1. The input text is split into tokens using the regex pattern self.pat. This is the forced splitting we discussed earlier.
  2. Each token is then converted to bytes and encoded using self.byte_encoder.
  3. The self.bpe method is called on each encoded token. This applies the BPE merges to the token. We'll look at this method in a moment.
  4. The resulting BPE tokens are split on spaces and added to the bpe_tokens list.
  5. Finally, each BPE token is converted to its integer ID using self.encoder.

Now let's look at the bpe method:

def bpe(self, token):
    if token in self.cache:
        return self.cache[token]
    word = tuple(token)
    pairs = get_pairs(word)

    if not pairs:
        return token

    while True:
        bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
        if bigram not in self.bpe_ranks:
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)
    word = ' '.join(word)
    self.cache[token] = word
    return word

This is very similar to the merge function we wrote earlier, but with a few additional optimizations:

  1. The result of BPE for each token is cached in self.cache to avoid recomputing it.
  2. The get_pairs function (not shown here) is used to get all the possible pairs in the current word. This is equivalent to the get_stats function we wrote earlier, but it only returns the pairs, not their frequencies.
  3. The while loop continues to merge pairs until no more merges are possible. This is equivalent to the loop we had in our earlier code, but here it's done for each individual token, not the entire text.
  4. The min function is used to find the pair with the lowest rank (earliest merge) that exists in the current word. This is equivalent to what we did earlier with min(stats, key=lambda p: merges.get(p, float("inf"))).
  5. If a merge is found, the word is split at that point, the merge is performed, and the process continues with the updated word.
  6. Once no more merges are possible (either because there are no more pairs or because none of the pairs are in self.bpe_ranks), the final merged word is joined with spaces and returned.

The decoding process (not shown here) is essentially the inverse of this: the integer token IDs are converted back to their token strings using self.decoder, the byte encoding is reversed, and the resulting text is returned.

And that's it! The GPT-2 tokenizer in a nutshell. As you can see, while there are some additional optimizations and preprocessing steps, the core BPE algorithm is the same as what we implemented earlier.

In the next section, we'll look at how this tokenizer handles special tokens, such as those used to indicate the start and end of a sequence.


Last update: 2024-08-21