Walkthrough of implementing BPE from scratch in Python
Now let's see how this works with a more realistic example. Instead of toy letters, we're going to start with raw bytes. So our initial vocabulary will be the 256 possible bytes. Then we're going to iteratively merge byte pairs to create new tokens, which will expand our vocabulary.
Here's a concrete example in Python. I went to this blog post and copied the first paragraph into a string in Python:
text = """A Programmer's Introduction to Unicode March 3, 2017 · Coding · 22 Comments Unicode! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺\u200c🇳\u200c🇮\u200c🇨\u200c🇴\u200c🇩\u200c🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to "support Unicode" in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don't blame programmers for still finding the whole thing mysterious, even 30 years after Unicode's inception."""
# let's take a look at the text
print(text)
print("length:", len(text))
tokens = list(map(int, text.encode('utf-8'))) # convert to a list of integers in range 0..255 for convenience
print('---')
print(tokens)
print("length:", len(tokens))
A Programmer's Introduction to Unicode March 3, 2017 · Coding · 22 Comments Unicode! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺🇳🇮🇨🇴🇩🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to "support Unicode" in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don't blame programmers for still finding the whole thing mysterious, even 30 years after Unicode's inception.
length: 533
---
[65, 32, 80, 114, 111, 103, 114, 97, 109, 109, 101, 114, 39, 115, 32, 73, 110, 116, 114, 111, 100, 117, 99, 116, 105, 111, 110, 32, 116, 111, 32, 85, 110, 105, 99, 111, 100, 101, 32, 77, 97, 114, 99, 104, 32, 51, 44, 32, 50, 48, 49, 55, 32, 194, 183, 32, 67, 111, 100, 105, 110, 103, 32, 194, 183, 32, 50, 50, 32, 67, 111, 109, 109, 101, 110, 116, 115, 32, 32, 208, 146, 208, 189, 208, 184, 208, 186, 208, 190, 208, 180, 208, 181, 33, 32, 240, 159, 153, 164, 240, 159, 153, 157, 240, 159, 153, 152, 240, 159, 153, 146, 240, 159, 153, 158, 240, 159, 153, 148, 240, 159, 153, 144, 226, 128, 141, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 178, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 181, 226, 128, 140, 240, 159, 135, 144, 33, 32, 240, 159, 152, 132, 32, 84, 104, 101, 32, 118, 101, 114, 121, 32, 110, 97, 109, 101, 32, 115, 116, 114, 105, 107, 101, 115, 32, 102, 101, 97, 114, 32, 97, 110, 100, 32, 97, 119, 101, 32, 105, 110, 116, 111, 32, 116, 104, 101, 32, 104, 101, 97, 114, 116, 115, 32, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 115, 32, 119, 111, 114, 108, 100, 119, 105, 100, 101, 46, 32, 87, 101, 32, 97, 108, 108, 32, 107, 110, 111, 119, 32, 119, 101, 32, 111, 117, 103, 104, 116, 32, 116, 111, 32, 34, 115, 117, 112, 112, 111, 114, 116, 32, 85, 110, 105, 99, 111, 100, 101, 34, 32, 105, 110, 32, 111, 117, 114, 32, 115, 111, 102, 116, 119, 97, 114, 101, 32, 40, 119, 104, 97, 116, 101, 118, 101, 114, 32, 116, 104, 97, 116, 32, 109, 101, 97, 110, 115, 226, 128, 148, 108, 105, 107, 101, 32, 117, 115, 105, 110, 103, 32, 119, 99, 104, 97, 114, 95, 116, 32, 102, 111, 114, 32, 97, 108, 108, 32, 116, 104, 101, 32, 115, 116, 114, 105, 110, 103, 115, 44, 32, 114, 105, 103, 104, 116, 63, 41, 46, 32, 66, 117, 116, 32, 85, 110, 105, 99, 111, 100, 101, 32, 99, 97, 110, 32, 98, 101, 32, 97, 98, 115, 116, 114, 117, 115, 101, 44, 32, 97, 110, 100, 32, 100, 105, 118, 105, 110, 103, 32, 105, 110, 116, 111, 32, 116, 104, 101, 32, 116, 104, 111, 117, 115, 97, 110, 100, 45, 112, 97, 103, 101, 32, 85, 110, 105, 99, 111, 100, 101, 32, 83, 116, 97, 110, 100, 97, 114, 100, 32, 112, 108, 117, 115, 32, 105, 116, 115, 32, 100, 111, 122, 101, 110, 115, 32, 111, 102, 32, 115, 117, 112, 112, 108, 101, 109, 101, 110, 116, 97, 114, 121, 32, 97, 110, 110, 101, 120, 101, 115, 44, 32, 114, 101, 112, 111, 114, 116, 115, 44, 32, 97, 110, 100, 32, 110, 111, 116, 101, 115, 32, 99, 97, 110, 32, 98, 101, 32, 109, 111, 114, 101, 32, 116, 104, 97, 110, 32, 97, 32, 108, 105, 116, 116, 108, 101, 32, 105, 110, 116, 105, 109, 105, 100, 97, 116, 105, 110, 103, 46, 32, 73, 32, 100, 111, 110, 39, 116, 32, 98, 108, 97, 109, 101, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 115, 32, 102, 111, 114, 32, 115, 116, 105, 108, 108, 32, 102, 105, 110, 100, 105, 110, 103, 32, 116, 104, 101, 32, 119, 104, 111, 108, 101, 32, 116, 104, 105, 110, 103, 32, 109, 121, 115, 116, 101, 114, 105, 111, 117, 115, 44, 32, 101, 118, 101, 110, 32, 51, 48, 32, 121, 101, 97, 114, 115, 32, 97, 102, 116, 101, 114, 32, 85, 110, 105, 99, 111, 100, 101, 39, 115, 32, 105, 110, 99, 101, 112, 116, 105, 111, 110, 46]
length: 616
Here, text is the raw string, which has a length of 533 Unicode code points. When we encode it into UTF-8 bytes and convert to a list of integers, we get a list tokens of length 616. The reason the byte sequence is longer is because some Unicode code points (like plain ASCII) only take up a single byte in UTF-8, but others (like various symbols and emoji) can take up to 4 bytes each.
Getting token statistics¶
We'll start with a function get_stats that takes a list of tokens (in our case, integers representing bytes) and returns a dictionary of counts for each pair of consecutive tokens:
def get_stats(ids):
counts = {}
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
stats = get_stats(tokens)
Here we use Python's zip function to iterate over consecutive pairs of tokens. The counts dictionary keeps track of how many times we've seen each pair.
Merging frequent pairs¶
Next, we need a merge function that takes the list of tokens, a pair to merge, and the new token ID to replace the pair with:
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
This function iterates through the token list, replacing occurrences of the given pair with the new token ID. It's careful not to go out of bounds when checking for the pair.
Iteratively merging to build vocabulary¶
Finally, we can put these pieces together in a loop to build up our vocabulary:
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list
merges = {} # (int, int) -> int
for i in range(num_merges):
stats = get_stats(ids)
pair = max(stats, key=stats.get)
idx = 256 + i
print(f"merging {pair} into a new token {idx}")
ids = merge(ids, pair, idx)
merges[pair] = idx
We decide on a final vocabulary size and do num_merges rounds of merging. In each round, we:
- Get the token pair statistics
- Find the most frequent pair
- Merge that pair into a new token ID
- Update our
mergesdictionary to remember this merge
After this process, we have our new compressed token list ids and our merges dictionary that tells us how to undo the merges.
Let's see how much we've compressed the original text:
print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")
In this example, we've gone from 24,597 tokens down to 19,438, a compression ratio of about 1.27X.