UnigramLM: An Attempt at Writing The Missing Manual

TL;DR: This post is my attempt to write down the UnigramLM tokenization cleanly and explicitly—because the original paper is more of a sketch than a full spec, and the real algorithm is basically the implementation in the SentencePiece library. We’ll formalize the unigram generative model, derive the EM updates, explain why pruning is needed (and how it’s done), and point out the spots where the practical implementation quietly diverges from the tidy math. And hopefully, by the end, you’ll think the UnigramLM algorithm is just as cool as I do.

Origins of this blog post

(feel free to skip this section)

These days, tokenization is basically syonymous with Byte-pair Encoding (BPE). If you ask someone “do you know how tokenization works?”, there’s a decent chance you’ll get an answer like: “Yeah yeah, I know BPE.”"

But there’s this other tokenizer sitting right next to BPE in practice: UnigramLM (the SentencePiece “unigram” model). On paper it looks totally different. Instead of greedily merging pairs, it says: “let’s uncover latent tokens and treat tokenization like inference.” At least to me, that framing feels a lot more linguistically sane (or, at minimum, less like we’re playing subword Tetris). Naturally, I figured I should actually understand the algorithm. So I did what everyone does: I went to the original 2018 paper. That… didn’t get me very far. So then I went to the SentencePiece repo, hoping I could reconstruct the missing pieces from the code. After a brief flashback while staring at the C++ implementation to the terror of my undergraduate CS classes, I bailed on that approach too. Then I thought maybe the missing explanation was hiding in the HuggingFace documentation — but let’s just say that rabbit hole ended like this:

The HuggingFace documentation [on UnigramLM] describes a tokeniser that doesn’t exist. It should not be relied on as an explanation for UnigramLM, because it doesn’t even come close.
–Claude

The original paper gives a nice high-level story, and the code clearly works in practice, but I couldn’t find a single place that actually spells out the full generative model, why the algorithm is mathematically sound, or how all the little “engineering details” (like pruning and vocabulary initialization) fit into that picture. At that point, I figured to myself, well this is a unigram model. How complicated can it be? I can definitely reason through the logic myself. Turns out that was a tad bit naive. But I’m nothing if not stubborn, so here we are a few months later!

This post is what I wish I’d had at the start of that endeavor–an approachable but rigorous walkthrough of UnigramLM as a probabilistic model, showing why EM is a reasonable tool here, what the posterior over segmentations actually looks like, and how the SentencePiece-style implementation approximates all of this in practice. If you’ve ever felt that UnigramLM is “clear enough to use, but not clear enough to explain on a whiteboard,” my hope is that this takes you the rest of the way to really understanding it, and maybe even extending it. Because at least I think its a pretty cool algorithm that deserves some of BPE’s limelight.

Tokenization Background and Notation

So that we’re on the same page, let’s start with a formal definition of tokenization.

Let ${\mathbf{{s}}}=\langle{s}_{1},{s}_2,\dots\rangle$ be a string—a sequence of characters (or bytes) such that ${s}_{t}\in{\Sigma}$ for a base alphabet ${\Sigma}$. Let ${\mathcal{V}}$ be a finite set, where each ${v}\in{\mathcal{V}}$ consists of a sequence of symbols from ${\Sigma}\cup{\Gamma}$, where ${\Gamma}$ denotes a finite set of reserved symbols (e.g., whitespace markers, start/end tokens, etc.); we refer to ${\mathcal{V}}$ as our vocabulary and to ${v}$ as pieces.1 During tokenization, we wish to convert the sequence of characters/bytes ${\mathbf{{s}}}$ into a sequence of tokens ${\mathbf{{v}}}=\langle {v}_{1},\dots,{v}_{{m}}\rangle$, each of which is a piece in the set ${\mathcal{V}}$. We refer to this token sequence as a segmentation of ${\mathbf{{s}}}$, and it can informally be seen as just a different way of representing the original string.

A tokenization algorithm defines a mapping ${h}: {\Sigma}^* \rightarrow {\mathcal{V}}^*$ and the method for learning the parameters of this mapping. The application of ${h}$ (which we’ll call our tokenization function here) to a string is sometimes referred to as inference, although perhaps more commonly people just call this process “tokenizing a string.” For example, the byte-pair encoding (BPE) algorithm defines a ${h}$ that is parameterized by a list of merge pairs $\boldsymbol{\mu}=\langle({v}_1, {v}_1’),({v}_2, {v}_2’), \dots \rangle$ and the algorithm for learning $\boldsymbol{\mu}$. At inference, starting from the representation of ${\mathbf{{s}}}$ as just a sequence of symbols from the base vocabulary ${\Sigma}$, ${h}_{\boldsymbol{\mu}}$ goes through the text $i=1, \dots |\boldsymbol{\mu}|$ times. At step $i$, it replaces all co-occurrences of the pair $({v}_i, {v}_i’)$ with a new merged token (typically, of the form ${v}_i\circ{v}_i’$).2

Importantly, we assume that ${\mathbf{{s}}}$ can be reconstructed from ${\mathbf{{v}}}$ via a detokenization function ${g}: {\mathcal{V}}^* \rightarrow {\Sigma}^*$; often ${g}:$ is a simple mapping like string concatenation with some special symbol handling, e.g., ${g}({\mathbf{{v}}}) = {v}_{1}\circ\dots \circ {v}_{{m}}$. In what follows, we consider ${g}$ fixed and treat it as part of the model specification. All probabilities over strings and segmentations are defined with respect to this fixed choice of ${g}$. Notably, given just the vocabulary ${\mathcal{V}}$, there are often multiple valid ${\mathbf{{v}}}$ for which the application of our simple detokenization function ${g}$ would lead to the same ${\mathbf{{s}}}$. In other words, ${g}$ is generally non-injective. We use ${\mathcal{T}}_{\mathcal{V}}({\mathbf{{s}}}) \mathrel{\stackrel{\textnormal{ def}}{=}}{g}^{-1}({\mathbf{{s}}}) = {{\mathbf{{v}}}\in{\mathcal{V}}^* : {g}({\mathbf{{v}}}) = {\mathbf{{s}}}}$ to refer to the set of all valid token sequences that produce ${\mathbf{{s}}}$, i.e., the set-valued inverse of ${g}$.

Example 1 (A concrete example of the non-injectivity of $g$.). Consider a toy string ${\mathbf{{s}}}= hat$ and a small vocabulary ${\mathcal{V}}= {\text{h},\text{a},\text{t},\text{ha},\text{at}}$. Under our fixed detokenization function ${g}$ (simple concatenation of token symbol sequences), the set of all valid segmentations of ${\mathbf{{s}}}$ is $$\begin{aligned} {\mathcal{T}}_{\mathcal{V}}({\mathbf{{s}}}) ={ \langle \text{h}, \text{a}, \text{t} \rangle, \langle \text{ha}, \text{t} \rangle, \langle \text{h}, \text{at} \rangle}. \end{aligned} $$ where all three segmentations detokenize to the same string ${\mathbf{{s}}}= \text{hat}$ under ${g}$.

While it might not seem notable, the non-injectivity of ${g}$ is actually an interesting property of most tokenization schemes. For one, it’s motivated several variants of different tokenization algorithms in which the inference rule—the mapping ${h}:{\Sigma}^*\rightarrow{\mathcal{V}}^*$ that selects a particular element of ${\mathcal{T}}_{\mathcal{V}}({\mathbf{{s}}})$—is replaced or redefined, for example by sampling from a posterior over segmentations (Kudo 2018) or by changing the inference objective to something like minimizing token sequence length (Hofmann et. al., 2022; Schmidt et. al. 2024). It It also means that we should distinguish between the canonical tokenization of ${\mathbf{{s}}}$, which is ${h}({\mathbf{{s}}})$, and any other valid segmentation ${\mathbf{{v}}}\in {\mathcal{T}}_{\mathcal{V}}({\mathbf{{s}}})$ with ${\mathbf{{v}}}\neq {h}({\mathbf{{s}}})$, which are typically called non-canonical tokenizations. The existence of non-canonical tokenizations has implications for how one should actually compute the probability of a string under a language model using a given vocabulary. See Cao and Rimell (2021) for a more detailed discussion of non-canonical tokenizations and why they matter in practice.

What you came here for: UnigramLM

The UnigramLM tokenization algorithm (Kudo 2018) takes a probabilistic-modeling approach to string tokenization. It defines an ${h}$, together with an algorithm for learning its parameters, by treating tokenization as inference in a latent-variable generative model over strings—in particular, a unigram generative model.

Generative model

The UnigramLM tokenization algorithm assumes that each observed string ${\mathbf{{s}}}$ arises from a latent sequence of tokens ${\mathbf{{v}}}$, where tokens are drawn independently from a fixed probability distribution, i.e., from a unigram distribution over a fixed vocabulary. The data-generating distribution can thus be defined in terms of the unigram probabilities ${\boldsymbol{\phi}}\in \Delta^{|{\mathcal{V}}| - 1}$. Before we get to the definition of the data-generating distribution though, we have to establish some other definitions.

Warning about notation: To reduce the number of nested subscripts (and other similarly offensive notational choices), I’m going to primarily use random variables to describe this problem. Don’t worry, you’ll still get a nice sprinkling of nested subscripts even with the random variables! Just fewer than without. Sorry... As is standard, uppercase letters will denote random variables (e.g., $X$, $Z$), and bold uppercase letters will denote sequences of them (e.g., $\mathbf X$, $\mathbf Z$).

Formally, let ${V}$ be our token-valued random variable: a categorical random variable on ${\mathcal{V}}$ with $\sum_{{v}\in{{\mathcal{V}}}}{P({V}={v};{\boldsymbol{\phi}})}=1$. Occasionally for shorthand, we’ll use ${\phi_{v}}= {P({V}={v};{\boldsymbol{\phi}})}$ to refer to the unigram probability of the piece ${v}$. Let ${\mathbf{V}}$ be a random variable taking values in the space of token sequences ${\mathbf{{v}}}\in {\mathcal{V}}^*$. For the distribution of ${\mathbf{V}}$ to be a valid probability distribution on ${\mathcal{V}}^*$, we must specify a length prior, i.e., a random variable ${M}$ on $\mathbb{N}$ with $\sum_{{m}=0}^\infty P({M}={m})=1$.3 The UnigramLM algorithm then assumes token sequence ${\mathbf{{v}}}=\langle{v}_1,\dots,{v}_{m}\rangle$ are generated as $$ {m}\sim {M},\quad {v}_{t}\stackrel{\text{i.i.d.}}{\sim} {\small\mathrm{Categorical}}({\boldsymbol{\phi}}) (t=1,\dots,{m}) \tag{1} $$

We can thus define the distribution of ${\mathbf{V}}$ as $$ P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}}) \mathrel{\stackrel{\textnormal{ def}}{=}} P({M}=|{\mathbf{{v}}}|)\prod_{t=1}^{|{\mathbf{{v}}}|}{P({V}={v}_{{t}};{\boldsymbol{\phi}})} \tag{2} $$ The likelihood of a sequence conditional on a given length ${m}$ is then simply the product of its piece probabilities, i.e., Eq. (2) where the length prior term cancels out: $$ P({\mathbf{V}}={\mathbf{{v}}}\mid {M}={m};{\boldsymbol{\phi}}) = \prod_{{t}=1}^{{m}} {P({V}={v}_{t};{\boldsymbol{\phi}})}, \tag{3} $$ One thing to note is that the parameters of ${P({V};{\boldsymbol{\phi}})}{\cdot}$ are completely specified by ${\boldsymbol{\phi}}$. This isn’t the case with $P({\mathbf{V}};{\boldsymbol{\phi}})$, for which the parameters of ${M}$ must also be known to fully specify the distribution. We won’t add any additional notation to $P({\mathbf{V}};{\boldsymbol{\phi}})$ to specify the parameters of ${M}$, though, since ${M}$ is pretty much always ignored. Rather, in yet another moment of ’engineering convenience' winning out over ’theoretical elegance’, most people just compute token sequence probabilities in UnigramLM using Eq. 3.

Given the deterministic mapping ${g}$ from tokens to strings, we can derive the distribution over strings—our data-generating distribution—as a pushforward of the distribution over tokens. Let ${\mathbf{S}}$ be a random variable on ${\Sigma}^*$. The following relationship holds: $$ \begin{aligned} P({\mathbf{S}}={\mathbf{{s}}};{\boldsymbol{\phi}}) \mathrel{\stackrel{\textnormal{ def}}{=}}\sum_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})} P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}}) \end{aligned} \tag{4} $$

Some useful relationships between ${\mathbf{V}}$ and ${\mathbf{S}}$.

We can see from Eq. (4) that distribution of ${\mathbf{S}}$ is simply the marginal probability distribution over valid segmentations of ${\mathbf{{s}}}$ under ${\mathcal{V}}$. Applying Bayes’ rule then gives us the posterior over segmentations for a fixed ${\mathbf{{s}}}$: $$ P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}={\mathbf{{s}}}; {\boldsymbol{\phi}}) = \begin{cases} &\frac{P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})}{P({\mathbf{S}}={\mathbf{{s}}};{\boldsymbol{\phi}})} \quad \text{if } {\mathbf{{v}}}\in{\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})\\ &0 \quad \quad \text{ otherwise.} \end{cases} \tag{5} $$ By just moving some terms in Eq. (5) around, we also get the definition of the joint distribution over strings and token sequences: $$P({\mathbf{S}}={\mathbf{{s}}}, {\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}}) = P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})\mathbb{1}\{{\mathbf{{v}}}\in{\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})\}, $$

For the moment, let’s assume that we know ${\boldsymbol{\phi}}$, or at least have estimates for these parameters. At inference time (i.e., when segmenting text into tokens), the UnigramLM tokenization algorithm aims to find the most likely segmentation of ${\mathbf{{s}}}$ into tokens ${\mathbf{{v}}}= \langle {v}_1, {v}_2, \dots\rangle$ under the generative model (defined above) with these parameters. To this end, it uses a Viterbi-style algorithm: $$ \begin{aligned} {{h}_{{\boldsymbol{\phi}}}({\mathbf{{s}}})}&= \arg\max_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})} P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}={\mathbf{{s}}}; {\boldsymbol{\phi}})\\ &= \arg\max_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})} P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})\\ &= \arg\max_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})} P({M}=|{\mathbf{{v}}}|)\prod_{t=1}^{|{\mathbf{{v}}}|}{P({V}={v}_{t};{\boldsymbol{\phi}})}\\ &\overset{?}{=} \arg\max_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})} \prod_{t=1}^{|{\mathbf{{v}}}|}{P({V}={v}_{t};{\boldsymbol{\phi}})} \end{aligned} \tag{6} $$ where the second line follows from the relationship in Eq. (5) ($P({\mathbf{S}}={\mathbf{{s}}};{\boldsymbol{\phi}})$ does not depend on ${\mathbf{{v}}}$ and so it doesn’t affect the argmax). As we can see in Eq. 6, the length prior (${M}$) is part of the posterior distribution and should thus affect the Viterbi segmentation; intuitively speaking, it biases the distribution towards token sequences of certain lengths.

Example 2 (Effect of the length prior on Viterbi segmentation). Suppose a string ${\mathbf{{s}}}$ admits two valid segmentations ${\mathbf{{v}}}^{(1)}$ and ${\mathbf{{v}}}^{(2)}$ under ${\mathcal{V}}$, with lengths $|{\mathbf{{v}}}^{(1)}| = 1$ and $|{\mathbf{{v}}}^{(2)}| = 3$. Assume that the unigram probabilities are such that $$ \prod_{{t}=1}^{|{\mathbf{{v}}}^{(1)}|} {P({V}={v}^{(1)}_{t};{\boldsymbol{\phi}})} = \prod_{{t}=1}^{|{\mathbf{{v}}}^{(2)}|} {P({V}={v}^{(2)}_{t};{\boldsymbol{\phi}})} $$ so the two segmentations tie if we ignore the length prior. Now let the length prior favor shorter sequences, e.g. $$ P({M}=1) = 0.9, \qquad P({M}=3) = 0.1 $$ Then the full sequence probabilities become $$ \begin{aligned} P({\mathbf{V}}={\mathbf{{v}}}^{(1)};{\boldsymbol{\phi}}) &= P({M}=1) \prod_{{t}=1}^{|{\mathbf{{v}}}^{(1)}|} {P({V}={v}^{(1)}_{t};{\boldsymbol{\phi}})} = 0.9 \cdot C,\\ P({\mathbf{V}}={\mathbf{{v}}}^{(2)};{\boldsymbol{\phi}}) &= P({M}=3) \prod_{{t}=1}^{|{\mathbf{{v}}}^{(2)}|} {P({V}={v}^{(2)}_{t};{\boldsymbol{\phi}})} = 0.1 \cdot C, \end{aligned} $$ for some common factor $C$. The Viterbi segmentation under the full model (including the length prior) is therefore ${\mathbf{{v}}}^{(1)}$, while under the approximation that drops $P({M}=\cdot)$, the two segmentations are equally probable. This illustrates that the length prior can in principle have a non-trivial affect on the inference result. :::

For example, if $P({M}=3) \gg P({M}=1)$, the product of three piece probabilities (which should be fairly small) could win out over the probability of a single piece probability even if the product was originally much smaller. Consequently, a theoretically sound implementation of the method would define and keep this prior for inference, as in Eq. 6. Typically, though, this term is ignored (or at least sequence length probabilities are assumed to be constant for all valid lengths). Unless otherwise specified, when talking about inference, we will assume use of no length term is used for faithfulness to the original algorithm. It could potentially be interesting to look into the effects of this design choice!

The true parameters of the generative process ${\boldsymbol{\phi}}$ are unknown, however; this includes both the piece probabilities ${\phi_{v}}$ and the underlying vocabulary ${\mathcal{V}}$ over which they are defined. The UnigramLM tokenization algorithm (described next) proposes a method for coming up with an estimate of these parameters from text data.

Learning Model Parameters

