Constrained decoding with weighted automata

This is a take on grammar-constrained generation I've been working on for a while.

Grammar constraints: What are they and why does performance matter?

When I say "grammar constraints," I mean the setup where the model only sees tokens that keep the output valid. At each step you compute a mask and zero out anything that would violate a grammar. That's how JSON Schema-guided output works, how you keep SQL generation syntactically sane, and how code completion can be forced to stay inside a language's rules. The payoff is obvious: structured output you can reliably parse without a cleanup pass.

The catch is that this mask sits directly on the decoding path. It has to run every token, which is why performance becomes the whole game.

LLM inference is batched: multiple user requests run on the GPU at once. And the GPU is kept on a tight schedule.

You might think: if the mask is late, you could sample a token without it. It might work out fine. But what if that token is invalid under the grammar? Well oopsie, now you've just broken the constraint, and when some code tries to parse your JSON they'll find it does not parse.

If you want to generate 1k tokens per second, you have 1 millisecond to produce a mask. 1 mask every millisecond, 1000 times a second. And if you miss even one—if a single mask takes 2ms instead of 1—you've wasted expensive GPU time.

The mask is on the critical path, and it's worst-case performance that matters. Not p95 — p99.9.

The two primitive operations

I think of it as two functions:

At runtime the loop looks like:

loop:
  parallel:
    logits = llm.step()          # GPU
    mask = constraint.get_mask() # CPU

  logits = apply_mask(logits, mask)
  token = sample(logits)
  llm.commit(token)
  constraint.commit(token)

Obviously I am simplifying a bit here—tokenizer state matters too, but this is the basic idea.

Commit == incremental parsing

Commit is just incremental generalized LR (GLR) parsing. The parser receives consecutive chunks of input and updates its internal state as they come. Ambiguity is represented in a graph-structured stack (GSS), which keeps the state compact.

Computing the mask

You have a GSS encoding the current parse state, and you need a mask over LLM tokens. How would we do it?

A naive implementation might check each token by feeding its terminal sequence (or sequences) into the GSS and seeing whether it can be accepted. But this has... issues.

Fun fact: The longest non-space token in c200k (GPT-5's tokenizer) is ----------------------------------------------------------------------------------------------------------------. That's 112 dashes, for those counting. And in JavaScript, this is perfectly valid input. To make matters worse, since both - and -- are valid terminals (the latter being the decrement operator), the number of ways to tokenize it is the 113th Fibonacci number—about 102310^{23}.

But wait, there's more. Hundreds of thousands more. As its name suggests, c200k consists of around 200k. They're not all as bad as our 112-dash friend above, but it's a lot.

To make this sane, most constraint libraries build a trie or trellis representing all possible tokenizations of all LLM tokens in a compact way. To distinguish between LLM tokens, each edge in the tree can be conditioned on a set of LLM tokens.

Unfortunately, even that's rarely enough. Each token can invoke many, many reductions before it shifts. So, even if you're only processing 300 tokens in a single mask step, that can sum up to thousands of GSS operations, which aren't cheap. Plus the GSS fragments over time, so you have to simplify it continually, which costs even more.

All this is to say that the naive approach of iterating over the vocabulary and all possible tokenizations is a non-starter.

I’ve tried to make this route fast, and I keep coming back to the same conclusion: running the GLR parser over a trellis for mask generation is a dead end.

Weighted automata for the mask

Token validity is a function of the stack. Instead of feeding terminals into the GSS per token, read the stack directly and filter tokens in one pass.

A weighted automaton is like a finite automaton, but each transition carries a weight that accumulates as you traverse. Here the weights are bitsets of tokens. Traversing a transition intersects the token set; merging paths onto one state takes their union.

Weighted automata are well-studied, with known determinization and minimization strategies.

live = ALL_TOKENS
accepted = EMPTY
state = start
for sym in stack_top_to_bottom:
  live &= transition_bits[state, sym]
  state = next_state[state, sym]
  accepted |= accept_bits[state] & live
  live &= ~accepted
  if live is empty:
    break
return accepted

Why this feels close to optimal

On the commit side, GLR-style techniques are a strong fit for fast incremental context-free grammar (CFG) parsing. They use table-driven decisions and represent ambiguity with a compact GSS. This is the family of techniques behind incremental parsers like tree-sitter. If you want a parser that updates quickly as new tokens arrive, this design is a natural place to start.

On the mask side, the bitset-weighted automaton performs the minimum kind of work you can reasonably hope for: it reads the stack once, never backtracks, and stops as soon as the decision is determined. It reads only as deep as needed to decide the mask. The computation is driven purely by the information the stack actually contains, and it terminates as soon as that information is sufficient.

The hard part in practice

This is the part that took most of the engineering time for me. The idea is simple; getting it to behave on large grammars took a bag of tricks to avoid exponential blowups—especially around determinization and minimization—and to get compile times down.

Maybe I'll write about that next.