Skip to content

KellerJordan/Muon

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 

Repository files navigation

Muon: An optimizer for the hidden layers of neural networks

This repo contains an implementation of the Muon optimizer originally described in this thread and this writeup.

Installation

pip install git+https://github.com/KellerJordan/Muon

or

pip install muon_optimizer

Usage

Muon is an optimizer for the hidden weights of a neural network. Other parameters, such as embeddings, classifier heads, and hidden gains/biases should be optimized using standard AdamW. Muon should be used as follows:

# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)

# To replace the above, do the following:

from muon import MuonWithAuxAdam
hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
nonhidden_params = [*model.head.parameters(), *model.embed.parameters()])
param_groups = [
    dict(params=hidden_weights, use_muon=True,
         lr=0.02, weight_decay=0.01),
    dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
         lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01),
]
optimizer = MuonWithAuxAdam(param_groups)

You'll have to replace model.body, model.head, and model.embed with whatever is appropriate for your model. E.g., for a ConvNet, you should use Muon to optimize all the convolutional filters except the first one, and AdamW to optimize everything else.

Example usage

Example use in the NanoGPT speedrun

Example use in the CIFAR-10 speedrun

Hyperparameter tuning

Typically, the default values of momentum (0.95), nesterov (True), and ns_steps (5) work well. Only the learning rate and weight decay must be tuned. The learning rate should have constant muP scaling: That is, as you scale up the model size, you shouldn't need to retune the learning rate.

Benchmarks

For a comparison between AdamW, Shampoo, SOAP, and Muon for training a 124M-parameter transformer, see here.

Accomplishments

More learning resources and results about Muon

Citation

@misc{jordan2024muon,
  author       = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and
                  Franz Cesista and Laker Newhouse and Jeremy Bernstein},
  title        = {Muon: An optimizer for hidden layers in neural networks},
  year         = {2024},
  url          = {https://kellerjordan.github.io/posts/muon/}
}

About

Muon optimizer: +>30% sample efficiency with <3% wallclock overhead

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages