MEDUSA: Making Large Language Models Generate Text Faster

Isaac Kargar
5 min readJan 1, 2025

--

Introduction

Large Language Models (LLMs) are powerful but slow because they generate text one token at a time. Traditional solutions like speculative decoding use a smaller model to predict tokens and a larger model to verify them. MEDUSA takes a different approach — instead of using two models, it adds extra prediction heads on top of the main model to generate multiple tokens simultaneously. Think of these heads as specialized predictors that work together with the main model, making generation faster while maintaining quality.

source

Training

MEDUSA’s training process is fascinating in its simplicity. Imagine teaching multiple students to complete sentences, but each student specializes in predicting words at different positions. That’s essentially what MEDUSA does with its heads.

The model starts with a base LLM and adds several prediction heads on top. Each head has a specific job — while the main model predicts the next word, the first MEDUSA head learns to predict two words ahead, the second head three words ahead, and so on. For example, given “2025 would be,” while the main model learns to predict “a,” the first MEDUSA head learns “great,” the second “year,” and the third “for.”

MEDUSA offers two training approaches. The simpler version, MEDUSA-1, keeps the base model frozen and only trains the new heads. It’s like keeping an experienced teacher unchanged while training new assistant teachers. The more advanced version, MEDUSA-2, trains everything together, allowing the base model and heads to learn from each other.

What’s particularly clever is how they handle cases where the original training data isn’t available. They use a self-distillation approach — having the model generate its own examples and then using those to train the heads. It’s like having a master chef teach apprentices by first creating dishes and then using those dishes as teaching examples.

The training process gives decreasing importance to predictions further in the future, acknowledging that predicting the immediate next word should be more accurate than predicting five words ahead. This weighted approach helps maintain prediction quality while still benefiting from the speed of parallel generation.

Inference

The inference process is where MEDUSA truly excels. Imagine it as a team of forecasters collaborating to predict future events. However, instead of weather, they’re predicting text continuations.

Understanding MEDUSA’s Forward Passes

Let’s explore a detailed example to see how MEDUSA processes forward passes:

Input: “The future of artificial intelligence will”

First Forward Pass:

  • Main model predicts: “transform”
  • MEDUSA head 1 predicts: “the”
  • MEDUSA head 2 predicts: “world”
  • MEDUSA head 3 predicts: “in”

If all tokens are accepted, the updated context becomes: “The future of artificial intelligence will transform the world in.”

Second Forward Pass: From the updated context, the process repeats. Instead of requiring four separate forward passes (one for each token), MEDUSA accomplishes this in a single pass.

This parallel prediction drastically reduces the number of forward passes needed. Traditional LLMs require 20 forward passes to generate 20 tokens. In contrast, MEDUSA only needs five passes if each successfully predicts and accepts four tokens.

Tree-Based Prediction and Attention

source

The tree construction process further enhances efficiency by retaining multiple top predictions for each position:

Context: “The future of AI will”

  • Level 1 (Main Model top-2): “transform”, “change”
  • Level 2 (Head 1 top-2): “the”, “our”
  • Level 3 (Head 2 top-2): “world”, “industry”
  • Level 4 (Head 3 top-2): “in”, “by”

The tree-based attention mechanism is key to this efficiency. Instead of processing each path independently, it employs an attention mask that restricts tokens to only attend to their predecessors within the same path. For instance:

  • When verifying “the” after “transform,” it references only “transform” and the original context.
  • When verifying “world,” it considers “transform the” and the original context.

This approach allows MEDUSA to verify multiple potential continuations simultaneously, avoiding the inefficiency of individually checking each path.

Efficient Verification and Acceptance

MEDUSA’s verification process is equally ingenious. Rather than relying on a separate model, it uses the tree-based attention mechanism to assess all predicted tokens collectively, ensuring their coherence as a group. This is akin to a quality control system inspecting multiple items at once instead of one by one.

To further enhance efficiency, MEDUSA introduces a “typical acceptance” scheme for selecting predictions. Instead of demanding exact matches with what the model would have generated token by token, it accepts predictions that are sufficiently probable. This approach balances generation speed and output quality by tolerating slight deviations.

Iterative Process

Once predictions are verified, the accepted tokens are added to the context. If MEDUSA accepts four tokens in one pass, the next prediction starts from this extended context, eliminating the need for multiple forward passes.

By combining parallel prediction, tree-based attention, and a pragmatic acceptance scheme, MEDUSA delivers efficient and high-quality text generation that outperforms traditional models in both speed and coherence.

source

Results

source

The performance improvements MEDUSA achieves are impressive. Without compromising generation quality, MEDUSA-1 achieves a 2.2× speedup over standard LLM inference. MEDUSA-2 pushes this further to 2.3–2.8× faster generation.

What’s particularly interesting is how MEDUSA performs differently across various tasks. It shows exceptional performance in coding tasks, achieving a 3.29× speedup, and data extraction tasks, reaching a 3.62× speedup. This suggests that tasks with more structured or predictable outputs benefit most from MEDUSA’s parallel prediction approach.

The results are especially impressive considering that these improvements come without requiring a separate draft model. This means simpler deployment, less memory usage, and easier integration into existing systems. The paper demonstrates these speedups across different model sizes, from 7B to 33B parameters, showing MEDUSA’s scalability.

Most importantly, these speed improvements come without sacrificing generation quality. The paper shows comparable or even slightly improved quality scores on benchmark tests, suggesting that MEDUSA’s parallel prediction approach might actually help the model maintain better coherence across longer sequences of text.

--

--

Isaac Kargar
Isaac Kargar

Written by Isaac Kargar

Co-Founder and Chief AI Officer @ Resoniks | Ph.D. candidate at the Intelligent Robotics Group at Aalto University | https://kargarisaac.github.io/

No responses yet