Skip to main content
  1. Posts/

Speculative Decoding: 2x to 4x speedup of LLMs without quality loss

·10 mins·
Machine Learning NLP LLM Speculative Decoding Transformers Inference Optimization
Table of Contents

TL;DR
#

Speculative decoding uses a small draft model or heuristic to propose a continuation of generated tokens, and the LLM scans the suggestions in parallel to verify the suggestions, or quickly correct them. The introduced parallelism makes it much faster than token-by-token generation, and unlike other optimizations, it does not sacrifice quality.

The Problem
#

Large language models are auto-regressive - they generate the next token by taking the past tokens as input at each step one at a time, which is inherently slow. It’s not primarily because the mathematical computations (the FLOPs) for predicting a single token are overwhelming. Instead, the critical bottleneck is memory bandwidth.

Modern LLMs can occupy hundreds of gigabytes or even terabytes of memory. These parameters typically reside in High-Bandwidth Memory (HBM) associated with the accelerator (like a GPU or TPU). However, to perform the computations needed to predict the next token, these parameters must be loaded into the much smaller, faster on-chip memory (cache or SRAM) of the compute units.

In autoregressive decoding, this massive data transfer - loading potentially terabytes of weights from HBM to cache - happens for every single token generated. This memory movement completely dominates the time taken for each step. The powerful compute cores of the accelerator end up spending most of their time waiting for data to arrive, leading to:

  • Severe underutilization of compute resources
  • Low arithmetic intensity (the ratio of computational operations to memory access)

Performance Comparison

The LLM decoding operation is memory-bound as described above, and hence making compute units faster will not make the inference faster proportionally. Any truly effective acceleration strategy for the decoding phase must find a way to reduce the number of these costly memory-transfer cycles required per generated token. Hence, we would need to devise a way to parallelize an inherently sequential process.

The Solution
#

Speculative Decoding (SD) does exactly that - it speeds up the sequential decoding by generating multiple tokens for roughly the cost of one target model memory load cycle, and as a result, it significantly improves the efficiency of the decoding process.

It uses a draft then verify approach:

  • Drafting produces multiple next tokens using a faster/smaller model
  • Verification checks these tokens in parallel in the target/large LM

The key players:

  • The Target Model (Mq​): This is the large, powerful, and accurate LLM whose output distribution we want to replicate exactly, but is very slow.
  • The Draft Model (Mp​): This is a significantly smaller, and therefore much faster, language model or n-gram model. It might be a smaller version from the same model family, a distilled version, or even a different architecture altogether. Its role is not to be perfectly accurate, but to quickly generate plausible candidate tokens (the “draft”).

The core idea relies on two key observations:

  • Many tokens are “easy”: Not all tokens in a generated response need the full power of a huge LLM - common phrases, grammatical structures, or repetitive sequences can be correctly generated by a much smaller draft model with high confidence. The large model must only be used for more nuanced generation tasks such as reasoning or factual recall - the “hard” tokens.

  • Parallel Verification is Cheap: While generating one token with the large target model Mq​ is slow due to the memory bandwidth bottleneck, verifying multiple candidate tokens (say, K draft tokens plus the original context) in a single, parallel forward pass through Mq​ takes roughly the same amount of time. This is because the dominant cost – loading the model weights – happens only once for the entire verification batch.

Speculative decoding leverages these observations. If the draft model’s predictions align well with what the target model would have generated, multiple tokens can be accepted in a single step, effectively amortizing the cost of the expensive target model inference over several tokens.

Draft vs Target Model

The drafting part is straightforward - you just generate K tokens from the draft model auto-regressively. The verification part, intuitively works for transformers because it bypasses the causal dependency in auto-regressive decoding by computing the logits of a sequence at once. Let’s look at some familiar code to clarify -

Auto-Regressive Decoding (Single Token)
#

from transformers import AutoModelForCausalLM, AutoTokenizer

# Setup model and tokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Prepare input context
context = "The cat sat"
input_ids = tokenizer(context, return_tensors="pt").input_ids

# Generate single token prediction
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :] # Logits for the next token

Speculative Decoding (Draft + Verify)
#

# Step 1: Draft Model - Generate candidate tokens
drafted_tokens = ["on", "the", "mat"]
drafted_sequence = context + " " + " ".join(drafted_tokens)

# Step 2: Prepare full sequence for verification
full_input_ids = tokenizer(drafted_sequence, return_tensors="pt").input_ids

