A connection between contextual bandit exploration in infinite action spaces and LLM-RL algorithms leveraging the Gibbs policy representation.
Infinite action contextual bandits
Contextual bandits are a special case of RL where you receive a context $x$, play an action $a$, and receive a reward $R(x, a)$. The distribution of $R$ depends only upon $(x, a)$, so the universe is memoryless and the entire context is $x$: you don’t have to worry about the more difficult aspects of RL, such as taking actions now to improve prospects for future rewards, or remembering things from the past that could affect the current reward.
It turns out a simple exploration algorithm works well in infinite action spaces: you maintain a reward estimator $\hat{r}(x, a)$, and then you rejection sample an action as follows
- Observe context $x$.
- Compute $\hat{a} = \arg \max \hat{r}(x, a)$.
- Sample $a \sim \mu$. Here $\mu$ is a prior distribution.
- Accept $a$ with probability $\left(1 + h \gamma \left(\hat{r}(x, \hat{a}) – \hat{r}(x, a)\right)\right)^{-1}$. This is a type of “inverse gap weighting” strategy. Here $\gamma$ is like a learning rate, and $h$ specifies a type of trust region (see below). In practice $h \gamma$ are combined into one scalar hyperparameter controlling exploration, but are kept distinct for analysis.
- If $a$ is rejected, play $\hat{a}$.
- Observe reward $r$ and update $\hat{r}$ using $x$, played action, and observed reward.
Regarding the meaning of $h$: this exploration ensures low regret against an adversary that can pick any reward distribution $Q$ which satisfies $|dQ/d\mu| \leq h^{-1}$. This is a kind of trust region around the prior distribution.
The Gibbs policy
Recall the Gibbs policy results from pretending to run the following optimization $$
\begin{align*}
\pi^* &= \arg \max_{\pi} \mathbb{E}_{\ldots}\left[ \left. \mathbb{E}_{\substack{a \sim \pi(\cdot | x) \\ r \sim P(r|x,a)}}\left[r\right] – \lambda^{-1} \text{KL}(\pi(\cdot|x) \| \mu(\cdot|x) \right| x\right].
\end{align*}
$$ which results in $$ \begin{align*} \pi^*(a|x) &= Z(x)^{-1} \mu(a|x) \exp\left( \lambda \mathbb{E}\left[r | x, a\right]\right), \\ Z(x) &= \mathbb{E}_{a \sim \mu(\cdot|x)}\left[ \exp\left( \lambda \mathbb{E}\left[r | x, a\right]\right) \right]. \end{align*} $$ Thus the log density ratio to the prior encodes the expected reward $$ \begin{align*} \log \frac{\pi^*(a|x)}{\mu(a|x)} = \lambda \mathbb{E}\left[r | x, a \right] – \log Z(x). \end{align*} $$ This is used in papers like iStar as a value function estimate, substituting $\pi^*$ for the current policy.
Combining the ideas
Consider an LLM problem where the LLM receives a prompt $x$; and generates a chain-of-thought $z$ and final answer $a$, with no intermediate tool calls that interact with the environment; and then receives a reward that only depends upon $(x, z, a)$. This problem is well-modeled as a contextual bandit in a huge action space. Note a typical reward is something like “did you get the math problem right” with a penalty for long CoT. Also note we’ve changed notation slightly: the action is now $(z, a)$, i.e., the combination of the CoT and the final answer.
Using the log-density-ratio to the prior as our current value function, we have the rejection sampling probability is
$$ \begin{align*} & \left(1 + h \gamma \left( \hat{r}(x, \hat{z}, \hat{a}) – \hat{r}(x, z, a) \right)\right)^{-1} \\ &= \left(1 + \lambda^{-1} h \gamma \left( \log \frac{\pi_\theta(\hat{z},\hat{a} \mid x)}{\mu(\hat{z},\hat{a} \mid x)} – \log \frac{\pi_\theta(z,a \mid x)}{\mu(z,a \mid x)}\right) \right)^{-1}. \end{align*} $$ Here $(z, a)$ is a sample from the prior distribution $\mu$; and $(\hat{z}, \hat{a})$ is the chain-of-thought and final answer that maximizes the likelihood ratio between the current policy $\pi_\theta$ and the prior $\mu$. Because the models are autoregressive, we can exactly compute $\hat{z}$ and $\hat{a}$ by sequentially maximizing the conditional logit difference, but for numerical stability we should avoid tokens with very small $\mu$. Alternatively, the follow-up work Capped IGW describes how to avoid computing the argmax action by using multiple samples from the prior.
So what’s this useful for?
Smoothed IGW is designed to minimize online regret, i.e., a situation where we are trying to maximize reward while learning. This is not usually the situation: typically we are trying to minimize PAC regret, i.e., the performance of the final trained model: as long as the result is good, we don’t care strongly if it does badly during training. Smoothed IGW bounds PAC regret but doesn’t minimize it, so for training models in the lab for fixed deployment, better strategies exist. However if an LLM needs to learn-as-it-goes, Smoothed IGW is pretty simple to implement and provides a strong guarantee. Perhaps continual learning or test-time-training will provide an opportunity to leverage this.
Leave a Reply