Optimisers.jl icon indicating copy to clipboard operation
Optimisers.jl copied to clipboard

Add `setup(::Function, model)`

Open mcabbott opened this issue 1 year ago • 2 comments

Quick sketch of one way to easily allow different rules for different arrays, by modifying setup -- see docstring.

PR Checklist

  • [ ] Tests are added
  • [ ] Documentation, if applicable

mcabbott avatar Dec 21 '24 19:12 mcabbott

I think this is elegant and useful. I'm working on some improvements to #203. Muon is optimal for linear layers, but doesn't make as much sense for e.g. Flux.Embedding, even though it is linear-like, and since the linear decoder layer in LLMs is often tied to the input encoder layer, it's preferable to disable Muon for that layer as well. I imagine the cleanest way of differentiating between linear layers is with an IdDict inside the setup rule function. You'd for example create the function based on which layers are present in some IdDict, and in the same function embed rules for different array shapes.

AntonOresten avatar Oct 21 '25 11:10 AntonOresten

One could do something like:

function fun_rule(model, rule=Muon(), fallback=Adam())
    skipped = Base.IdSet{Any}([model.encode.weight, model.decode.weight])
    fun(x::AbstractVector) = fallback
    fun(x::AbstractArray) = x in skipped ? fallback : rule
    return fun
end

opt_state = Optimisers.setup(fun_rule(model), model)

such that:

julia> model = (;
           encode=(; weight=rand(2,2)),
           other=(; weight=rand(2,2), bias=rand(2)),
           decode=(; weight=rand(2,2)));

julia> fun_rule(model)(model.encode.weight)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)

julia> fun_rule(model)(model.other.weight)
Muon(0.02, 0.95, 0.01, 1.0e-7, true)

julia> fun_rule(model)(model.other.bias)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)

julia> fun_rule(model)(model.decode.weight)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)

I generally avoid closures, but this has a certain elegance to it. Base.IdSet is private, but the alternative is slightly cursed:

skipped = keys(IdDict([model.encode.weight, model.decode.weight] .=> nothing))

AntonOresten avatar Oct 22 '25 12:10 AntonOresten