# Step 3: Verify with target model (single forward pass)
outputs = model(full_input_ids)
logits = outputs.logits # Shape: [batch_size, seq_len, vocab_size]

# Step 4: Extract verification probabilities for each drafted token
draft_start = input_ids.shape[1] # First drafted token position

for i, token in enumerate(drafted_tokens):
    position = draft_start + i - 1 # logits at position predict next token
    token_id = tokenizer.convert_tokens_to_ids(token)
    
    # Log probability the verifier would assign to the draft token
    prob = logits[0, position, token_id].softmax(dim=-1).item()
    print(f"Verifier probability for '{token}': {prob}")

As you can see, we can directly get the assigned probabilities for each token at once, and hence compute the overall probability of the sequence in a single pass.

The last step is the acceptance/rejection token-by-token verification that decides which tokens to keep or reject from the draft - we simply follow the above idea, where we have the verifier probability of each drafted token, and check if it is the highest top token according to the verifier logits. This is the simplest acceptance/rejection method, called top-1 verification. In practice, we would use a generalized speculative decoding algorithm with Metropolis-Hastings-style acceptance criteria (often called “adaptive speculative decoding” or “accept/reject sampling”).

# Top-1 Verification Implementation

# Setup for token verification
accepted_tokens = []
draft_start = input_ids.shape[1] # First position where drafts start

# Step 1: Iterate through each drafted token for verification
for i, token in enumerate(draft_tokens):
    # Position in logits where the verifier predicts the NEXT token
    position = draft_start + i - 1 # position -1 because logits predict the next token
    
    # Step 2: Get verifier's top prediction
    verifier_top_token_id = logits[0, position, :].argmax().item()
    verifier_top_token = tokenizer.convert_ids_to_tokens([verifier_top_token_id])[0]

    # Step 3: Get drafted token ID for comparison
    drafted_token_id = tokenizer.convert_tokens_to_ids(token)

    # Step 4: Accept/reject based on exact match with top prediction
    if verifier_top_token_id == drafted_token_id:
        accepted_tokens.append(token)
        print(f"ACCEPTED token '{token}' at position {i+1}")
    else:
        print(f"REJECTED token '{token}' at position {i+1}")
        break # Stop at first mismatch

Rejection Sampling Approach
#

In this, instead of just checking “is the token the same as verifier’s top prediction?” you give some probability of accepting even if the draft and verifier disagree - depending on how much they disagree (probabilistic acceptance).

Metropolis-Hastings Acceptance Criteria
#

  1. Pick a random number “r” from uniform distribution [0,1]
  2. For each token, get draft logit/probability “p_t” and the verifier logit/probability “q_t”
  3. Accept draft token if r < min(1, p_t/q_t) - essentially if the draft probability is much lower than the verifier probability, the chances of acceptance drop. This is where we see how much the models agree with each other.

The following code illustrates the approach, along with how we then correct the remaining draft tokens if a token is rejected before reaching the end of current draft tokens.

The correction happens by sampling from the residual distribution max(0, (q(x) - p(x))). Intuitively, we want to do this to ensure that we ignore all the remaining token choices where the draft model logits > verifier logits.

For example, if -

Draft model p(x) says:

  • “cat”: 0.4
  • “dog”: 0.3
  • “mouse”: 0.3

Verifier q(x) says:

  • “cat”: 0.2
  • “dog”: 0.5
  • “mouse”: 0.3

Draft model picked “cat”, but it got rejected. Now, we calculate residuals:

  • residual(cat) = max(0, 0.2 - 0.4) = 0
  • residual(dog) = max(0, 0.5 - 0.3) = 0.2
  • residual(mouse) = max(0, 0.3 - 0.3) = 0

In this case, “dog” remains the only option. If you just pick from q(x) again, you might resample a token the draft model already explored and proposed, wasting compute and possibly biasing the results. There is an edge case, where the residual might become all < 0, in which case the algorithm must fallback to sampling from q(x) - I have no remaining unaccounted-for probability, so I’ll just sample from the verifier’s full distribution as if no draft proposal happened.

import random
import torch

# Setup for Metropolis-Hastings acceptance criteria
accepted_tokens = []
draft_start = input_ids.shape[1]