Maximum likelihood estimation (MLE)—a standard approach to estimating model parameters—aims to find the model parameters that maximize the log-likelihood of our data. Under the UnigramLM assumptions about the generative process of strings, our “complete” dataset actually consists of $({\mathbf{{s}}},{\mathbf{{v}}})$ pairs, i.e., strings and the sequence of tokens that produced them. Thus, our complete dataset looks like $\mathcal{X} = {({\mathbf{{s}}}_i,{\mathbf{{v}}}_i)}_{i=1}^K$ and the complete-data log likelihood is defined as: $$ \begin{aligned} {\mathcal{L}}(\mathcal{X}; {\boldsymbol{\phi}}) &\mathrel{\stackrel{\textnormal{ def}}{=}}\log\prod_{i=1}^KP({\mathbf{S}}={\mathbf{{s}}}_i, {\mathbf{V}}={\mathbf{{v}}}_i;{\boldsymbol{\phi}})\\ &= \sum_{i=1}^K\log P({\mathbf{S}}={\mathbf{{s}}}_i, {\mathbf{V}}={\mathbf{{v}}}_i;{\boldsymbol{\phi}}) \end{aligned} \tag{7} $$ Eq. 7 is typically referred to as the complete data log-likelihood. If we actually had this complete data (and we knew ${\mathcal{V}}$), we would simply find the ${\boldsymbol{\phi}}$ that maximizes Eq. (7), which would be a fairly clean problem that is easy to solve given our assumptions about the underlying distributions. However, we only see the “post-processed” strings ${\mathbf{{s}}}= {g}({\mathbf{{v}}})$; the exact underlying pieces that form that string are unknown (can be any in ${\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})$ and we don’t even know ${\mathcal{V}}$!). So, we can instead try to maximize our observed data log-likelihood, i.e., the likelihood of just our strings under our data-generating distribution defined in Eq. (4). Given our relationships from earlier, we can define this likelihood in terms of ${\boldsymbol{\phi}}$: $$ \begin{aligned} {\mathcal{L}}({\mathcal{C}}; {\boldsymbol{\phi}}) &\mathrel{\stackrel{\textnormal{ def}}{=}}\log\prod_{i=1}^KP({\mathbf{S}}={\mathbf{{s}}}_i;{\boldsymbol{\phi}})\\ &= \log\prod_{i=1}^K\sum_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}}_i)} P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}}) %& \text{\color{gray}{Sub def in \cref{eq:pushforward}}} \\ &= \sum_{i=1}^K \log\sum_{{\mathbf{{v}}}\in{\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}}_i)} P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}}) % \\ &= \sum_{i=1}^K \log\sum_{\tokens\in\allsegmentations_{\vocab}(\str_i)} \qprior{|\tokens|}\prod_\tokindex \unigramdist{\token_\tokindex} \end{aligned} \tag{8} $$ where ${\mathcal{C}}= {{\mathbf{{s}}}\mid {\mathbf{{s}}}, \_ \in \mathcal{X} }$ is simply our observed set of strings, i.e., our corpus. Unfortunately, Eq. (8) is a difficult quantity to maximize directly due to the log–sum structure. Luckily, the expectation-maximization (EM) algorithm provides us a route for working with this situation.

Expectation-Maximization

EM was designed for exactly the use case where wish to get MLE estimates for a data-generating process in which only part of the data is unobserved. Formally, the EM algorithm uses Jensen’s inequality to relate the expected value of the complete data log-likelihood to the observed data log-likelihood, i.e., relating the expected value of Eq. (7) to Eq. (8). This is exactly the connection made by Kudo (2018) (even if not explicitly) when introducing their algorithm for approximating the parameters ${\boldsymbol{\phi}}$.

Expected complete-data log-likelihood under observed data and current parameters.

Let ${{\boldsymbol{\phi}}^{(n)}}$ denote our current belief about what the unigram parameters might be (more discussion on how we can initialize this distribution coming up!). For now, we will assume that the vocabulary is fixed. These random variables adhere to our original definitions from earlier. Note that when we use simply ${\boldsymbol{\phi}}$, we are referring to the distributions (and corresponding random variables) induced by a generic ${\boldsymbol{\phi}}$; these are the entities for which our parameters are free variables that we are optimizing.

The expected complete data log-likelihood under ${{\boldsymbol{\phi}}^{(n)}}$—which we denote as ${\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})$—follows simply from taking the expectated value of Eq. (7), given our observed data ${\mathcal{C}}$ and our current model parameters ${{\boldsymbol{\phi}}^{(n)}}$, i.e., the expected value under the posterior ${\mathbf{V}}\mid {\mathbf{S}};{{\boldsymbol{\phi}}^{(n)}}$. $$ \begin{aligned} {\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}}) &\mathrel{\stackrel{\textnormal{ def}}{=}} \mathop{\mathrm{\mathbb{E}}} \big[{\mathcal{L}}(\mathcal{X}; {\boldsymbol{\phi}}) \mid {\mathcal{C}}, {{\boldsymbol{\phi}}^{(n)}}\big]\\ &= \underset{ {\mathbf{V}}\mid {\mathbf{S}};{{\boldsymbol{\phi}}^{(n)}}}{\mathop{\mathrm{\mathbb{E}}}}\big[\sum_{i=1}^K \log P({\mathbf{S}}, {\mathbf{V}};{\boldsymbol{\phi}})\mid {\mathcal{C}}\big]\\ &= \sum_{i=1}^K\underset{{\mathbf{{v}}}\sim{\mathbf{V}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}}}{\mathop{\mathrm{\mathbb{E}}}}\big[\log P({\mathbf{S}}={\mathbf{{s}}}_i, {\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})\big] \end{aligned} $$ In words, we can think of ${\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})$ as the expected complete data log-likelihood where the (latent) segmentations are induced by the posterior with parameters ${{\boldsymbol{\phi}}^{(n)}}$, while the log-likelihood inside is evaluated using the candidate parameters ${\boldsymbol{\phi}}$.

Now we will show how this quantity relates to the observed data log-likelihood.

Observed data log-likelihood and Jensen’s inequality.

