Monte Carlo Tree Search LLM
A post on how to combine an LLM with the Monte Carlo Tree Search (MCTS) algorithm. There are numerous versions of MCTS, I follow Planning with Large Language Models for Code Generation’s implementation with the following additions:
- Allow step-level expansions instead of just token-level (for more discussion on token vs. step level, see section 3 of this paper).
- Pass generic callable’s for the candidate generation, simulation, and reward.
Code is available here which is heavily inspired from this excellent blog post.
MCTS
First, I outline appendix D.1 PG-TD from Planning with Large Language Models for Code Generation which covers the adapted MCTS algorithm.
Select
- Starting from the root node (initialized as an empty string: “” or a prompt: “The dog ran”), recursively select subtrees until finding a node that has not previously been expanded.
- Each node maintains a cache $Q(s, a)$ which is the maximum reward (could also be the average reward) obtained by starting from a state $s$ and taking action $a$.
- Selection is defined as:
Using the selection criterion $\text{P-UCB}$ defined as:
\[\text{P-UCB}(s, a) = Q(s, a) + \beta(s) \cdot P_{\text{transformer}}(a | s) \cdot \frac{\sqrt{\log{s.visits}}}{1 + s'.visits}\]where $s$ is the parent state, $s’$ is the new state after taking action $a$ from $s$, and $\beta(s)$ is computed as:
\[\beta(s) = \log{\frac{s.visits + c_{\text{base}} + 1}{c_{\text{base}}}} + c\]And finally $c$ is a constant that encourages exploration.
Expansion
- Once at the selected node, add children to this node with a set of candidate actions.
- The top-k candidate actions are taken from the conditional distribution produced by the transformer:
- In the case of token-level MCTS, we only sample the top-k tokens conditioned on the selected node’s state.
- For step-level MCTS, we can sample an entire path (multiple steps of tokens), recording the top-k candidates at each step. For example, we may sample:
1
2
3
4
5
6
7
8
9
10
11
Parent node's state: The dog ran
===
Expansion token 1: {very, extremely, slightly}
Expansion token 2 | very: {fast, slow, randomly}
===
1st candidate node's state: The dog ran very
2nd candidate node's state: The dog ran extremely
3rd candidate node's state: The dog ran slightly
4th candidate node's state: The dog ran very fast
5th candidate node's state: The dog ran very slow
6th candidate node's state: The dog ran very randomly
This may vary depending on the LLM’s generation API, but the above is compatible for any model returning the top-k candidate tokens at each generation step.
Evaluation
- Conduct a beam search starting from the selected node. This has the effect of a “simulation” where we try to complete the program, translation, or statement from the current node.
Evaluation is done on the current selected node, not the candidate nodes created in the expansion step. Beam search is used loosely as the beam width is set to 1 (also known as hill climbing).
- Compute a reward using the completed evaluation. This can either be a deterministic scoring (compiler passes, math proof is correct) or a score from an LLM judge.
Backpropagation
- Computed reward is backpropagated recursively back to the root node using the update:
Each iteration leaves the algorithm in the state where $Q(\tilde s, \tilde a)$ represents the best possible reward achievable from state $s$ taking action $a$.
Translation Example
In this demo notebook, I show how MCTS can be used to improve translation from Chinese text to English:
1
2
3
4
5
6
7
8
9
import numpy as np
import pandas as pd
from src.mcts import mcts
from src.node import MCTSNode
from src.open_ai import get_candidates_fn, get_simulation_fn, get_reward_fn
from IPython.display import display, Markdown
from src.utils import create_graph_html
np.random.seed(1)
1
2
3
4
5
6
# Chinese text we want to translate (from Twenty Thousand Leagues Under the Sea).
CHINESE = "这事大体是这样:不久以前,好些大船在海上碰见了一一个“庞然大物”,一个很长的物体,形状很像纺锤,有时发出磷光,它的体积比鲸鱼大得多,行动起来也比鲸鱼快得多。"
# Official translation (as given by http://bilinguis.com/book/verne20k/en/zh/p1c1/).
ORIGINAL = 'In essence, over a period of time several ships had encountered "an enormous thing" at sea, a long spindle-shaped object, sometimes giving off a phosphorescent glow, infinitely bigger and faster than any whale.'
# The translation given by google translate.
GOOGLE_TRANSLATE = 'The story goes something like this: Not long ago, a number of large ships encountered a "monster" at sea, a very long object, shaped like a spindle, sometimes emitting phosphorescence. It was much larger than a whale and moved much faster than a whale.'
1
2
3
4
5
6
7
8
# This few shot prompt will be used for expansion and simulation.
generation_prompt = f"""Chinese text needs to be translated into English.
- Do not provide any context or description, just the translation.
- A user will start the translation. Complete the translation without repeating what has already been translated.
Translate the following:
{CHINESE}
"""
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# This prompt will be used for calculating the reward with an LLM judge.
example_1 = '{"completeness": 1.0, "correctness": 0.6, "elegance": 0.5}'
example_2 = '{"completeness": 1.0, "correctness": 0.95, "elegance": 1.0}'
reward_prompt = f"""Provide scores between 0 and 1 of how well the english has been translated from Chinese. Respond in json format for the following keys:
- 'correctness' value between 0 and 1 - if each pinyin token correctly translates into english tokens.
- 'brevity' value between 0 and 1 - if there is no redundancy in the translation.
- 'elegance' value between 0 and 1 - if the translation matches the original prose and is pleasurable to read.
Example:
Pinyin: shuǐ dī shí chuān.
English: Dropping water can penetrate the stone, sometimes.
Response: {example_1}
Chinese: 學而時習之,不亦悅乎?
English: To learn and to practice what is learned time and again is pleasure, is it not?
Response: {example_2}
Translate the following:
{CHINESE}
"""
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Run MCTS and visualize the algorithm's history.
root, node, history = mcts(
get_candidates_fn=get_candidates_fn(
prompt=generation_prompt,
# Consider candidates which add at most 5 tokens.
max_completion_tokens=5,
# Consider 3 alternatives at each step.
top_logprobs=3,
# Consider candidates with at least 3 tokens.
minimum_candidate_token_length=3,
),
get_simulation_fn=get_simulation_fn(
prompt=generation_prompt,
# Do not limit how far we simulate.
max_completion_tokens=None,
),
get_rewards_fn=get_reward_fn(prompt=reward_prompt),
# Number of total MCTS iterations. Each iteration will have a expansion, simulation, and reward API call.
max_rollouts=16,
# exploration constant
c=5.0,
# Print out the logging.
verbose=True,
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{'actions': [['root']]}
{'step': 0, 'actions': [['expansion'], ['simulation'], ['reward']], 'reward': 0.917}
{'step': 1, 'actions': [['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 2, 'actions': [['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 3, 'actions': [['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.933}
{'step': 4, 'actions': [['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 5, 'actions': [['selection'], ['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.917}
{'step': 6, 'actions': [['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.933}
{'step': 7, 'actions': [['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.967}
{'step': 8, 'actions': [['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 9, 'actions': [['selection'], ['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 10, 'actions': [['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 11, 'actions': [['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 12, 'actions': [['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 13, 'actions': [['selection'], ['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.917}
{'step': 14, 'actions': [['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
{'step': 15, 'actions': [['selection'], ['selection'], ['selection'], ['expansion'], ['simulation'], ['reward']], 'reward': 0.95}
1
2
3
4
# Find the simulations ordered by their reward value.
simulations = sorted(
[v for i in history for k, v in i.items() if k == "reward"], key=lambda x: -x[-1]
)
1
2
3
4
# See what a one-shot generation without MCTS would give.
one_shot = get_simulation_fn(prompt=generation_prompt)(
MCTSNode(prob=0.0, state="", type="one_shot")
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Compare the various translations.
top = 3
display(
Markdown(
pd.DataFrame(
{
"type": [
"Chinese",
"Author's Translation",
"Google Translate",
"One Shot",
]
+ [f"MCTS #{i + 1}" for i in range(top)],
"text": [
CHINESE,
ORIGINAL,
GOOGLE_TRANSLATE,
one_shot,
]
+ [i[0] + " " + i[1] for i in simulations[:top]],
}
).to_markdown()
)
)
type | text | |
---|---|---|
0 | Chinese | 这事大体是这样:不久以前,好些大船在海上碰见了一一个“庞然大物”,一个很长的物体,形状很像纺锤,有时发出磷光,它的体积比鲸鱼大得多,行动起来也比鲸鱼快得多。 |
1 | Author’s Translation | In essence, over a period of time several ships had encountered “an enormous thing” at sea, a long spindle-shaped object, sometimes giving off a phosphorescent glow, infinitely bigger and faster than any whale. |
2 | Google Translate | The story goes something like this: Not long ago, a number of large ships encountered a “monster” at sea, a very long object, shaped like a spindle, sometimes emitting phosphorescence. It was much larger than a whale and moved much faster than a whale. |
3 | One Shot | This matter is generally as follows: Not long ago, several large ships encountered a “colossal being” in the sea, a very long object that was spindle-shaped, sometimes emitting phosphorescence. Its size was much larger than that of a whale, and it moved much faster than a whale. |
4 | MCTS #1 | This matter is generally as follows: Not long ago, several large ships encountered a “colossal creature” at sea, a very long object that resembled a spindle, sometimes emitting phosphorescent light. Its size was much larger than that of a whale, and it moved much faster than a whale. |
5 | MCTS #2 | This matter is generally like this: not long ago, several large ships encountered a “colossal creature” in the sea, a very long object that was spindle-shaped, sometimes emitting phosphorescence. Its size was much larger than that of a whale, and it moved much faster than a whale. |
6 | MCTS #3 | This matter is generally like this: not long ago, several large ships encountered a “colossal creature” in the sea, a very long object that was spindle-shaped, sometimes emitting phosphorescence. Its size was much larger than that of a whale, and it moved much faster than a whale. |
1
2
# Visualize the trees with pyvis.
create_graph_html(root=root, filename="graph.html", height="300px")
1
graph.html
Conclusion
MCTS is a tree search algorithm that is capable of searching large action spaces. It uses a select-expand-simulate-backpropagate pattern which is highly customizable for different use cases. As LLM performance improves with time, MCTS can always be used at inference time to build a diverse set of generations that mimic human reasoning.