# Step 1: Process drafted tokens with probabilistic acceptance
for i, token in enumerate(draft_tokens):
    position = draft_start + i - 1

    # Step 2: Get probability distributions from both models
    verifier_probs = logits[0, position, :].softmax(dim=-1)
    draft_token_id = tokenizer.convert_tokens_to_ids(token)

    # Extract specific token probabilities
    q_t = verifier_probs[draft_token_id].item() # verifier prob
    p_t = draft_probs[i] # you must get this from the draft model when proposing tokens

    # Step 3: Apply Metropolis-Hastings acceptance rule
    r = random.uniform(0, 1)
    acceptance_threshold = min(1, p_t / q_t)

    if r < acceptance_threshold:
        accepted_tokens.append(token)
        print(f"ACCEPTED token '{token}' with r={r:.4f}, threshold={acceptance_threshold:.4f}")
    else:
        print(f"REJECTED token '{token}' with r={r:.4f}, threshold={acceptance_threshold:.4f}")
        break # Stop at first rejection

# Step 4: Handle rejection with residual sampling
if len(accepted_tokens) < len(draft_tokens):
    position = draft_start + len(accepted_tokens) - 1

    # Calculate and normalize residual distribution
    residual = (verifier_probs - draft_probs_tensor).clamp(min=0)
    residual /= residual.sum() # Normalize

    # Sample from residual distribution
    next_token_id = torch.multinomial(residual, 1).item()
    next_token = tokenizer.convert_ids_to_tokens([next_token_id])[0]
    accepted_tokens.append(next_token)

print("Final accepted sequence:", accepted_tokens)

Bonus Token
#

If all the draft model tokens are accepted, indicating that the draft and target model are aligned, we can actually sample an additional token from the target model since we are already generating K logits from the model, giving us K+1 tokens - this gives us a bonus token for free! Further maximizing efficiency.

KV Cache Efficiency
#

A critical aspect for efficiency, especially during the parallel verification step, is managing the Key-Value (KV) cache. Standard autoregressive decoding maintains a cache of activations from previous steps to avoid recomputing them. In speculative decoding, the target model processes multiple potential future states simultaneously. A naive implementation might require replicating the KV cache for each potential path, leading to excessive memory usage.

Optimized inference engines address this challenge. Techniques like PagedAttention allow sharing parts of the KV cache between different sequences in the verification batch. This works similarly to virtual memory paging in operating systems, dividing the cache into blocks that can be shared and managed efficiently, preventing memory bloat and maintaining throughput even at larger batch sizes.

Choosing Draft and Target Models
#

Usually draft models can be:

  • A smaller model in the same LLM family (Gemma 2B -> Gemma 9B). Additionally, the smaller models can be further aligned through knowledge distillation.

  • N-grams: Proposals are drawn by matching up to a chosen size n-grams from the prompt itself, effectively reusing previously seen text patterns.

  • Custom MLP‐based speculator networks condition on both the context vectors and sampled tokens to predict future tokens, enabling learned proposal distributions beyond simple draft‐model outputs.

  • EAGLE (Extrapolation Algorithm for Greater Language‐model Efficiency) draft models to extrapolate continuations, combining algorithmic lookahead with draft‐model verification for efficient, lossless sampling.

Key Hyperparameters
#

  • K (number of draft tokens):

    • Number of tokens proposed by the draft model per iteration
    • Important because it affects the balance of more accepted tokens vs. higher draft cost & rejection probability
    • Optimal often 3-5. Depends on draft speed/accuracy
  • Draft Model Size/Arch:

    • Choice of the smaller model (parameters, architecture)
    • Directly impacts draft latency (critical) and baseline acceptance rate
    • Aim for low latency. 10-20x smaller than target or use specialized architectures
  • Draft Model Alignment:

    • How well the draft model predicts the target model’s distribution
    • Primary driver of acceptance rate. Higher alignment = higher acceptance = more speedup
    • Use models from same family, or apply Knowledge distillation
  • Temperature:

    • Controls randomness during sampling
    • Higher temp -> lower predictability -> lower acceptance rate -> lower speedup
    • Performance often peaks at low/mid temps. Align KD temp with inference temp

Speculative Decoding Process

References
#

  • Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. ICML.
  • Chen, X., et al. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. arXiv preprint.
  • Spector, B., & Murray, K. (2023). Accelerating LLM Inference with Staged Speculative Decoding. NeurIPS.
Reply by Email

Related

About Me 🇨🇦 💻 🌇
·1 min