Add `setup(::Function, model)`
Quick sketch of one way to easily allow different rules for different arrays, by modifying setup -- see docstring.
PR Checklist
- [ ] Tests are added
- [ ] Documentation, if applicable
I think this is elegant and useful. I'm working on some improvements to #203. Muon is optimal for linear layers, but doesn't make as much sense for e.g. Flux.Embedding, even though it is linear-like, and since the linear decoder layer in LLMs is often tied to the input encoder layer, it's preferable to disable Muon for that layer as well. I imagine the cleanest way of differentiating between linear layers is with an IdDict inside the setup rule function. You'd for example create the function based on which layers are present in some IdDict, and in the same function embed rules for different array shapes.
One could do something like:
function fun_rule(model, rule=Muon(), fallback=Adam())
skipped = Base.IdSet{Any}([model.encode.weight, model.decode.weight])
fun(x::AbstractVector) = fallback
fun(x::AbstractArray) = x in skipped ? fallback : rule
return fun
end
opt_state = Optimisers.setup(fun_rule(model), model)
such that:
julia> model = (;
encode=(; weight=rand(2,2)),
other=(; weight=rand(2,2), bias=rand(2)),
decode=(; weight=rand(2,2)));
julia> fun_rule(model)(model.encode.weight)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)
julia> fun_rule(model)(model.other.weight)
Muon(0.02, 0.95, 0.01, 1.0e-7, true)
julia> fun_rule(model)(model.other.bias)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)
julia> fun_rule(model)(model.decode.weight)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)
I generally avoid closures, but this has a certain elegance to it. Base.IdSet is private, but the alternative is slightly cursed:
skipped = keys(IdDict([model.encode.weight, model.decode.weight] .=> nothing))