We start with a reminder of Jensen’s inequality, applied to our definition of $P({\mathbf{S}}={\mathbf{{s}}};{\boldsymbol{\phi}})$. For any valid distribution probability $P({\mathbf{V}}={\mathbf{{v}}})$, Jensen’s inequality tells us $$ \begin{aligned} \log P({\mathbf{S}}={\mathbf{{s}}};{\boldsymbol{\phi}}) &= \log \sum_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})} P({\mathbf{V}}={\mathbf{{v}}})\frac{P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})}{P({\mathbf{V}}={\mathbf{{v}}})}\\ &\ge \sum_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})} P({\mathbf{V}}={\mathbf{{v}}})\log \frac{P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})}{P({\mathbf{V}}={\mathbf{{v}}})} \end{aligned} $$ If we choose $P({\mathbf{V}}={\mathbf{{v}}})$ to be $P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}= {\mathbf{{s}}};{{\boldsymbol{\phi}}^{(n)}})$—the posterior under our current parameter beliefs for a fixed ${\mathbf{{s}}}$—and apply this to our definition of the observed data log-likelihood from Eq. (8), we get $$ \begin{aligned} {\mathcal{L}}&({\mathcal{C}};{\boldsymbol{\phi}})= \sum_{i=1}^K \log\sum_{{\mathbf{{v}}}\in{\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}}_i)} P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})\\ &\ge \sum_{i=1}^K\sum_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}}_i)}P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}}) \big[\log P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})-\log P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}})\big]\\ &= \sum_{i=1}^K\sum_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}}_i)}P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}}) \log P({\mathbf{S}}={\mathbf{{s}}}_i, {\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})\nonumber\\ & \qquad\qquad\qquad\qquad-\sum_{i=1}^K\sum_{{\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}}_i)}P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}})\log P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}})\\ &= \underbrace{\sum_{i=1}^K\underset{{\mathbf{{v}}}\sim {\mathbf{V}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}}}{\mathop{\mathrm{\mathbb{E}}}}\big[\log P({\mathbf{S}}={\mathbf{{s}}}_i, {\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})\big]}_{{\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})} +\sum_{i=1}^K{\mathrm{H}}\big({\mathbf{V}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}}\big) \\ &\geq {\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}}) \end{aligned} \tag{9} $$ Note that when going from the second to third lines in Eq. (9), we make use of the fact that for any ${\mathbf{{v}}}\in {\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}}_i)$ we have $P({\mathbf{S}}={\mathbf{{s}}}_i, {\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}}) = P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})$ by definition. Then, we’re simply using the equivalence of these values with the definitions of expected values and (Shannon) entropy, respectively.

Eq. (9) is typically referred to as the evidence lower bound (ELBO)—a proxy objective that is often used in machine learning. For example, it’s used for training variational autoencoders, where it provides a tractable lower bound on the intractable log-likelihood of the data under a latent-variable model. In the case of EM, we go one step further and use one of the components of the ELBO as our proxy objective for observed data log-likelihood: the expected complete data log-likelihood ${\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})$. And this is the basis the EM algorithm, which iteratively updates ${\boldsymbol{\phi}}$ by choosing the value of it that maximizes ${\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})$ until convergence.

So now that we’ve gone through all the parts, here’s a quick recap of the EM algorithm: it’s an iterative algorithm for approximating MLE estimates. The E step computes the expected complete data log-likelihood under current beliefs about model parameters (in our case, ${{\boldsymbol{\phi}}^{(n)}}$); this quantity is standing in for observed data log-likelihood, which is a much more difficult quantity to compute. The M step then solves for the free parameters (in our case, ${\boldsymbol{\phi}}$) that maximize this quantity, and then updates our current beliefs to the new quantity. I’ll omit the proof of why this should converge (for a fixed ${\mathcal{V}}$) since, well, it’s in a lot of ML textbooks (you know, those ones we all swear we’ll read cover-to-cover someday...)

The UnigramLM Algorithm

The UnigramLM algorithm is typically seen as a “simple” application of EM. This, however, is not exactly the case. Importantly, EM assumes that the support of the distribution whose parameters we’re trying to estimate is known (and fixed), i.e., that we know ${\mathcal{V}}$. But, as discussed earlier, we don’t know ${\mathcal{V}}$! The UnigramLM algorithm addresses this by beginning with an intentionally overcomplete initial vocabulary and progressively reducing it through a heuristic pruning step, which is done after an iteration of the standard E-step and M-step, throughout which ${\mathcal{V}}$ is held fixed. In short, as the algorithm iteratively re-estimates the model parameters, it gradually shrinks ${\mathcal{V}}$ toward the desired final size by removing pieces that are seemingly unimportant for achieving good corpus log-likelihood. You can think of this as putting your vocabulary on a strict likelihood-based diet: pieces that don’t contribute enough to explaining the data get gently but firmly removed.

Where its necessary, to make this dependence explicit, we will use ${{\mathcal{V}}_{{n}}}$ to denote the current vocabulary. To reduce notational clutter, in defining the below algorithm, we’ll use just ${\mathcal{V}}$; at step $n$ of the algorithm, you can assume ${\mathcal{V}}={{\mathcal{V}}_{{n}}}$ (and that all random variables are defined over ${{\mathcal{V}}_{{n}}}$) unless otherwise stated.

  1. Initialization: Define an initial vocabulary ${\mathcal{V}}_0$. This could be something like all possible substrings of ${\mathcal{C}}$, subject to a maximum length constraint.4 Initialize ${\boldsymbol{\phi}}^{(0)}$ by some heuristic: the simplest would be uniform initialization, i.e., all pieces are assigned probability $1/|{\mathcal{V}}_0|$.

  2. Perform EM for $n=1, \dots N$ iterations or until piece probability estimates converge:

    i. E-step (Expected data log-likelihood computation): The E-step in EM is for computing the expected complete data log-likelihood under our current parameter beliefs ${\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})$. It turns out that expected token counts are a sufficient statistic for the M-step objective in this case, and so our problem boils down to computing expected token counts under ${{\boldsymbol{\phi}}^{(n)}}$. To see why this is the case… First, we define the count function on token sequences as $$c_{{v}}({\mathbf{{v}}}) \mathrel{\stackrel{\textnormal{ def}}{=}}\sum_{t=1}^{|{\mathbf{{v}}}|}\mathbb{1}\{{v}_{t}= {v}\} \tag{10}$$ Then, note that for any valid ${\mathbf{{s}}},{\mathbf{{v}}}$ such that ${\mathbf{{v}}}\in{\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})$, we can write $$ \begin{aligned} \log P({\mathbf{S}}={\mathbf{{s}}}, {\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})&=\log P({\mathbf{V}}={\mathbf{{v}}};{\boldsymbol{\phi}})\\ &=\log P({M}=|{\mathbf{{v}}}|)+\sum_{t=1}^{|{\mathbf{{v}}}|}\log {P({V}={v}_{t};{\boldsymbol{\phi}})}\\ &=\log P({M}=|{\mathbf{{v}}}|)+ \sum_{{v}\in{\mathcal{V}}} c_{{v}}({\mathbf{{v}}})\log {P({V}={v};{\boldsymbol{\phi}})} \end{aligned} $$ Substituting these relationships into our definition of ${\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})$ and using the linearity of expectations rule, we get $$\begin{aligned} {\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}}) = \underbrace{\sum_{i=1}^K\underset{{\mathbf{{v}}}\sim {\mathbf{V}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}}}{\mathop{\mathrm{\mathbb{E}}}}[\log P({M}=|{\mathbf{{v}}}|)]}_{\text{constant in }{\boldsymbol{\phi}}} +\sum_{i=1}^K\sum_{{v}\in{\mathcal{V}}} \underbrace{\underset{{\mathbf{{v}}}\sim {\mathbf{V}}\mid {\mathbf{S}}={\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}}}{\mathop{\mathrm{\mathbb{E}}}\left[ c_{{v}}({\mathbf{{v}}})\right]}}_{\mathrel{\stackrel{\textnormal{ def}}{=}}\widetilde{c}_{{v}}({\mathbf{{s}}}_i;{{\boldsymbol{\phi}}^{(n)}})}\log {P({V}={v};{\boldsymbol{\phi}})} \end{aligned} \tag{11} $$ where $\widetilde{c}_{{v}}({\mathbf{{s}}};{{\boldsymbol{\phi}}^{(n)}})$ are simply expected token counts under unigram model parameters ${{\boldsymbol{\phi}}^{(n)}}$, which can be computed as $\widetilde{c}_{{v}}({\mathbf{{s}}};{{\boldsymbol{\phi}}^{(n)}})= \sum_{{\mathbf{{v}}}\in{\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})} c_{{v}}({\mathbf{{v}}}) P({\mathbf{V}}={\mathbf{{v}}}\mid {\mathbf{S}}={\mathbf{{s}}};{{\boldsymbol{\phi}}^{(n)}})$. Lastly, if we define the corpus-level expected counts as $$ \widehat{c}_{{v}}({\mathcal{C}};{\boldsymbol{\phi}}) \mathrel{\stackrel{\textnormal{ def}}{=}} \sum_{{\mathbf{{s}}}\in{\mathcal{C}}} \widetilde{c}_{{v}}({\mathbf{{s}}};{\boldsymbol{\phi}}) \tag{12} $$ and substitute them into our expansion in Eq. 11, then the equality reduces to $$ {\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}}) = \text{const} + \underbrace{\sum_{{v}\in{\mathcal{V}}}\widehat{c}_{{v}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}})\log {P({V}={v};{\boldsymbol{\phi}})}}_{\mathrel{\stackrel{\textnormal{ def}}{=}}\bar{{\mathcal{Q}}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})} \tag{13} $$ where we have added the definition of $\bar{{\mathcal{Q}}}$ (simply ${\mathcal{Q}}$ without the “$\mathrm{const}$” term) since it will be useful later. From the above, we can see that the posterior expected counts are sufficient statistics for the M-step objective ${\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})$. In practice, the per-string expected counts $\widetilde{c}_{{v}}({\mathbf{{s}}};{\boldsymbol{\phi}})$ can be computed efficiently using a forward–backward dynamic program defined over the segmentation lattice induced by ${\mathcal{T}}_{{\mathcal{V}}}({\mathbf{{s}}})$. In words, this lattice forms a directed acyclic graph: nodes correspond to positions in the string and edges originating from the nodes correspond to tokens ${v}\in {\mathcal{V}}$ that can begin at that position and end at another (i.e., pieces whose symbol sequences match the substring). Each edge is weighted by the token’s probability under the current parameters, ${\phi^{(n)}_{v}}$. Valid paths in this graph correspond to a valid segmentation of ${\mathbf{{s}}}$. The forward–backward algorithm then marginalizes over all valid paths in this graph to compute the posterior probability of each token’s occurrence, from which the expected counts follow. A somewhat interesting observation is that this method of getting counts uses an inference procedure that is different from what is done when actually tokenizing text. In the latter case, only the maximum probability segmentation is ultimately used. Here, though, we consider all segmentations of a ${\mathbf{{s}}}$ that have non-zero probability, weighting the token counts from this segmentation (token sequence) by the probability of the segmentation under our current parameters ${{\boldsymbol{\phi}}^{(n)}}$. Also of note is that this is where a length prior could have an effect on the model parameters we learn. But the term is often never actually used in the model definition.

    ii. M-step (maximize ${\boldsymbol{\phi}}$ and update ${{\boldsymbol{\phi}}^{(n)}}$): In the M-step, we want to maximize ${\mathcal{Q}}({\boldsymbol{\phi}};{{\boldsymbol{\phi}}^{(n)}})$ with respect to ${\boldsymbol{\phi}}$ subject to these parameters giving us a valid probability distribution, i.e., $\sum_{{v}\in{\mathcal{V}}}{\phi_{v}}=1$ and ${\phi_{v}}\ge 0$. Subbing in the relationship established in Eq. 13, this actually boils down to a relatively simple problem: finding the ${\boldsymbol{\phi}}$ that maximizes the probability of having observed the expected counts that we got from the segmenting the corpus according to our prior model parameter beliefs: $$ \begin{aligned} \max_{{\boldsymbol{\phi}}}&\sum_{{v}\in{{\mathcal{V}}}}\widehat{c}_{{v}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}})\log {P({V}={v};{\boldsymbol{\phi}})}\\ &\text{s.t.}\quad \sum_{{v}\in{\mathcal{V}}}{\phi_{v}}=1,{\phi_{v}}\ge 0 \end{aligned} $$ The solution (normalized expected counts) is very recognizable, as it is essentially the same as the MLE for a standard multinomial distribution, albeit using expected counts rather than pure counts: $$ {\phi^{(n+1)}_{v}} = \frac{\widehat{c}_{{v}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}})} {\sum_{{v}’\in{{\mathcal{V}}}}\widehat{c}_{{v}’}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}})}. \tag{14} $$ The length-prior term is constant in ${\boldsymbol{\phi}}$ and does not alter the update (for fixed ${\mathcal{V}}$).

    iii. Pruning: After applying the above steps, the vocabulary itself will not have changed (only the per-piece probabilities are updated). We can optionally prune some subset of pieces, creating a new ${\mathcal{V}}_{n+1}$. For example, if the the initial vocabulary consists of 100 pieces and our desired final vocabulary size is 30 pieces, one might remove the 10 least “important” pieces after each of the first seven EM iterations. Following pruning, the remaining probabilities in ${\boldsymbol{\phi}}^{(n+1)}$ are renormalized to form a valid distribution over ${\mathcal{V}}_{n+1}$. We elaborate more on this process (and how we determine token “importance”) below.

Strategies for pruning tokens.

Because the initial vocabulary ${\mathcal{V}}_0$ is typically over-complete (often $|{\mathcal{V}}_0| \gg |{\mathcal{V}}|$), UnigramLM applies a pruning step within the EM iterations to gradually reduce vocabulary size. At iteration $n$, after we’ve determined our updated parameters ${\boldsymbol{\phi}}^{(n+1)}$, the algorithm removes tokens whose absence leads to the smallest decrease in (our proxy for) observed data log-likelihood. Intuitively, we prune tokens that contribute least to explaining the data under the current model.

Formally, let $\bar{{\mathcal{Q}}}({\boldsymbol{\phi}}^{(n+1)};{{\boldsymbol{\phi}}^{(n)}})$ be our expected complete data log-likelihood under updated model parameters (albeit still under the segmentations according to ${{\boldsymbol{\phi}}^{(n)}}$). We define the contribution (or “loss”) associated with token ${v}$ as the change (typically a decrease) in the corpus log-likelihood when ${v}$ is removed from the model: $$ {L}({v}) \mathrel{\stackrel{\textnormal{ def}}{=}} \bar{{\mathcal{Q}}}({\boldsymbol{\phi}}^{(n+1)};{{\boldsymbol{\phi}}^{(n)}}) - \bar{{\mathcal{Q}}}({\boldsymbol{\phi}}^{(n+1)}_{-{v}};{{\boldsymbol{\phi}}^{(n)}_{-{v}}}), \tag{15} $$

The notation ${{\boldsymbol{\phi}}^{(n)}_{-{v}}}$ in Eq. (15) refers to the unigram distribution obtained from ${{\boldsymbol{\phi}}^{(n)}}$ by removing ${v}$ from its support and renormalizing the remaining probabilities. The corresponding string-level distribution is thus identical to the one induced by ${{\boldsymbol{\phi}}^{(n)}}$, except that all segmentations containing ${v}$ are assigned zero probability and individual piece probabilities are renormalized over ${\mathcal{V}}\setminus {{v}}$ (this logic also applies to ${\boldsymbol{\phi}}^{(n+1)}_{-{v}}$).

Computing $\bar{{\mathcal{Q}}}({\boldsymbol{\phi}}^{(n+1)}_{-{v}};{{\boldsymbol{\phi}}^{(n)}_{-{v}}})$ in Eq. 15 for a given ${v}$ generally requires a separate forward–backward pass over the corpus. This is because disallowing the use of ${v}$ in segmentations changes both the set of valid paths and the total probability of those paths. 5 The new per-string marginal probabilities (and expected token counts) under ${\boldsymbol{\phi}}^{(n+1)}_{-{v}}$. cannot, in general, be recovered from forward/backward marginals computed under ${{\boldsymbol{\phi}}^{(n)}}$. Hence, we would need a fresh forward–backward evaluation on the pruned lattice to obtain the exact $\bar{{\mathcal{Q}}}({\boldsymbol{\phi}}^{(n+1)}_{-{v}};{{\boldsymbol{\phi}}^{(n)}_{-{v}}})$.

Performing a separate forward–backward pass for each piece in the vocabulary whenever we want to prune is impractical for vocabularies of any reasonable size. For example, if our initial vocabulary is a mere 100$k$, then computing per-piece losses would require 100$k$ forward passes of the corpus on its own. In practice, approximations that reuse the statistics computed from the current EM iteration are done. We discuss those next.

After computing ${L}({v})$ for all ${v}\in {\mathcal{V}}_n$, we remove the $k_n$ tokens with the smallest losses. Intuitively, this can be seen as removing the tokens whose removal incurs the least penalty on the corpus log-likelihood. The pruning rate $k_n$ is a hyperparameter of the algorithm. It is often chosen such that the target vocabulary size $|{\mathcal{V}}|$ is reached after a fixed number of pruning iterations $J$, e.g., $$ k_n = \min\left( \frac{1}{J}\bigl(|{\mathcal{V}}_0| - |{\mathcal{V}}|\bigr), \max\bigl(0, |{\mathcal{V}}_n| - |{\mathcal{V}}|\bigr) \right). \tag{16} $$

Approximations of ${L}$.

To avoid the need to resegment the corpus to compute each ${v}$’s loss, several approximations can be used to compute per-piece losses. A simple approximation would be to use the log-likelihood contribution of a token as its loss, i.e., $\widehat{L}({v}) \approx \widehat{c}_{{v}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}})\log {P({V}={v};{{\boldsymbol{\phi}}^{(n)}})}$. An arguably more sound approximation (and the one used by the original implementation of UnigramLM found in the SentencePiece library) is to look at the change in corpus log-likelihood when simply replacing ${v}$ by the best alternative segmentation of that piece, i.e., the best alternative segmentation of the string ${g}({v})$ when ${v}$ is not in the vocabulary.

Formally, let ${\mathbf{{v}}}’ = {h}_{{{\boldsymbol{\phi}}^{(n)}_{-{v}}}}({g}({v}))$ be the best segmentation of the string ${\mathbf{{s}}}= {g}({v})$ under ${{\boldsymbol{\phi}}^{(n)}_{-{v}}}$.6 The approximate loss is then the change to corpus log-likelihood when replacing every use of ${P({V}={v};{{\boldsymbol{\phi}}^{(n)}})}$ with $\prod_{{t}=1}^{|{\mathbf{{v}}}’|}{P({V}={v}’_{t}; {{\boldsymbol{\phi}}^{(n)}_{-{v}}})}$ under the new renormalized unigram probabilities ${{\boldsymbol{\phi}}^{(n)}_{-{v}}}$. This loss can be computed concisely as: $$ \widehat{L}({v})\approx \widehat{c}_{{v}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}})\left[\log {P({V}={v};{{\boldsymbol{\phi}}^{(n)}})} - \log \prod_{{t}=1}^{|{\mathbf{{v}}}’|}{P({V}={v}’_{t}; {{\boldsymbol{\phi}}^{(n)}_{-{v}}})}\right] $$

Example 3 (Toy pruning example). Suppose our corpus contains the string ${\mathbf{{s}}}= \text{``internationalization’’}$

and our vocabulary includes the tokens $$ { \text{international},\quad \text{inter},\quad \text{national},\quad \text{ization},\quad \text{al},\ldots } $$ Assume that under the current parameters ${{\boldsymbol{\phi}}^{(n)}}$, the posterior expected corpus-level counts are $$ \widehat{c}_{\text{international}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}}) \ll \widehat{c}_{\text{inter}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}}), \widehat{c}_{\text{national}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}}), \widehat{c}_{\text{ization}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}}) $$

To approximate the contribution of ${v}_{\text{international}}$, we consider its best alternative segmentation when it is removed from the vocabulary. Let $$ {\mathbf{{v}}}’ = \langle \text{inter}, \text{national} \rangle $$ be the Viterbi segmentation of the string ${g}(\text{international})$ under the renormalized distribution ${{\boldsymbol{\phi}}^{(n)}_{-{v}}}$. The approximate loss associated with pruning $\text{international}$ is then $$ \begin{aligned} \widehat{L}(\text{international};{{\boldsymbol{\phi}}^{(n)}}) &\approx \widehat{c}_{\text{international}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}}) \log \frac{ {P({V}=\text{international};{{\boldsymbol{\phi}}^{(n)}})} }{ {P({V}=\text{inter}; {{\boldsymbol{\phi}}^{(n)}_{-{v}}})} \cdot {P({V}=\text{national}; {{\boldsymbol{\phi}}^{(n)}_{-{v}}})} }. \end{aligned} $$ Intuitively, if ${v}_{\text{international}}$ is both rare (small $\widehat{c}_{\text{international}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}})$) and easily replaced by a segmentation whose product of probabilities is similar to ${P({V}=\text{international};{{\boldsymbol{\phi}}^{(n)}})}$, then its (approximate) loss will be small, making it a good candidate for pruning.

While this approximation does not account for changes in other valid paths’ probabilities that might happen as a result of removing ${v}$ from the vocabulary, it seems to work fairly well in practice as a pruning heuristic (although I don’t believe that anyone has actually tried to run the algorithm with the real, brute-force loss computation).

Implementation in the SentencePiece library

In practice, the UnigramLM algorithm as we know it is largely defined by the public SentencePiece implementation, since Kudo (2018) give only a high-level description and leave many engineering choices under-specified. The library makes a number of concrete design decisions that go beyond the abstract EM + pruning picture above.

First, SentencePiece runs a fixed EM+prune schedule rather than iterating EM to convergence on a fixed vocabulary. Each outer iteration consists of a small fixed number of EM “sub-iterations” (typically two), after which the vocabulary is pruned by a fixed shrinking factor, and training stops once the target vocabulary size is reached. Second, the seed vocabulary is not “all substrings up to length $L$”: SentencePiece uses an Enhanced (Extended) Suffix Array procedure to mine a large lexicon of frequent substrings from the corpus (on the order of $10^6$ pieces by default), subject to length and frequency thresholds. Pruning is performed as described above in the approximations section, i.e., a piece’s loss is approximated by assuming that the removed piece’s probability mass transfers to its best alternative Viterbi segmentation (notably, pieces that do not appear in the current Viterbi segmentation of the corpus are pre-pruned). Third, SentencePiece does not use the plain MLE M-step update from Eq. (14). Instead, it performs a Bayesian-EM style update with an implicit Dirichlet(-process) prior, replacing expected counts by their digamma-transformed counterparts: ${\phi^{(n+1)}_{v}}\propto\exp(\Psi(\widehat{c}_{{v}}({\mathcal{C}};{{\boldsymbol{\phi}}^{(n)}})))$. Mathematically, this is equivalent to using the posterior expectation of $\log {\phi_{v}}$ under a Dirichlet posterior. Intuitively, this “softens” the update by damping very large counts and preventing rare pieces from collapsing to near-zero probability too early.

Arguably some of the more critical design choices to be aware of are those pertaining to normalization and pretokenization, as these change which segmentations are feasible. By default it applies NFKC normalization, collapses whitespace, inserts a dummy-prefix marker, and treats whitespace (and often script/number boundaries) as explicit segmentation cues, i.e., as markers that can be suffixes or prefixes of pieces, but that pieces cannot cross. Most of these behaviors can be disabled via training flags but the fact that they’re used is not well advertised.

Taken together, these implementation details instantiate one particular, very specific version of the abstract UnigramLM model described above, albeit the one that people are typically referring to (rather than an implementation-free mathematical ideal) when talking about “the UnigramLM tokenization algorithm.”

Acknowledgments

As with pretty much any technical work I’ve written, Tiago Pimentel provided critical commentary and recommendations for this blogpost. Thanks, PhD sibling :)


  1. ${v}$ are also sometimes called subwords; we avoid this naming because ${v}$ need not align with orthographic words, in their typical definition. ↩︎

  2. $\circ$ denotes string concatenation and when applied to tokens, indicates the pieces’ symbols are concatenated together (perhaps with some special formatting if symbols from ${\Gamma}$ are present in the piece). ↩︎

  3. The distribution can also be made proper with the use of an EOS symbol, which is the more common way of specifying a language model. The use of ${M}$ in this situation (a non-autoregressive model) is a bit more general (if the distribution of ${M}$ follows a power law, then our distribution over token sequences could equivalently be represented using an EOS symbol). The use of ${M}$ though allows us to handle sequence length without adding a special token to our vocabulary. ↩︎

  4. There are several ways that this seed vocabulary can be created. The Enhanced Suffix Array is one of the more common algorithms. Often, pretokenization is performed on the corpus and one of the more common pretokenization rules splits on whitespace, preventing pieces from crossing whitespace boundaries, although that’s kind of an arbitrary rule… ↩︎

  5. Explicitly, relative to the original model, two coupled changes occur when removing ${v}$ from ${{\mathcal{V}}_{{n}}}$: (i) the feasible set of paths shrinks from ${\mathcal{T}}_{{{\mathcal{V}}_{{n}}}}({\mathbf{{s}}})$ to ${\mathcal{T}}_{{\mathcal{V}}_{n+1}}({\mathbf{{s}}})$ (all segmentations using ${v}$ are removed); (ii) the per-edge weights change after the renormalization of ${{\boldsymbol{\phi}}^{(n)}}$ and the marginal probabilities of remaining paths must be recomputed. ↩︎

  6. This segmentation may need to include an UNK token depending on the base vocabulary. ↩︎

Postdoc in Computer Science

My research interests include decoding methods for sequence models and the general applications of information theory and statistics to NLP