In the previous four articles, we solved several specific steepest descent problems with equality constraints on parameters. For the problems in the third and fourth articles, no closed-form solutions could be found, so the author proposed corresponding fixed-point iteration methods. In particular, the "Muon on Stiefel manifold" problem studied in the third article "Steepest Descent on Manifolds: 3. Muon + Stiefel" originated from Jeremy Bernstein's article "Orthogonal manifold".
For this problem, Jeremy Bernstein eventually presented his own solution, which the author terms "Dual Gradient Descent". This approach is quite valuable to study and will be the focus of this article.
Basic Concepts#
Jeremy Bernstein's solution was finally published in the Thinking Machines Lab blog "Modular Manifolds", which is the second blog post from that laboratory. In the article, it is referred to as "Dual Ascent", but following the context of the previous four articles, the author here will call it "Dual Gradient Descent".
In fact, dual gradient descent can be considered a natural consequence of the method of Lagrange multipliers. However, rigorous discussion of the method of Lagrange multipliers is quite involved, requiring concepts like the Minimax theorem. Therefore, in this series, to avoid these complexities, we adopted the "method of undetermined coefficients" as our derivation approach. This makes dual gradient descent appear less natural. Nevertheless, we can still derive it following our line of reasoning, albeit with some additional exposition.
First, let us review the mathematical notation. Let $\boldsymbol{W}\in\mathbb{R}^{n\times m}$ be a matrix parameter. Without loss of generality, assume $n\geq m$. Let $\boldsymbol{G}\in\mathbb{R}^{n\times m}$ be its gradient. $\Vert\boldsymbol{G}\Vert_2$ denotes the spectral norm of matrix $\boldsymbol{G}$, equal to its largest singular value. $\Vert\boldsymbol{G}\Vert_*$ denotes the nuclear norm of matrix $\boldsymbol{G}$, equal to the sum of all singular values. Particularly, according to the results from the article "Derivatives of SVD", we have:
where $\boldsymbol{G}=\sum_i \sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top} = \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$ is the SVD of $\boldsymbol{G}$. That is, the gradient of the nuclear norm is exactly the $\msign$ operator, which forms an important foundation for the subsequent derivation.
Problem Statement#
We will introduce dual gradient descent following our previous derivation approach, so this section will first restate the problems and existing results.
In "Steepest Descent on Manifolds: 3. Muon + Stiefel", the problem we need to solve is:
The solution is $\boldsymbol{\Phi} = \msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})$, where $\boldsymbol{X}\in\mathbb{R}^{m\times m}$ is an undetermined symmetric matrix such that $\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}$.
In "Steepest Descent on Manifolds: 4. Muon + Spectral Sphere", the problem we need to solve is:
The answer is $\boldsymbol{\Phi} = \msign(\boldsymbol{G} + \lambda\boldsymbol{\Theta})$, where $\lambda$ is an undetermined coefficient such that $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0$.
It can be seen that our ultimate task in both cases becomes finding undetermined coefficients that satisfy the additionally introduced equality constraints. This essentially reduces to solving nonlinear equation(s). Dual gradient descent transforms the equation solving into minimizing a certain objective function, thereby allowing solution via gradient descent.
Dual Objective#
The key to the transformation is the nuclear norm gradient equality (1). For simplicity, let us first examine the "Muon + Spectral Sphere" problem (3), where the undetermined coefficient is just a scalar, making it easier to observe. It is not difficult to verify:
This means that solving the equation $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0$ is equivalent to finding points where the gradient of $\Vert\boldsymbol{G} + \lambda\boldsymbol{\Theta}\Vert_*$ equals 0, which could be its (local) minimum/maximum points. Since $\Vert\boldsymbol{G} + \lambda\boldsymbol{\Theta}\Vert_*$ clearly has no maximum, we transform the problem into finding its minimum point:
Let us summarize the steps here:
- Our goal is to solve the equation $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0$, finding any solution suffices.
- $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})$ happens to be the gradient of $\Vert\boldsymbol{G} + \lambda\boldsymbol{\Theta}\Vert_*$ with respect to $\lambda$.
- This transforms into finding (local) minimum/maximum points, because gradients at such points are often zero.
- It can be simply determined that no maximum exists, so we can only find the minimum.
Gradient Descent#
After determining the objective (5), we can solve it using gradient descent. The gradient is readily available as $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})$. Thus, the gradient descent update rule is:
Of course, we could also consider adding $\newcommand{sign}{\mathop{\text{sign}}}\sign$ to $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})$, i.e., SignSGD. These variations can be freely explored. From the perspective of the iteration format, dual gradient descent appears much simpler than the fixed-point iteration we proposed earlier. However, in many cases, dual gradient descent requires many more iteration steps and may need careful tuning of the learning rate, introduction of momentum mechanisms, etc., to possibly converge.
Therefore, for the purpose of solving the equation $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0$, dual gradient descent is not a particularly ideal solution. However, our ultimate goal is not merely to solve the equation $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0$, but to compute $\boldsymbol{\Phi}$ as the optimization direction for the model. Model optimization itself is an iterative process. We can cache historical $\lambda$ values and adopt an approximate strategy of synchronously updating $\lambda$ with model parameters:
In this way, each training step only requires one additional, nearly free computation of $\lambda - \eta_2 \tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})$, yielding an approximate implementation of the original objective (3). In form, it is equivalent to an adaptive Weight Decay for Muon.
Stiefel Manifold#
After discussing the relatively simple "Muon + Spectral Sphere", let us now consider "Muon + Stiefel", i.e., objective (2). Here the undetermined matrix $\boldsymbol{X}$ has the constraint $\boldsymbol{X}=\boldsymbol{X}^{\top}$. We remove this constraint by setting $\boldsymbol{X}=\boldsymbol{\Lambda}+\boldsymbol{\Lambda}^{\top}$, where $\boldsymbol{\Lambda}\in\mathbb{R}^{m\times m}$ is an arbitrary matrix. Then we can find:
where $\boldsymbol{\Phi} = \msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})$. Therefore, solving the equation system $\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W}=\boldsymbol{0}$ can similarly be transformed into finding the minimum point of the function $\Vert\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}\Vert_*$, and then solved via gradient descent:
Since $\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W}$ must be symmetric, directly updating $\boldsymbol{X} \leftarrow\boldsymbol{X} - \eta(\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W})$ is also feasible. Combining it with $\boldsymbol{W}$ for synchronous iteration, we obtain:
This achieves an approximation of objective (2), with the additional $\boldsymbol{X} - \eta_2(\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W})$ also being nearly free at each step.
Lagrangian Formulation#
In both examples, the equations that need to be solved happen to equal the gradient of some nuclear norm objective. Is this merely coincidence? Certainly not. As mentioned in the "Basic Concepts" section, this is a natural result of the method of Lagrange multipliers. This section will elaborate on this discussion.
For ease of understanding, let us again take the relatively simple objective (3) as an example. It can be equivalently written as:
To understand this transformation, one only needs to realize that the above expression must have $\tr(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0$. Otherwise, the $\min$ step can always reach negative infinity, making the final $\max$ result also negative infinity. As for changing $\Vert\boldsymbol{\Phi}\Vert_2 = 1$ to $\Vert\boldsymbol{\Phi}\Vert_2\leq 1$, it does not change the maximum value result (since the maximum is always attained at the boundary), but makes the feasible region for $\boldsymbol{\Phi}$ a convex set.
With this equivalent form, we can utilize the Minimax theorem to exchange the positions of $\min$ and $\max$:
The step of taking $\max$ over $\Vert\boldsymbol{\Phi}\Vert_2\leq 1$ is a basic result from Muon derivation, so taking $\max$ first is not difficult. This gives us the dual objective $\Vert\boldsymbol{G} + \lambda \boldsymbol{\Theta}\Vert_*$ for the original problem (3).
Some readers might wonder: This Lagrange multiplier method seems different from what I learned? Because the Lagrange multiplier method here is generalized to general convex sets, and strictly discusses the exchangeability of $\min,\max$ to ensure the final result is what we want. The Lagrange multiplier method generally taught involves a heuristic solution process for constrained optimization problems in $\mathbb{R}^n$, without much discussion of theoretical guarantee details.
Article Summary#
This article introduced the idea of finding the steepest descent direction on manifolds via dual gradient descent. This is also the method used by the Thinking Machines Lab blog "Modular Manifolds" to solve Muon on the Stiefel manifold.
Key contributions of this article:
- Dual Formulation: Presented a systematic approach to transform constraint satisfaction problems into minimization problems through dual objectives.
- Unified Framework: Demonstrated how both spectral sphere and Stiefel manifold constraints can be handled within the same dual gradient descent framework.
- Theoretical Foundation: Connected the dual approach to classical Lagrangian duality and the Minimax theorem, providing theoretical justification.
- Practical Algorithms: Derived concrete gradient descent update rules for both scalar and matrix dual variables.
- Implementation Strategy: Proposed synchronous update schemes that integrate dual variable optimization with model parameter updates, minimizing computational overhead.
The dual gradient descent approach provides an alternative perspective on constrained optimization on manifolds, complementing the fixed-point iteration methods presented earlier in the series. While potentially requiring more iterations, it offers simpler update rules and a clear connection to established optimization theory.
Original Article: Su Jianlin. Steepest Descent on Manifolds: 5. Dual Gradient Descent. Scientific Spaces.
How to cite this translation:
BibTeX: