Link between code and BPE algorithm already implemented
The main encoding logic is in the encode method:
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.bpe(token).split(' '))
return [self.encoder[bpe_token] for bpe_token in bpe_tokens]
Here's what's happening:
- The input
textis split into tokens using the regex patternself.pat. This is the forced splitting we discussed earlier. - Each token is then converted to bytes and encoded using
self.byte_encoder. - The
self.bpemethod is called on each encoded token. This applies the BPE merges to the token. We'll look at this method in a moment. - The resulting BPE tokens are split on spaces and added to the
bpe_tokenslist. - Finally, each BPE token is converted to its integer ID using
self.encoder.
Now let's look at the bpe method:
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
This is very similar to the merge function we wrote earlier, but with a few additional optimizations:
- The result of BPE for each token is cached in
self.cacheto avoid recomputing it. - The
get_pairsfunction (not shown here) is used to get all the possible pairs in the current word. This is equivalent to theget_statsfunction we wrote earlier, but it only returns the pairs, not their frequencies. - The
whileloop continues to merge pairs until no more merges are possible. This is equivalent to the loop we had in our earlier code, but here it's done for each individual token, not the entire text. - The
minfunction is used to find the pair with the lowest rank (earliest merge) that exists in the current word. This is equivalent to what we did earlier withmin(stats, key=lambda p: merges.get(p, float("inf"))). - If a merge is found, the word is split at that point, the merge is performed, and the process continues with the updated word.
- Once no more merges are possible (either because there are no more pairs or because none of the pairs are in
self.bpe_ranks), the final merged word is joined with spaces and returned.
The decoding process (not shown here) is essentially the inverse of this: the integer token IDs are converted back to their token strings using self.decoder, the byte encoding is reversed, and the resulting text is returned.
And that's it! The GPT-2 tokenizer in a nutshell. As you can see, while there are some additional optimizations and preprocessing steps, the core BPE algorithm is the same as what we implemented earlier.
In the next section, we'll look at how this tokenizer handles special tokens, such as those used to indicate the start and end of a sequence.