Steepish Descent

Unleash the chains of thought.

Pertinacity and the Puke

My opinions on why importance weighted estimators make bad objective functions.

An academic punching bag for off-policy RL is the importance-weighted policy gradient, driven by the value estimator $$
\begin{align*}
V(\theta) &= \mathbb{E}_{\substack{x \sim P(x) \\ y \sim h(y|x) \\ r \sim P(r|y,x)}}\left[ \frac{\pi(y|x;\theta)}{h(y|x)} r\right],
\end{align*}
$$ where $x$ is a prompt, $y$ is a completion, $h$ is the policy that generated the data, and $\pi$ is the policy that we want to optimize, parameterized by $\theta$. The gradient is $$
\begin{align*}
\nabla_\theta V(\theta) &= \mathbb{E}_{\substack{x \sim P(x) \\ y \sim h(y|x) \\ r \sim P(r|y,x)}}\left[ \frac{\pi(y|x;\theta)}{h(y|x)} r \nabla_\theta \log \pi(y|x; \theta) \right].
\end{align*}
$$ The finite sample version of that gradient is what everybody hates. Given a dataset $\mathcal{D} = \{ (x_t, y_t, h_t, r_t) \}$, $$
\begin{align*}
\nabla_\theta \hat{V}(\theta) &\approx \frac{1}{T} \sum_{t=1}^T \underbrace{\frac{\pi(y_t|x_t;\theta)}{h(y_t|x_t)}}_{w_t(\theta)} r_t \nabla_\theta \log \pi(y_t|x_t; \theta).
\end{align*}
$$ Numerically this is unstable if $w_t(\theta)$ gets large, and statistically the variance of this estimator scales as $\mathbb{E}\left[w(\theta)^2\right]$. But if those were the only problems, you could easily fix this via the self-normalized version $$
\begin{align*}
\nabla_\theta \hat{V}(\theta) &\approx \sum_{t=1}^T \frac{w_t(\theta)}{\sum_s w_s(\theta)} r_t \nabla_\theta \log \pi(y_t|x_t; \theta),
\end{align*}
$$ but spoiler alert this is still meh, and also the self-normalized variant I derived from a martingale lower bound is meh. Basically, the magnitude is under control but the direction of the update sucks.

For simplicity, let’s assume we have a 2 point dataset each rolled out from the same prompt $x$.

DatapointImportance WeightReward
$(x, y_0)$$w_0$ is smaller$r_0$ is big
$(x, y_1)$$w_1$ is bigger$r_1$ is small

Pertinacity

Now we see the gradient in this case is $$
\begin{align*}
\nabla_\theta \hat{V}(\theta) &= \underbrace{w_0 r_0}_{\text{small}} \nabla_\theta \log \pi(y_0|x;\theta) + \underbrace{w_1 r_1}_{\text{big}} \nabla_\theta \log \pi(y_1|x;\theta)
\end{align*}
$$ where I’ve chosen values such that the ordering of $w_i r_i$ is opposite the ordering of $r_i$. This is perverse: in a world where rewards are deterministic but typically small, we want to move a lot of probability to a big reward when we find one. Instead we are more aggressively moving probability onto a meh outcome ($r_1$) just because the ratio of our probability to the logging probability is much bigger for the meh outcome ($w_1$). You might say this is because we are viewing things in log space, but an equivalent linear space formulation is $$
\begin{align*}
\nabla_\theta \hat{V}(\theta) &= \frac{1}{h(y_0|x)} r_0 \nabla_\theta \pi(y_0|x_0;\theta) + \frac{1}{h(y_1|x)} r_1 \nabla_\theta \pi(y_1|x_1;\theta)
\end{align*}
$$ and this still sucks because now we might transfer probability to the meh zone more aggressively just because the logging policy is less likely to go there.

Thus, just because we happen to doing something with higher probability than the logging policy (large $w$), we might encourage poor outcomes even with evidence of better outcomes available. Self-normalization doesn’t fix this. Clipping can mitigate this to some degree if the clipped range of the importance weights is dominated by the range of the rewards, but we’re still encouraging the meh outcome, which draws probability from the rest of the action space.

The Puke

One thing that is popular in practical RL is using advantages, but this doesn’t fix it. Using the average value as the baseline, our update becomes $$
\begin{align*}
\nabla_\theta \hat{V}(\theta) &= \underbrace{w_0 \frac{r_0 – r_1}{2}}_{\text{small positive}} \nabla_\theta \log \pi(y_0|x;\theta) + \underbrace{w_1 \frac{r_1 – r_0}{2}}_{\text{large negative}} \nabla_\theta \log \pi(y_1|x;\theta)
\end{align*}
$$ which is really bad, because we are ejecting a lot of probability from the meh outcome, but we’re only absorbing a little probability into the good outcome. Where does the rest of the probability go? It gets distributed arbitrarily among all other actions. If most actions are meh, this is wasteful. What we want to do is move probability from the meh zone to the good zone, but instead we are moving probability from the meh zone to probably another meh zone.1

“The Puke” mostly distributes meh to meh, rather than transporting meh to good.

Once again self-normalization doesn’t fix this. Clipping can mitigate this, by limiting the amount of puke from datapoint $(x, y_1)$. If we clip just the right amount, then all the probability puked by $(x, y_1)$ gets absorbed by $(x, y_0)$. However clipping is usually static so there’s no reason to believe we have chosen this magic value. Maybe we should solve for a good clipping value in each batch? That could work. But there are also update rules that automatically guarantee balance.

  1. In the SAPO paper, the authors write “negative updates tend to increase the logits of many inappropriate tokens and are therefore
    more prone to introduce instability than positive updates”. The puke sounds cooler. ↩︎

Leave a Reply

Discover more from Steepish Descent

Subscribe now to keep reading and get access to the full archive.

Continue reading