Wordpiece
Intro
In a previous post I have talked about Byte Pair Encoding, which is a popular subword tokenization algorithm that has been used to train popular models such as GPT, GPT 2 and BART (if you randomly found this post and want a good intro to tokenizers you can click here). In this follow up post I would like to talk about a very similar algorithm to the byte pair encoding tokenizer, the WordPiece tokenizer. WordPiece is the tokenization algorithm Google developed to pre-train BERT. It has since been reused in quite a few Transformer models based on BERT, such as DistilBERT, MobileBERT, Funnel Transformers, and MPNET. Although Google never open-sourced its implementation of the training algorithm of WordPiece, there is enough public knowledge about it to enable us to write some code for the implementation of the algorithm in this article. Just like its cousin Byte Pair Encoding, this algorithm can be divided into 2 parts: Learner and Segmenter. The Learner is responsible for taking our training data and tokenize it in order to create our vocabulary whereas the Segmenter splits the data that we wish to test upon.
learner
The learner of a WordPiece is very similar to the learner of a Byte Pair Encoding algorithm, in fact, one could argue that they are exactly the same algorithm with the only difference being the scoring function that it’s used. In Byte Pair Encoding the scoring function is more simple, its purely based on the frequencies of the pairs whereas WordPiece uses something a little more involved, instead of looking at just the frequencies the scoring tries to find the pair that maximizes the likelihood on the training data. To understand this better lets look at an example (For easier readability i’m using the underscore character to represent spaces):
Initialize the vocabulary with all the unique characters contained in the training data, also referred as corpus. For instance if our corpus was this text: “carpet car”, our initial vocabulary would be:
- “c”, “a”, “r”, “p”, “e”, “t”, “_”
Now that we have our vocabulary, we will perform whats called as a pre-tokenization to avoid merging characters that we specifically want separated. The purpose of the pre tokenization is to define the groups of characters that are allowed to merge together, without it we were at risk of having a nonsensical set of tokens like this one for example: [“carpet_c”, “ar”]. There are different ways of pre tokenazing our vocabulary. For simplicity, in this example our pre tokenization will be based on the spaces:
- [“c a r p e t _”, “c a r”]
Now that we have our vocabulary, we will again split the corpus into single characters but this time keep the repeated characters:
- “c a r p e t _ c a r”
Calculate a score of all of the consecutive pair of the characters obtained in step 2. The formula for the score is as follows (this formula was taken from this website):
- score=(freq_of_pair)/(freq_of_first_element×freq_of_second_element)
WARNING: Again, I repeat, google have never published an official implementation of this algorithm so this might not be 100% accurate but its enough to get the gist of how it works.
Having said that, if our corpus was for instance this text: “carpet car”, then the scores would be:
- score(c, a) = 2 / (2 x 2) = 1/2
- score(a, r) = 2 / (2 x 2) = 1/2
- score(r, p) = 1 / (2 x 1) = 1/2
- score(p, e) = 1 / (1 x 1) = 1
- score(e, t) = 1 / (1 x 1) = 1
- score(t, _) = 1 / (1 x 1) = 1
- score(_, c) = 1 / (1 x 1) = 1
Join all occurrences of the pair with the most number of occurrences and in case of having more than 1 pair with the maximum score just pick one of those pairs randomly. In this case the score is 1 and these pairs: < p e >, < e t >, < t _ > and < _ c > have that same number so, I have chosen randomly to concatenate the pair < p e >. After concatenating that pair, our corpus will look like this:
- “c a r pe t _ ca r”
Now we simply assume that the previously concatenated pair of characters is now a character in and of itself and repeat the previous two steps for many iterations. Now you must be wondering “for how many iterations should I run this algorithm?”, well, the answer for that is: It depends (one of data science’s favorite word). Usually what is done in practice is to define a number of tokens that we want and let the algorithm run until that threshold is met.
Segmenter
Okay cool, we have built our beloved vocabulary using the learner, but how do we apply that to new data? Simple, with just have to recursively loop through the word that we wish to segment and when we find a match we break the word. Lets look at an example. Consider the following vocabulary: “new”, “e”, “r”; If the text we wanted to tokenize using that vocabulary was: “newer” then it would first start by checking whether the word “newer” is a token itself contained in the vocabulary, in this case it isn’t so it would then go to the next iteration and check whether “newe” is contained in the vocabulary which again it isn’t, but then finally in the next iteration it would check for “new” which is indeed in our vocabulary so it would consider it to be a token and it would store it away in a list but we are not done yet; Although we identified “new” as a token, we still haven’t checked for remaining of the text: “er”; So we have to repeat the process explained so far for that particular substring and the other possible subsequent substrings that we might end up producing in the process.
Code
import collections
import re
def wordpiece_learner(train_data: dict) -> tuple:
steps = []
# vocabulary = list("abcdefghijklmnopqrstuvwxyz")
vocabulary = list({char for string in train_data.keys() for char in string.split()})
vocabulary.append("[UNK]")
for _ in range(3):
# count the number of pair occurences
pair_count = collections.defaultdict(int)
char_count = collections.defaultdict(int)
for chars, count in train_data.items():
chars = chars.split()
for i in range(len(chars) - 1):
pair_count[chars[i], chars[i + 1]] += count
char_count[chars[i]] += count
char_count[chars[i + 1]] += count
pair_likelihoods = {
pair: round(count / (char_count[pair[0]] * char_count[pair[1]]), 4)
for pair, count in pair_count.items()
}
# find the pair with maximum likelihood
pair_max = max(pair_likelihoods, key=pair_likelihoods.get)
# store steps needed for the segmentor phase
steps.append(pair_max)
# add the pair to the vocabulary
vocabulary.append("".join(pair_max))
# update the train_data dict by concatenating the max_pair
updated_train_data = {}
bigram = re.escape(" ".join(pair_max))
p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
for word in train_data.keys():
updated_word = p.sub("".join(pair_max), word)
updated_train_data[updated_word] = train_data[word]
train_data = updated_train_data
return steps, train_data, vocabulary
# this function was taken from the following site:
# https://d2l.ai/chapter_natural-language-processing-pretraining/subword-embedding.html
def segmenter(test_data: list, vocabulary: dict) -> list:
outputs = []
for token in test_data:
token = token.replace(" ", "")
start, end = 0, len(token)
cur_output = []
# Segment token with the longest possible subwords from vocabulary
while start < len(token) and start < end:
if token[start:end] in vocabulary:
cur_output.append(token[start:end])
start = end
end = len(token)
else:
end -= 1
if start < len(token):
cur_output.append("[UNK]")
outputs.append(" ".join(cur_output))
return outputs
def main():
# Our data is assumed to come on this format. each key represents
# a pre token with the value being how many times that pre token
# appears in the text from the test data. in this case our text data
# would look something like:
# "hug hug hug hug hug pugs pugs pun pun pun pun pun pun bun bun bun hugs hugs "
# The underscores represent the spaces
vocab = {
"h u g _": 5,
"p u g s _": 2,
"p u n _": 6,
"b u n _": 3,
"h u g s _": 2,
}
steps, final_train_data, vocabulary = wordpiece_learner(vocab)
print("final state of the training data:")
print(final_train_data)
print()
print("vocabulary")
print(vocabulary)
print()
input1, input2 = ["h u g s"], ["b u g s"]
print("segmenter")
print("results: ")
result1 = segmenter(input1, vocabulary)
result2 = segmenter(input2, vocabulary)
print(result1)
print(result2)
if __name__ == "__main__":
main()