It is well known that the cost of fully training a large Language Model (LLM) is expensive, which precludes the possibility of repeatedly testing hyperparameters directly on large LLMs. A natural idea is to carefully search for hyperparameters on smaller models with the same architecture, then directly transfer the optimal combination to larger models. Although this idea is straightforward, implementing it is non-trivial. It requires understanding the scaling laws between common hyperparameters and model scale, and MuP is one practical implementation of this idea.
MuP, sometimes written as $\mu P$, stands for Maximal Update Parametrization, originating from the paper "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer". With the proliferation of LLM training, it has gradually become one of the de facto standards in scientific model training.
Method Overview#
MuP (Maximal Update Parametrization) investigates the transfer laws of hyperparameters across model scales. The key insight is that optimal hyperparameter combinations found on small models can be directly transferred to larger models with proper scaling.
Before delving into the topic, I must first comment that the original MuP paper is written in an extremely obscure manner, with conclusions expressed insufficiently clearly, unnecessarily increasing the difficulty of comprehension. Therefore, I will attempt to reproduce MuP's conclusions in what I consider a concise manner.
First, the conclusion: MuP primarily investigates the transfer laws of hyperparameters across model scales. There are several key terms here:
1. Hyperparameters: Currently mainly refers to learning rate;
2. Model scale: Currently mainly refers to model width;
3. The core here is "transfer".
Please note that MuP does not study what the optimal hyperparameters are, only how optimal hyperparameters change with model scale. Therefore, we need to search for the optimal hyperparameter combination on some small model, then transfer it to larger models. This is the use case and method of MuP.
The principle behind deriving MuP is to ensure that the model's forward propagation, backward propagation, loss increment, and feature changes do not significantly vary with model scale:
1. The specific approach analyzes the order of magnitude of initialization, then assumes conclusions can represent subsequent optimization patterns;
2. Essentially, it assumes that with proper initialization, subsequent optimization will automatically follow the correct trajectory (a good start is half the battle?);
3. One can also justify this assumption with stories about the Law of Large Numbers or Central Limit Theorem, but personally I don't find this necessary.
Forward Propagation#
We begin with forward propagation because it is relatively simple and well-established. First, consider the linear layer $\boldsymbol{Y}=\boldsymbol{X}\boldsymbol{W}$, where $\boldsymbol{X}\in\mathbb{R}^{b\times d_{in}},\boldsymbol{W}\in\mathbb{R}^{d_{in}\times d_{out}}$. We use RMS (Root Mean Square) as a measure of matrix scale, for example:
We know that to make the RMS of $\boldsymbol{X}$ approximately equal to the RMS of $\boldsymbol{Y}$ during initialization (referred to as "stability"), $\boldsymbol{W}$ should use:
LeCun Initialization: Random initialization with "mean 0, variance $1/d_{in}$".
This is already one of the fundamental conclusions in deep learning, so I won't expand on the derivation. Readers unfamiliar with this can refer to previous blog posts such as "Geometric Perspective on Model Parameter Initialization Strategies" and "On Transformer Initialization, Parameterization, and Normalization".
Next, consider the nonlinear layer $\boldsymbol{Y}=\phi(\boldsymbol{X}\boldsymbol{W})$, where $\phi$ is an element-wise activation function. If we still want to maintain approximate equality between the RMS of $\boldsymbol{X}$ and the RMS of $\boldsymbol{Y}$, the result differs slightly. For example, with $\text{relu}$ activation we obtain:
Kaiming Initialization: Random initialization with "mean 0, variance $2/d_{in}$".
It's easy to see that Kaiming Initialization differs from LeCun Initialization only by a constant factor 2 (independent of model scale). One can prove similar results for other activation functions. So we can conclude:
fan_in initialization: To ensure forward propagation stability, one should use random initialization with "mean 0, variance proportional to $1/d_{in}$".
This conclusion can also be understood as "the effect of activation functions is independent of model scale." So if we only want to analyze the effect of model scale, we can ignore the presence of (element-wise) activation functions and directly obtain the scaling law $\propto 1/d_{in}$ from LeCun initialization.
Backward Propagation#
Now let's continue analyzing backward propagation (gradients). Note that here we assume variables and their gradients have the same shape. We can compute:
The first formula is the gradient of parameters within the current layer, the second is the gradient propagated forward from this layer, $\otimes$ denotes the Hadamard product, and $\phi'$ is the derivative of $\phi$.
Note a fact: For commonly used activation functions, their derivatives can be bounded by a (scale-independent) constant. So at least in terms of order of magnitude, we can write:
Let's look at the second formula first. Compared with $\boldsymbol{Y}=\boldsymbol{X}\boldsymbol{W}$, the matrix multiplied on the right becomes $\boldsymbol{W}^{\top}$. Then according to the previous section's conclusion, to maintain RMS stability in backward propagation, $\boldsymbol{W}$'s initialization should be:
fan_out initialization: Random initialization with "mean 0, variance $1/d_{out}$".
When $d_{in}\neq d_{out}$, the requirements for forward propagation and backward propagation conflict. Someone proposed a compromise strategy:
Xavier initialization: Random initialization with "mean 0, variance $2/(d_{in} + d_{out})$".
This is also called "fan_avg initialization" because it simply algebraically averages $d_{in}$ and $d_{out}$. Other averaging methods can also be considered, refer to "Thoughts on Dimension Averaging Strategies for Non-Square Matrices in Initialization Methods". Xavier initialization appears to balance both forward and backward propagation, but one could also say it doesn't fully satisfy either. A better approach is to design models such that most parameters are square matrices, as discussed later with the model family equation (6).
Loss Increment#
With the groundwork of forward and backward propagation, we can attempt to analyze the increment of the loss function. Consider the change in loss when $\boldsymbol{W}\to \boldsymbol{W} + \Delta\boldsymbol{W}$:
Here $\langle\cdot,\cdot\rangle_F$ is the Frobenius inner product, i.e., flattening matrices into vectors then computing vector inner product. Consider gradient descent $\Delta\boldsymbol{W} = -\eta \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}$, where $\eta$ is naturally the learning rate. Combining with equation (3), we have:
This formula already tells us why the same learning rate $\eta$ cannot be used across model scales:
- $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is a $d_{in}\times d_{out}$ matrix;
- $\left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2$ is the sum of squares of $d_{in}\times d_{out}$ numbers;
- $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is exactly the product of forward and backward propagation;
- If both forward and backward propagation are stable, then each element of $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is $\mathcal{\Theta}(1)$ ($\mathcal{\Theta}$ is "Big Theta Notation");
- Therefore $\left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2$ is $\mathcal{\Theta}(d_{in} d_{out})$.
Point 4 may need additional commentary. $\boldsymbol{X}^{\top}$ is a $d_{in}\times b$ matrix, $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is a $b\times d_{out}$ matrix, their multiplication yields $d_{in} d_{out}$ inner products of $b$-dimensional vector pairs. The inner product sums $b$ terms, and the loss $\mathcal{L}$ typically averages over samples (i.e., includes division by $b$). So if $\boldsymbol{X}^{\top}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ are both scale-independent, their product is essentially also scale-independent [i.e., RMS is $\mathcal{\Theta}(1)$].
The final conclusion shows that if we directly use a small model's learning rate for a large model, then for sufficiently large models, each step's loss increment will explode as parameter scale (i.e., $d_{in} d_{out}$) increases. This means we cannot replicate the small model's convergence process, and may even fail to converge due to steps that are too large.
At this point, one might think of scaling $\eta\propto 1/(d_{in} d_{out})$ to adjust $\Delta\mathcal{L}$. This idea actually aligns with MuP's approach. However, in practical scenarios, due to the incompatibility between forward and backward propagation mentioned earlier, point 4 "if both forward and backward propagation are stable, then each element of $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is $\mathcal{\Theta}(1)$" cannot always hold. So the actual situation is more complex.
Model Assumption#
Now let's consider a scenario closer to practice. Our task is to train a model $\mathbb{R}^{d_{in}}\mapsto \mathbb{R}^{d_{out}}$, where $d_{in},d_{out}$ are determined by data and cannot be changed. As mentioned at the beginning, MuP aims to study how hyperparameters scale with model scale, so all fixed quantities are essentially constants or $\mathcal{\Theta}(1)$. For example, initialization variance $1/d_{in}$ is equivalent to saying initialization variance is $\mathcal{\Theta}(1)$.
We can change the model's architecture, parameter count, etc., but MuP mainly considers width scaling laws. So let's define the model architecture. The primary model family considered is:
Where:
1. $\boldsymbol{X}\in\mathbb{R}^{b\times d_{in}}$ (including batch size);
2. $\boldsymbol{W}_{in} \in \mathbb{R}^{d_{in}\times d}, \boldsymbol{W}_{out} \in \mathbb{R}^{d\times d_{out}}$;
3. $\text{NN}$ is any neural network $\mathbb{R}^d\mapsto \mathbb{R}^d$;
4. Here $d$ is what we commonly call hidden size;
5. We can arbitrarily increase $d$ to enhance model parameter count and capacity;
6. MuP wants to study how hyperparameters change with $d$.
More specifically, here we consider $\text{NN}$ as a $K$-layer MLP:
Here $\boldsymbol{\Omega}=\{\boldsymbol{W}_1,\boldsymbol{W}_2,\cdots,\boldsymbol{W}_K\}$, $\boldsymbol{W}_k\in\mathbb{R}^{d\times d}$, i.e., all are $d\times d$ square matrices, all using fan_in initialization (equivalently, also fan_out initialization).
All parameter matrices being $d\times d$ square matrices here is purely for simplifying analysis, not a mandatory requirement. The real purpose is to assume that parameters in $\text{NN}$ have no scale-independent shapes, e.g., shapes like $d\times 64$ are not allowed because $64$ is a constant, but shapes like $d\times 4d$ are allowed because regardless of fan_in, fan_out, or fan_avg initialization, variance is proportional to $1/d$.
Assembly#
With a concrete model established, we can now assemble all previous conclusions. Parameters to update are divided into three parts: $\boldsymbol{W}_{in},\boldsymbol{\Omega},\boldsymbol{W}_{out}$. Compute gradients separately:
The $\cdot$ operation needs slight explanation: $\boldsymbol{Y}_{in},\boldsymbol{Y}_{out}$ are both matrices, so $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}$ is in principle a fourth-order tensor. The chain rule $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}}$ is actually multiplication of higher-order tensors, but we won't expand on that here. We simply use $\cdot$ as a placeholder; readers only need to know it's a generalization of matrix multiplication.
Now observe patterns:
1. All three expressions contain $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$;
2. The latter two contain $\boldsymbol{W}_{out}^{\top}$;
3. $\boldsymbol{W}_k$ are all square matrices, $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}$ and $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ are both stable [RMS is $\mathcal{\Theta}(1)$];
4. If $\boldsymbol{W}_{in}$ also uses fan_in initialization, then $\boldsymbol{Y}_{out}$ is also stable;
5. For $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}$ to be stable, initialization variance should be $1/d_{out}$, but $d_{out}$ is scale-independent, essentially a constant.
Thus:
1. $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$'s RMS is $\mathcal{\Theta}(1)$, $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\Vert_F^2$ is the sum of squares of $d\times d_{out}$ numbers, so its magnitude is $\mathcal{\Theta}(d\times d_{out})$. Remember $d_{out}$ is constant, so actually $\mathcal{\Theta}(d)$. Therefore, to obtain $\mathcal{\Theta}(1)$ $\Delta\mathcal{L}$, its learning rate must satisfy $\eta_{out}\propto 1/d$;
2. $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\Vert_F^2$ sums $d^2$ numbers. $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$'s RMS are both $\mathcal{\Theta}(1)$. If we directly set $\boldsymbol{W}_{out}$'s initialization variance as $\propto 1/d^2$, then $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$'s RMS is $\mathcal{\Theta}(1/d)$, and after squaring and summing, it becomes exactly $\mathcal{\Theta}(1)$. Therefore, learning rate doesn't need to change;
3. Then $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$'s RMS is also $\mathcal{\Theta}(1/d)$, but $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\Vert_F^2$ only sums $d_{in}\times d$ squared numbers, so the result is $\mathcal{\Theta}(1/d)$. To obtain $\mathcal{\Theta}(1)$ $\Delta\mathcal{L}$, the learning rate actually needs to be amplified by $d$ to counteract this effect, i.e., $\eta_{in}\propto d$.
Feature Variation#
The above results are correct, but careful thought reveals an issue in the derivation: points 2 and 3 above are based on the setting "we directly set $\boldsymbol{W}_{out}$'s initialization variance as $\propto 1/d^2$", which currently lacks direct justification. Without further explanation, the derivation remains incomplete.
Indeed, considering only the requirement $\Delta \mathcal{L}=\mathcal{\Theta}(1)$, other choices cannot be excluded. For example, if $\boldsymbol{W}_{out}$'s initialization variance is set as $\propto 1/d$, then $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$'s RMS is $\mathcal{\Theta}(1/\sqrt{d})$, after squaring and summing becomes $\mathcal{\Theta}(d)$. Then as long as learning rate $\eta\propto 1/d$, $\Delta \mathcal{L}=\mathcal{\Theta}(1)$ can also be achieved. Therefore, to explain the necessity of "$\boldsymbol{W}_{out}$'s initialization variance set as $\propto 1/d^2$", new conditions need to be introduced.
The loss function $\mathcal{L}$ is a macroscopic indicator of the model, or an external metric. Looking only at its changes is insufficient to explain all results, so we need to examine the model's internal details. Specifically, we hope that the variation of each layer's output (often called features, sometimes activations) also has scale invariance.
For example, for linear layer $\boldsymbol{Y}_k = \boldsymbol{Y}_{k-1} \boldsymbol{W}_k$, parameter change $\boldsymbol{W}_k\to \boldsymbol{W}_k + \Delta \boldsymbol{W}_k$ causes output change:
Note $\boldsymbol{Y}_{k-1}\in\mathbb{R}^{b\times d},\Delta\boldsymbol{W}_k\in\mathbb{R}^{d\times d}$, so $\boldsymbol{Y}_{k-1} \Delta\boldsymbol{W}_k$ is the inner product of $b\times d$ pairs of $d$-dimensional vectors. Note here $\Delta\boldsymbol{W}_k$ is a carefully designed update; it's unlikely to be independent of $\boldsymbol{Y}_{k-1}$ as in initialization. So "inner product of $d$-dimensional vector pairs" is more likely $\mathcal{\Theta}(d)$ ($d$-dimensional inner product sums $d$ terms). Therefore, if $\Delta\boldsymbol{Y}_{k-1}$'s RMS is $\mathcal{\Theta}(1)$, then $\Delta\boldsymbol{Y}_k$'s RMS would be $\mathcal{\Theta}(d\times \text{RMS}(\Delta \boldsymbol{W}_k))$.
Thus, to make $\Delta\boldsymbol{Y}_k$'s RMS $\mathcal{\Theta}(1)$, we obtain an additional requirement for $\Delta \boldsymbol{W}_k$:
Combining $\Delta \boldsymbol{W}_k = -\eta\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ and $\Delta\mathcal{L}=\mathcal{\Theta}(1)$, we can obtain the result "$\boldsymbol{W}_{out}$'s initialization variance set as $\propto 1/d^2$".
(Note: This section relies on insights from @Chenyu Zheng, many thanks!)
Adam Version#
The above is MuP for SGD. For Adam, we typically use SignSGD approximation for order-of-magnitude analysis:
1. $\Delta \boldsymbol{W} = -\eta \mathop{\text{sign}}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right)$;
2. $\Delta \mathcal{L} \approx -\eta \left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right|_1$;
3. Here $|\cdot|_1$ means taking absolute value of each element then summing.
- About the SignSGD approximation itself, readers can refer to "How Should Learning Rate Scale with Batch Size?"
- "How Does Adam's epsilon Affect Learning Rate Scaling Laws?"
- In summary, SignSGD is a commonly used approximation when analyzing scaling laws related to Adam.
Now we can mimic the SGD process for analysis:
1. $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$'s RMS is $\mathcal{\Theta}(1)$, $\left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right|_1$ sums $d\times d_{out}$ numbers, magnitude is $\mathcal{\Theta}(d\times d_{out}) = \mathcal{\Theta}(d)$, so its learning rate must satisfy $\eta_{out}\propto 1/d$ to counteract scale effects;
2. $\left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right|_1$ sums $d^2$ numbers. $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$'s RMS are both $\mathcal{\Theta}(1)$. If we set $\boldsymbol{W}_{out}$'s initial variance as $\propto 1/d^2$, then $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$'s RMS is $\mathcal{\Theta}(1/d)$, summing $d^2$ numbers gives $\mathcal{\Theta}(d)$, so learning rate should change as $\eta_k\propto 1/d$ to counteract scale effects;
3. Then $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$'s RMS is also $\mathcal{\Theta}(1/d)$, but $\left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right|_1$ only sums $d_{in}\times d$ numbers, so it's already $\mathcal{\Theta}(1)$. Thus learning rate doesn't need to change with scale.
(Note: Readers can verify that equation (10) is satisfied.)
Muon Version#
Naturally, Muon analysis follows. For Muon itself, we have already provided detailed introductions in "Appreciating Muon Optimizer: The Essential Leap from Vectors to Matrices" and "Muon Sequel: Why Did We Choose to Try Muon?", not repeated here. Similar to Adam using SignSGD, we use MSignSGD to approximate Muon:
1. $\Delta \boldsymbol{W} = -\eta \mathop{\text{msign}}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right)$;
2. $\Delta \mathcal{L} \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_*$ (proof see "Appreciating Muon Optimizer: The Essential Leap from Vectors to Matrices");
3. Here $\Vert\cdot\Vert_*$ refers to the Nuclear norm, which is the sum of all singular values of a matrix;
4. Nuclear norm is not easy to compute, but $F$ norm is easier; it equals the square root of the sum of squared singular values;
5. We use $F$ norm as an approximation for Nuclear norm, thus $\Delta \mathcal{L} \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_*\approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_F$;
6. $F$ norm also equals the square root of the sum of squared elements.
Now we can begin the analysis:
1. $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$'s RMS is $\mathcal{\Theta}(1)$, so $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\Vert_*$ magnitude is $\mathcal{\Theta}(\sqrt{d\times d_{out}}) = \mathcal{\Theta}(\sqrt{d})$. To eliminate scale effects, its learning rate must satisfy $\eta_{out}\propto 1/\sqrt{d}$;
2. $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\Vert_F$ is the square root of the sum of squares of $d^2$ numbers. $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$'s RMS are both $\mathcal{\Theta}(1)$. If we set $\boldsymbol{W}_{out}$'s initial variance as $\propto 1/d^2$, then $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$'s RMS is $\mathcal{\Theta}(1/d)$. After squaring and taking square root, the result is $\mathcal{\Theta}(1)$. So learning rate doesn't change;
3. Then $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$'s RMS is also $\mathcal{\Theta}(1/d)$, but $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\Vert_F$ is the square root of the sum of squares of $d_{in}\times d$ numbers, so it's $\mathcal{\Theta}(1/\sqrt{d})$. The learning rate actually needs to be amplified by $\sqrt{d}$ to counteract this effect, i.e., $\eta_{in}\propto \sqrt{d}$.
(Note: Here Muon's conclusion is correct, but it doesn't satisfy condition (10), because equation (10) in detail also depends on an assumption that updates are element-wise, which Muon doesn't satisfy. So it's actually unusable. Here we didn't carefully expand related discussion but directly adopted the conclusion "$\boldsymbol{W}_{out}$'s initialization variance set as $\propto 1/d^2$", avoiding equation (10).)
Conclusion Summary#
Summarizing the above conclusions:
| Optimizer | $\boldsymbol{W}_{in}$ | $\boldsymbol{W}_k$ | $\boldsymbol{W}_{out}$ | |||
|---|---|---|---|---|---|---|
| Variance | Learning Rate | Variance | Learning Rate | Variance | Learning Rate | |
| SGD | $1/d_{in}$ | $d$ | $1 / d$ | $1$ | $1/d^2$ | $1 / d$ |
| Adam | $1/d_{in}$ | $1$ | $1 / d$ | $1 / d$ | $1/d^2$ | $1 / d$ |
| Muon | $1/d_{in}$ | $\sqrt{d}$ | $1 / d$ | $1$ | $1/d^2$ | $1 / \sqrt{d}$ |
Here $\boldsymbol{W}_k$ refers to all parameters except $\boldsymbol{W}_{in},\boldsymbol{W}_{out}$. It must be emphasized that these relationships are "proportional to" not "equal to". In practice, adjustments can be made based on specific needs. For example, when we actually use Muon, $\boldsymbol{W}_{in}$ and $\boldsymbol{W}_{out}$ optimization typically doesn't use Muon but Adam. This leads to two changes:
1. $\eta_{out}\propto 1/d$;
2. $\eta_{in}$ unchanged.
If combined with Adjust LR proposed in our paper "Muon is Scalable for LLM Training", then learning rate should be multiplied by $\sqrt{\max(n, m)}$, where $n\times m$ is the parameter matrix shape. We have already assumed parameters in $\text{NN}$ part always scale proportionally, so $\sqrt{\max(n, m)}\propto \sqrt{d}$. Therefore, to counteract the scale effect brought by Adjust LR, we need:
3. $\eta_k\propto 1/\sqrt{d}$ .
Article Summary#
This article introduces MuP (Maximal Update Parametrization) in as clear and concise a manner as possible. This work aims to study the transfer laws of hyperparameters across model scales. Based on MuP, we can carefully search for hyperparameters (mainly learning rate and initialization here) on small models at relatively low cost, then transfer them to large models, reducing the cost of large model experimentation.
Objectively speaking, this introduction and analysis are still preliminary. For example, we didn't consider bias terms, didn't evaluate the generality of conclusions beyond MLP architectures, and didn't carefully consider the roles of Normalization and residuals. Not considering bias terms is purely laziness, left as an exercise for readers. As for MuP under different architectures, analysis is generally complicated. But due to neural network similarity, conclusions are roughly the same; we can use them without proof. Personally, I think the more critical improvements involve the effects of Normalization and residuals, especially Normalization, which enables stable forward propagation without relying on special initialization, providing greater freedom and possibilities.
Of course, these are left for subsequent analysis.
Original Article: Su Jianlin. A Preliminary Exploration of MuP: Cross-Model Scaling Laws for Hyperparameters. Scientific Spaces.
How to cite this translation:
BibTeX: