Constrained decoding with weighted automata
I want to share a take on grammar-constrained generation I’ve been working on for a while. The idea is to use weighted automata to read GLR parse stacks and compute token masks, with no parsing during mask generation. The goal is stable worst-case performance even on difficult inputs and complex grammars.
The problem
You have a grammar (often a JSON Schema) and you want the LLM to produce grammatically valid output.
At each decoding step, the LLM outputs logits over the vocabulary . You sample a token, append it to the output, and repeat.
Constrained decoding is: don’t sample tokens that would make the output grammatically invalid.
So we maintain a constraint state and, for each step, compute a mask of allowed tokens, apply it to the logits, and sample from what remains.
No more broken JSON.
The hard part is speed.
Two primitive operations
Think of constrained decoding as two functions:
get_mask(state)returns a mask over the LLM vocabulary.commit(state, token)updates the constraint state with the chosen token.
At runtime, the loop looks like:
loop:
parallel:
GPU: logits
CPU: mask
apply mask
sample token
commit token to model and constraint
Why worst-case latency matters
LLM inference is batched: multiple user requests run on the GPU at once. And mask generation sits directly on the critical decoding path.
If you want to generate 1k tokens per second, you have 1 millisecond to produce a mask. And if you miss even a single one - if a mask takes 2ms instead of 1 - it screws up scheduling; the GPU sits waiting on the straggler mask, and suddenly your p99 latency is trash.
Masks can’t be late; they have to arrive on time every time. It’s worst-case performance that matters.
It’s not “can you do masking fast on average?” It’s “can you do it fast every single step?”
Commit ~ incremental parsing
The commit op is analogous to incremental parsing: you’re incrementally parsing a growing prefix.
Many battle-tested incremental parsing systems like tree-sitter and Lezer use Generalized LR (GLR) parsing with a graph-structured stack (GSS).
- An LR parser maintains a stack of parser state IDs.
- It decides what to do by looking at the top of stack plus the next terminal symbol.
- The “Generalized” part handles ambiguity: instead of a single stack, you keep a graph of possible stacks sharing structure (the GSS). When ambiguity arises, stacks fork; when they converge to the same state, they merge; error paths are pruned.
- LLM Token
- An element of the model vocabulary—a BPE token with an ID and byte string. ~200k of these in modern tokenizers.
- Grammar Terminal
- The symbols your parser actually consumes. Might be bytes, characters, or lexer tokens like
--. - Parser State ID
- An integer identifying the LR automaton's current state. The parser stack is a sequence of these.
Generating masks
Now you have a parse state (often a GSS), and you want a mask over LLM tokens.
A straightforward approach is:
For each LLM token , check whether appending it keeps the parse valid.
But you don’t append an LLM token to the parser. You append terminal symbols.
And a single LLM token is a byte string which might:
- expand into multiple terminals (
"}\n"could lex asRBRACE NEWLINE), - end partway through a terminal (halfway through a string literal, halfway through a number, halfway through
true), - be lexically ambiguous,
- or produce no terminals yet (e.g. it adds bytes inside a string, but doesn’t close it).
So “try token ” is really “try every way ’s bytes could lex into some terminal sequence, and then try advancing the parser for that terminal sequence.”
Tries/trellises and why they still hurt in p99.9 land
A common direction is:
- Build a trie/trellis representing how LLM tokens map to terminal sequences.
- Traverse that trie while advancing the parser state, pruning invalid paths.
- Collect the set of LLM tokens that survive.
This can work well for many grammars and workloads. People implement this and get decent average performance, especially for “simple” grammars.
But for general CFG-ish constraints (especially when you care about worst-case), you run into two issues:
1) You end up doing real parser work “inside” masking
Even if your lexical side is deterministic (say you enforce longest-match, or you have a DFA lexer), a single grammar terminal shift can trigger a chain of reductions.
In an LR parser, “shift terminal ” is often really:
- while top-of-stack wants to reduce, do some reductions (pop some states, push a goto state),
- then shift .
Those reduction chains are table-driven and fast in the happy path, but in GLR they can fan out. The operations are pointer-heavy and cache-unfriendly: you’re manipulating a GSS, merging nodes, pruning branches, etc.
If you do that inside masking—i.e. while traversing a trie of possible tokens—you’re effectively doing speculative GLR work for a large number of hypothetical tokens on every step.
That’s exactly the kind of work you do not want on the critical path.
2) You still pay for long/ugly tokens
Even with tries, some tokens’ byte strings correspond to long sequences of grammar terminals. Sometimes those get quickly pruned. Sometimes they don’t.
- Lexing (grammar tokenization)
- Splitting source text into grammar terminals like
MINUS,NUMBER,IDENT. This is what your parser consumes. - Tokenization (BPE/LLM tokenization)
- Splitting text into LLM tokens—subword units from the model's vocabulary (~200k entries). This is what the model produces and what we're masking.
For example, Python will happily accept monstrosities like:
Those sixteen hyphens may be tokenized (BPE) as a single LLM token:
But that token lexes to 16 MINUS terminals:
Every time a mask is generated, the parser must ingest all 16 MINUS terminals of token 7535. This forces it to do a lot of terminal-level processing—16 shifts through the grammar, each of which may involve many reductions—triggered by a single LLM token.
You can transform the grammar to remove unit/null reductions. You can aggressively merge and simplify GSS nodes. You can add clever memoization, early exits. But in my experience, these pesky LLM tokens with complicated lexes are just really hard to engineer away.
I tried hard to make the “trie + incremental parse simulation” approach behave well in worst-case latency terms. In my experience, it’s a dead end if you’re aiming for predictable sub-millisecond masking on arbitrary inputs/grammars.
Invert the problem
Token validity depends on what’s on the stack. Instead of asking “does LLM token t work on this stack?” for each token, ask “given this stack, which LLM tokens work?”
That’s the reframing. Now we need a data structure that turns “stack → allowed tokens” into something we can execute quickly.
That’s where the weighted automaton comes in.
Weighted automata over parser states
Think of a finite automaton, except each transition carries a weight, and traversing paths accumulates weights.
In our weighted automata:
- Input symbols are LR parser state IDs (i.e. elements of the parse stack).
- Weights are bitsets over the LLM vocabulary.
- Traversing a transition applies an intersection (filter tokens).
- Merging paths onto one state applies a union (tokens valid via any path).
The automaton reads (a representation of) your current parse configuration:
- For a single LR stack: the sequence of state IDs from top to bottom.
- For GLR: a GSS, i.e. a DAG representing many possible stacks sharing suffixes.
A token is valid if it’s valid on any parse path in the GSS, so on a GSS you just union the results across paths.
Another way to see it
Another way to see it (which helped me):
- Fix an LLM token .
- Consider the set of all possible lexes of (each lex is a terminal sequence).
- For each such lex (terminal sequence), consider all stack configurations from which that sequence can be legally parsed without error.
That set of stacks is a regular language over parser state IDs. If you fix a terminal sequence , the set of LR stack configurations from which can be consumed without error can be characterized by a finite-state device over parser states (this is closely related to the classical “viable prefixes are regular” result for LR parsing).
So you can imagine an automaton that recognizes “stacks on which token is valid.”
Of course, if you do that for every , you’d have 200k automata. That’s useless at runtime.
A weighted automaton is how you smash those 200k membership tests into one run:
- Transitions are annotated with “which tokens could accept if we take this transition.”
- Running it once on the stack gives you exactly: “which tokens’ would accept this stack.”
Same computation, but ‘vectorized’ across tokens via bitsets.
A nice optimization: stop reading the stack early
In practice you rarely need the full stack. As you scan states, each token’s fate becomes fixed: it either already reaches an accepting state (so it will stay valid) or it has been filtered out forever. Once a token is decided, pushing it deeper won’t change anything.
So maintain a “decided” set as you go. Peel those tokens off the frontier. When everything left is decided, you can stop immediately—no need to read the rest of the stack.
Runtime sketch
For a single LR stack, the runtime looks like:
def get_mask(stack_state_ids_top_to_bottom):
# frontier: map automaton_state -> bitset(tokens)
frontier = { A.start: ALL_TOKENS }
for sid in stack_state_ids_top_to_bottom:
new = {}
for a_state, tokens in frontier.items():
for (a2, weight) in A.step(a_state, sid):
tokens2 = tokens & weight # ∩ (filter)
if tokens2.any():
new[a2] = new.get(a2, EMPTY) | tokens2 # ∪ (merge)
frontier = new
if decided(frontier):
break
return combine_accepting(frontier)
This is the key operational point:
You are not:
- 🐢 iterating over the vocabulary
- 🐢 traversing a token trie driven by parser simulation
- 🐢 simulating big chains of reductions
- 🐢 reading further down the stack than you need to
You are doing:
- ⚡ a scan over the relevant portion of the stack/GSS, with ops that are mostly
- ⚡ transition-table lookups
- ⚡ bitset intersections
- ⚡ unions on merge
(I actually use range sets rather than bitsets since they’re faster, but the idea is the same.)
Runtime on a GSS
For a GSS, you do basically the same computation, but over a graph rather than a single list. Conceptually:
- each edge in the GSS is labeled by a parser state ID,
- each node represents a shared suffix,
- you propagate ∩/∪ through the product of:
- GSS structure (many stacks),
- automaton transitions.
A token is valid if it’s valid on any stack path, so whenever GSS paths merge you union their token sets, and whenever automaton paths merge you union there too.
The important part is: it’s still the same “bitset flows through a graph via ∩ and ∪” pattern. No backtracking, no “try token ” loop.
In practice, GSS nodes need to carry the propagated token sets (and lexer configurations), so merges combine values on nodes as well as on edges.
So where does the automaton come from?
Up to now I’ve treated the weighted automaton as a given.
Concretely, it’s a compiled weighted automaton built by composing a lexer-side token→terminal machine with parser-side stack-effect machines.
At a high level, you want an automaton that answers:
given the current lexer+parser state (represented by the stack/GSS), which LLM tokens could lead to a valid continuation?
There are two distinct problems mixed together:
- Lexical: what grammar terminals could a given LLM token’s bytes produce (given where we are in the lexer)?
- Syntactic: if we fed those terminals to the LR parser, would they be legal given the current stack?
The thing we compile is basically a composition of two automata:
- a token→terminal-sequences automaton (“Terminal DWA”),
- a terminal→stack-effect automaton (“Template automaton”),
- composed into a final Parser DWA that reads stack state IDs and outputs valid-token bitsets directly.
1) Token → terminals (Terminal DWA)
An LLM token is a byte string.
A grammar tokenizer/lexer consumes a stream of bytes and emits grammar terminals. Crucially:
- the lexer may be in the middle of recognizing a terminal,
- and longest-match (greedy matching) rules mean a match might be extendable, so incremental lexing can require keeping multiple active lexer configurations (more on this later).
So the mapping “token → terminal sequence” is not a single fixed lookup. It depends on the current lexer state.
The Terminal DWA is the precomputed structure that answers:
from lexer configuration , which terminal sequences does each LLM token produce?
A practical way to build it is:
- build a trie over all vocabulary token byte strings,
- simulate the grammar lexer over that trie,
- record (state, terminal, next_state) edges weighted by “which vocabulary tokens realize this edge.”
Then determinize.
Because it’s built over a finite token trie, the Terminal DWA is acyclic: you only move forward along token bytes. The key thing the Terminal DWA buys you is: it collapses “iterate over 200k tokens and run the lexer” into “traverse a small automaton state space and get bitsets.”
2) Terminal → stack effect (Template automata)
Now suppose the lexer says “the next terminal is .” What does the LR parser do?
- Sometimes it just shifts .
- Often it does a chain of reductions first (pop a bunch of states, push goto states), and then shifts .
- Those reduction chains depend on the top of the stack, but importantly: they depend on a bounded suffix of the stack in well-behaved grammars.
You can precompute, for each terminal , an automaton that reads stack state IDs (top down) and encodes the stack effects (pop/goto sequences) produced by any legal reduction path that ends in shifting . These are the “template automata” from Aycock et al., Even Faster Generalized LR Parsing.
I represent those stack effects using the polycyclic monoid: an algebra of stack operations where each push/pop is a generator and composition yields either a net stack effect or zero when operations mismatch. This makes composition tractable because you can concatenate effects and cancel immediately. The key rule is: a push of state followed by a pop of the same state cancels (), while a push of followed by a pop of a different state yields zero (). That cancellation is the main simplification.
So the template automaton for terminal encodes the possible stack effects (pop/goto sequences) that lead to a legal shift of while scanning a stack suffix.
With the grammar normalizations described in Aycock et al., these template automata can be made acyclic and therefore bounded in depth.
This is the piece that takes “pointer-heavy GLR reduction simulation” off the runtime path and turns it into precomputed transitions.
3) Compose them into a Parser DWA
Finally, you compose:
- Terminal DWA edges (which say “token could yield terminal from lexer state ”) with
- Template automata (which say “terminal is legal from stack suffix ”)
to get one deterministic weighted automaton that:
- takes a lexer configuration + a parser stack suffix as input,
- and outputs the set of LLM tokens that can extend the parse.
This is the automaton you run in get_mask.
Since both the Terminal DWA and template automata are acyclic, their composition—the Parser DWA—is also acyclic (and hence bounded in depth), and this means there’s a hard cap on how far down the stack you ever need to look.
That makes the bounded-suffix story concrete: the composed automaton is a DAG, so it can only read a bounded stack suffix. Two reasons support this: viable prefixes are regular (so an automaton over stack states makes sense), and Aycock et al.’s grammar normalizations make the per-terminal template automata acyclic, which propagates to the composition.
At runtime, the real lexer+parser still updates the stack inside commit; get_mask is only a validity test. That means the automaton doesn’t need to carry full push sequences once they’ve filtered tokens—it only needs enough stack-effect information to decide which tokens remain possible.
Tokenization ambiguity (and longest-match) in streaming generation
Recall that “token” is overloaded: the LLM tokenizer produces subword units from a vocabulary, while the grammar lexer produces terminals. These don’t align.
Even if your lexer uses longest-match, longest-match is inherently forward-looking: you can’t know a match is final until you see what comes next.
Classic example:
"+"followed by"+"should yieldINCREMENT("++"), not twoPLUS.
In a streaming setting, after you see the first "+", you’re in a state where:
- you have a valid match for
PLUS, - but it’s extendable to
INCREMENT.
So you have to represent that uncertainty somehow.
A standard incremental-lexing trick is to treat extendable matches as inhibited terminals:
- when a terminal matches but remains extendable, you fork:
- one branch commits now (but marks it “inhibited”: it will be invalidated if a longer match appears),
- another branch waits for more bytes.
- if later input yields a longer match, the premature branch is pruned.
This plays nicely with GLR+GSS, because “fork and prune” is already the parser’s native move.
Practically, the constraint state you carry around is not just “parser stack(s)”; it’s “(parser stack(s), lexer state(s))”. get_mask needs to condition on both, which is why the Parser DWA has multiple initial states (one per active lexer state).
Putting it together
commit(state, token) uses GLR machinery:
- feeds the token’s bytes into the grammar lexer incrementally,
- keeps ambiguity compact with GSS,
- can still succumb to pathological cases, but you at least get decades of research behind it.
get_mask(state) uses the precompiled weighted automaton:
- treats the current GSS as the “input”,
- propagates token bitsets through the automaton with ∩/∪,
- stops early when deeper stack symbols can’t change the result,
- returns a vocabulary mask.
So you get a split where:
commitis responsible for advancing the parser state with a chosen LLM tokenget_maskis responsible for cheaply answering “what LLM tokens could be next?”
Why this feels close to optimal
I’m deliberately avoiding implying there’s a “this is optimal” theorem. That’d be misleading. Parsing highly ambiguous CFGs isn’t a domain where you get many satisfying worst-case optimality results.
But it feels close to optimal for two reasons:
On the commit side: GLR/GSS is a strong local optimum
For incremental CFG parsing with ambiguity, GLR is a pretty hard baseline to beat. There’s a reason people working on these problems keep converging on GLR-like techniques:
- https://tratt.net/laurie/blog/2020/which_parsing_approach.html
- https://marijnhaverbeke.nl/blog/lezer.html
- https://tree-sitter.github.io/tree-sitter/creating-parsers/3-writing-the-grammar.html
- https://www2.eecs.berkeley.edu/Pubs/TechRpts/1997/CSD-97-946.pdf
You still inherit GLR’s lack of comforting worst-case bounds in “fully adversarial ambiguity” settings. In theory, GLR can degrade badly on highly ambiguous grammars/inputs. In practice, for JSON schemas and “programming-language-ish” grammars with reasonable disambiguation, it behaves well.
On the mask side: “read once, never backtrack” is what you want
The weighted-automaton mask computation does the minimum work you’d reasonably hope for:
- scan a bounded portion of the stack once,
- do predictable inner-loop ops (bitset ∩ and ∪),
- stop as soon as the result is determined,
- never iterate over the vocabulary,
- never do speculative parser reductions for hypothetical tokens.
In other words: the runtime work scales with the size of the parse configuration, not with vocabulary size and not with “how gnarly are the tokens.”
Fast run, slow compile
The cost shifts to compile time.
Precompiling the grammar into this Parser DWA involves determinization and simplification in a bitset weights setting (union at merges, intersection along paths). If you’re not careful, large grammars can blow up in memory/time.
Getting compile-time and memory to behave took most of my engineering effort:
- determinization/minimization order matters a lot,
- when you push bitsets around you must be careful about sharing/copying (interning helps a lot here),
- you need aggressive pruning/cancellation,
- you need to normalize certain grammar patterns to keep reduction chains bounded.
But that’s a topic for another time.
If you’re building constrained decoding and you care about p99.9 latency, my main takeaway is this:
Don’t put real parsing work in get_mask.
Keep real parsing in commit, and keep get_mask as a fast, precomputed filter. Everything else is just how much compile-time complexity you’re willing to pay to buy runtime certainty.