Steepish Descent

Unleash the chains of thought.

Softmax Muon

Vibe-research is really productive: this idea would have taken me weeks to formalize myself, but instead took hours. Since most ideas are bad, minimizing time-to-empirical-verification is a huge win.

Something has bothered me about the GPT speedrun for a while: most of the matrix parameters are optimized via Muon, but the (un)embedding matrices are still optimized with Adam. I’ve actually tried several ideas to mitigate this and all of them have been failures so far. But if at first you don’t succeed, try try again! So here’s yet another idea.

Update: the diagonal Hessian in gauge coordinates indicates Adam is a good fit, see next post.

Muon background

Muon is motivated by noting that if a linear layer $y = W x$ with unit L2-norm input $x$ is perturbed by $\Delta W$, the L2-norm of the output change is controlled by the spectral norm $\sigma_{\max}(\Delta W)$. This leads to a steepest descent direction obtained by orthogonalizing the raw gradient $G$ by projecting onto the spectral norm ball boundary to get the change direction $\Delta W$,
$$
\max_{\Delta W} \text{tr}\left(\Delta W^\top G \right) \text{ s.t. } \Delta W^\top \Delta W \preceq I,
$$
which has solution $\Delta W = \left(G\right)_{\text{polar}}$. In other words, if $G = U S V^\top$ then $\Delta W = U V^\top$.

However for a softmax layer, the L2-norm of the logits is not the proper concept of size (e.g., adding a constant to all the pre-logits does not change the output at all). Therefore the spectral norm constraint $\Delta W^\top \Delta W \preceq I$ is “wrong” in the sense of not controlling the magnitude of the change in the output.

Softmax geometry

Park et al. analyze the softmax distribution
$$
P_\lambda(y) = \exp\left(\lambda^\top y – A(\lambda)\right)
$$
where $y$ is a one-hot vector, $\lambda = W x$ is the pre-logits vector given an (un)embedding matrix $W \in \mathbb{R}^{n \times d}$ with a vocabulary of size $n$ and a hidden state $x$ of dimension $d$, and $A(\lambda)$ is the log-normalizer
$$
A(\lambda) = \log \sum_y \exp\left(\lambda^\top y\right).
$$
They note that the KL-divergence between the distribution induced by two different pre-logits is
$$
D(P_\lambda | P_{\lambda’}) = A(\lambda’) – A(\lambda) – \nabla A(\lambda)^\top \left(\lambda’ – \lambda\right),
$$
where $\nabla A(\lambda) = P_\lambda$ is the gradient with respect to $\lambda$. This is a Bregman divergence induced by the convex function $A$, which leads to a local quadratic approximation to the KL-divergence
$$
\nabla^2 A(\lambda) = \text{Cov}_{y \sim P_{\lambda}}\left[y\right].
$$
This local Hessian $H_{\lambda,\lambda} = \nabla^2 A(\lambda)$ suggests a curvature-aware variant of Muon,
$$
\max_{\Delta W} \text{tr}\left(\Delta W^\top G \right) \text{ s.t. } \Delta W^\top H_{\lambda,\lambda} \Delta W \preceq I,
$$
which has solution
$$
\Delta W = H_{\lambda,\lambda}^{-1/2} \left(H_{\lambda,\lambda}^{-1/2} G\right)_{\text{polar}}
$$
i.e., a three-step process:

  1. Whiten gradient with $H_{\lambda,\lambda}^{-1/2}$.
  2. Find orthogonal polar factor.
  3. Recolor with $H_{\lambda,\lambda}^{-1/2}$.

Computational Considerations

In practice, the naive three-step approach has (at least) two issues.

First issue: $H$ has dimensionality $n \times n$ so the above operations are not computationally viable. Fortunately, using $(X)_{\text{polar}} = X (X^\top X)^{-1/2}$, we can rewrite the optimal direction as1
$$
\Delta W = (H^\dagger G)(G^\top H^\dagger G)^{-1/2}.
$$

Second issue: $H$ is guaranteed to have a null space, because the distribution induced by pre-logits is invariant to constant shifts. We can fix this by using gauge coordinates defined by $Q \in \mathbb{R}^{n \times (n-1)}$ where $Q^\top Q = I$ and $Q^\top \mathbf{1}=0$. This has the ambient space projection operator $$M = I – n^{-1} \mathbf{1} \mathbf{1}^\top$$ which subtracts the mean, and results in a convenient representation for $H^\dagger$ in the ambient space (proof below): $$H^\dagger = M \text{Diag}\left(P_{\lambda}^{-1}\right) M.$$
This suggests the five step process:

  1. Mean center gradient: $\tilde{G} = M\, G$.
  2. Compute product $B = H^\dagger G \in \mathbb{R}^{n \times d}$ via $B = M \text{Diag}(P_{\lambda}^{-1}) \tilde{G}$.
    • In practice, dividing by $P_{\lambda}$ is dangerous, so damp with a stochastic mixture with a uniform distribution.
  3. Compute “small” matrix $K = \tilde{G}^\top B \in \mathbb{R}^{d \times d}$.
  4. Compute $K^{-1/2}$ (in practice, symmetrize and optionally damp $K$ first).
    • Turns out this can be done by a different Newton-Schulz iteration.
  5. Compute $\Delta W = B K^{-1/2}$.

Does it work?

I’m not sure I have to try it out on the speedrun. On the plus side, vibe-coding is also really productive, so it shouldn’t take too long to test this idea.

Verifying the Pseudoinverse

The bots gave me the formula so I forced them to give me the proof.2

Our goal is to solve $H y = M x$ with $\mathbf{1}^\top y = 0$. Let’s try solution $y = M \text{Diag}(P_{\lambda}^{-1}) M x$. Denote $D = \text{Diag}(P_{\lambda})$.
$$
\begin{align*} H y &= (D – P_{\lambda} P_{\lambda}^\top) M D^{-1} M x \\
D M D^{-1} M x &= \left(I – n^{-1} D \mathbf{1} \mathbf{1}^\top D^{-1} \right) M x \\
&= M x – n^{-1} P_{\lambda} \mathbf{1}^\top D^{-1} M x, \\
P_{\lambda} P_{\lambda}^\top M D^{-1} M x &= P_{\lambda} P_{\lambda}^\top D^{-1} M x – n^{-1} P_{\lambda} P_{\lambda}^\top \mathbf{1} \mathbf{1}^\top D^{-1} M x \\
&= P_{\lambda} \mathbf{1}^\top M x – n^{-1} P_{\lambda} \mathbf{1}^\top D^{-1} M x \\
&= – n^{-1} P_{\lambda} \mathbf{1}^\top D^{-1} M x,
\end{align*}
$$ which indicates $H y = M x$. Next note $\mathbf{1}^\top y = 0$ because $\mathbf{1}^\top M = 0$. Thus $y$ is in gauge and $H^\dagger = M\,\text{Diag}(P_{\lambda}^{-1})\,M$ on the ambient coordinates.

  1. I may have never figured out this part on my own. The bots suggested this. Vibe research ftw. ↩︎
  2. Clearly the bots are way better at math than me. Not a new thing actually: historically I would use Reduce in Mathematica to prove stuff for me; but the bots are way more convenient. ↩︎

Leave a Reply

Discover more from Steepish Descent

Subscribe now to keep reading and get access to the full archive.

Continue reading