LLM Reasoning Series: Deep Dive into rStar-Math and Monte Carlo Tree Search
How a novel approach allows compact models to rival giants like OpenAI’s o1 through strategic “deep thinking.”
In this post, we will go over a paper called “rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking”, which can rival or even surpass the math reasoning capability of OpenAI o1 without distillation from superior models.
In a time when big AI models like GPT-4 are often in the news, this paper questions the idea that larger models are automatically better. It presents a new way for smaller language models (SLMs) to compete with, and even outperform, large models like OpenAI’s o1, without needing expensive training methods. Let’s take a closer look at how this team came up with new ways to solve problems for smaller models.
The Problem: Why Small Models Hit a Wall in Math
Large language models (LLMs) are great at solving complex problems, but they are expensive to run and not easy to access for everyday use. Smaller models have two main issues. First, their training data often contains subtle mistakes. Many math datasets come from larger models, which sometimes make errors in their steps. It’s like a student copying homework from a teacher who makes slight mistakes. These errors can add up and lead the student in the wrong direction. Second, smaller models usually take a “one-shot” approach to solving problems, similar to quick, instinctive thinking. While this can be fast, it doesn’t work well for complicated problems that need careful thinking and step-by-step reasoning. It’s like a student quickly writing an answer without checking it, which can lead to mistakes. This led the rStar-Math team to ask an important question: Can smaller models learn to think more deeply, explore different solutions, and fix their mistakes on their own without needing a lot of human help or computing power?
The Solution: Mimicking Human-Like “Deep Thinking”
The answer lies in three synergistic innovations, blending code execution, strategic preference learning, and a search algorithm borrowed from game theory.
- Code-Augmented Verification: No More Lucky Guesses
Every reasoning step is paired with executable code (e.g., Python or SymPy snippets). For instance, when solving `2x + 4 = 10`, the model doesn’t just write “Subtract 4 from both sides” — it generates:
# Step 1: Subtract 4 from both sides
2x = 10–4 # Executing this code verifies 2x = 6
Only steps with error-free code survive, filtering out guesswork. This mirrors how a meticulous student checks each equation iteration with a calculator, ensuring no misstep goes unnoticed.
2. Process Preference Model (PPM): Teaching “Good Habits”
Even correct steps can be suboptimal. The PPM acts like a seasoned tutor, rewarding methods that align with long-term success. For example, after simplifying to `2x = 6`, both “Divide by 2” and “Multiply by 0.5” are mathematically valid. However, the PPM favors division — a more intuitive and scalable strategy for learners. This nudges the model toward systematic reasoning, much like teachers prioritize foundational techniques over clever shortcuts.
3. Self-Evolution via Monte Carlo Tree Search (MCTS): Learning Through Exploration
Inspired by AlphaGo’s success, rStar-Math adapts MCTS — a decision-making algorithm that balances exploration and exploitation — for math problems. Here’s how it works in practice:
- Building the Tree: Starting with a root problem (e.g., `2x + 4 = 10`), the model generates multiple candidate actions (e.g., “Subtract 4,” “Divide by 2”). Each action becomes a branch.
- Simulating Outcomes: The algorithm explores paths, simulating solutions while the PPM evaluates each step’s quality. Successful paths (like solving for `x = 3`) are reinforced, while dead-ends (e.g., incorrect “Subtract 2” steps) are deprioritized.
- Backpropagation: Insights from successful simulations propagate backward, updating the tree to reflect which steps are most promising.
Example Problem: Solve 2x+4=102x+4=10
Assume the policy SLM generates 3 candidate steps (for simplicity) at each node, and we run 2 rollouts (iterations) to illustrate the process.
Rollout 1: First Exploration
1- Root Node (Step 0):
— Candidates:
- 2x+4=10→Subtract 42x+4=10→Subtract 4
- 2x+4=10→Divide by 22x+4=10→Divide by 2
- 2x+4=10→Guess x=32x+4=10→Guess x=3
2- Selection:
— The PPM and UCT formula select "Subtract 4" (highest initial score).
3- Expansion:
— Create a new node: Step 1 after subtracting 4: 2x=62x=6.
— Generate 3 new candidates for Step 1:
- 2x=6→Divide by 22x=6→Divide by 2
- 2x=6→Multiply by 0.52x=6→Multiply by 0.5
- 2x=6→Subtract 22x=6→Subtract 2 (incorrect step).
4- Simulation:
— Simulate from Step 1 (2x=62x=6) using the policy SLM:
- Select "Divide by 2" → x=3x=3.
- Code execution verifies x=3x=3 is correct.
5- Backpropagation:
— Update scores for all nodes in the path:
- Root node (Step 0): Q-value increases (since the path led to success).
- Step 1 node: Q-value increases.
Rollout 2: Second Exploration
1- Root Node (Step 0):
— Candidates (now with updated Q-values):
- "Subtract 4" (high score from Rollout 1).
- "Divide by 2" (unexplored).
- "Guess x=3x=3" (unexplored).
2- Selection:
— UCT balances exploration/exploitation. It picks "Divide by 2" (to explore a new path).
3- Expansion:
— Create a new node: Step 1 after dividing by 2: x+2=5x+2=5.
— Generate 3 new candidates:
- x+2=5→Subtract 2x+2=5→Subtract 2
- x+2=5→Multiply by 1x+2=5→Multiply by 1 (redundant)
- x+2=5→Add 3x+2=5→Add 3 (incorrect).
4- Simulation:
— Simulate from Step 1 (x+2=5x+2=5):
- Select "Subtract 2" → x=3x=3.
- Code execution verifies correctness.
5- Backpropagation:
— Update scores for the new path:
- Root node (Step 0): Q-value for "Divide by 2" increases.
- Step 1 node (new branch): Q-value increases.
Final Tree State
After 2 rollouts, the search tree looks like this:
Root (Step 0: 2x+4=10)
├── Path 1: Subtract 4 → Step 1 (2x=6) → x=3 (success)
└── Path 2: Divide by 2 → Step 1 (x+2=5) → x=3 (success)
Both paths lead to the correct answer, but their Q-values differ based on PPM rankings (e.g., "Subtract 4" might be preferred for clarity).
The following is another example from the paper:
Training: From Bootstrap to Self-Evolution
The model undergoes four evolutionary rounds:
1. Bootstrap: Initial training uses high-quality data from DeepSeek-Coder, a robust external model.
2. Refinement Rounds: The model tackles progressively harder problems (e.g., Olympiad-level questions). MCTS generates improved solutions, which then train better versions of the model and PPM. This creates a virtuous cycle — better data → better models → better data.
Results: Punching Above Their Weight
- MATH Benchmark: A 7B-parameter model achieved 90% accuracy, outperforming OpenAI’s o1-preview (85.5%).
- USA Math Olympiad (AIME): Solved 53.3% of problems, rivalling top high school competitors.
- Self-Correction: The system developed an intrinsic ability to backtrack from errors, akin to a student realizing a miscalculation mid-problem.
Challenges and Future Horizons
While promising, rStar-Math isn’t flawless. MCTS demands heavy GPU resources, limiting accessibility. Geometry problems — reliant on visual/spatial reasoning — remain a hurdle, much as they do for humans without diagrams. Future work may explore integrating multimodal inputs or optimizing MCTS efficiency.
Conclusion: Smarter, Not Just Bigger
rStar-Math demonstrates that strategic design can elevate small models to elite performance. By emulating human-like deliberation — exploring multiple pathways, verifying steps, and learning from mistakes — it challenges the narrative that AI progress hinges on scale. This isn’t just a leap for math reasoning; it’s a blueprint for democratizing advanced AI, proving that ingenuity can trump brute computational force.