Pass the model intializer into `Trainer.__init__`
I think it could be helpful to add a model_initializer: Optional[Callable[[torch.nn.Module], None]] argument to the trainer.
If this argument is provided, then trainer would call model.apply(model_initializer) after the random seed is set, after deterministic mode is configured, and after Event.INIT fires (i.e. after surgery occurs), but before checkpoints are loaded.
The advantages of initializing the model here (rather than relying on the user to initialize the model before passing it into the trainer) is that:
- the initializing would be deterministic w.r.t the random seed (so the user no longer needs to call
composer.utils.reproducibility.set_random_seed()orcomposer.utils.reproducibility.configure_deterministic_mode()before creating the model and constructing the trainer) - model layers replaced via surgery would be initialized.
This would not affect checkpoints, since checkpoint weights would override any initialized weights.
Thoughts?
cc: @jbloxham @ajaysaini725 @anisehsani @hanlint @A-Jacobson
I see what you're getting at and i believe this would alleviate having to explicitly set or wrap models to take initializers on init, but I'm not in love with it for a few reasons.
-
our current initializers aren't general. They're most relevant to the original resnet family, but even then they don't cover all bases. Layer specific initializations are sort of general (there's a reasonable combination of these) but full initialization schemes, to me, are often model specific.
-
the interface to our initializers is not something I'd want to call directly. it's not useful without looking at the code (to see setting and which layer types it applies to) and it awkward to use (a string enum that returns functions).
-
from a design/oop perspective it makes sense for a Model to have an initialization scheme, that's also been pytorch convention. though i'll admit it's frustrating how many people rely on the defaults set by pytorch layers (often without knowing what they are) rather than explicitly setting their own. (explicit > implicit).
the initializing would be deterministic w.r.t the random seed (so the user no longer needs to call composer.utils.reproducibility.set_random_seed() or composer.utils.reproducibility.configure_deterministic_mode() before creating the model and constructing the trainer)
model layers replaced via surgery would be initialized.
- ^ These don't seem like benefits to me because we've now created a scenario where we're reliant on an (optional) argument for reproducibility. calling those functions should be standard practice for reproducibility, and layers replaced by model surgery could/should have their own initializations set in their respective algorithms.
TLDR: not happy with the current initialization so i'd rather not double down on it + having one clearly explained way to get deterministic behavior is better than having "or do this" interactions.
As an aside, I do agree the conventional pytorch model defined initialization schemes aren't perfect. They are often not obvious (implicitly set as i mentioned) and they aren't flexible. Here's the timm vit weight init scheme as a case study, there are different implementations of ViT in different frameworks and in order to even try to replicate them all accurately he needs a mess of if statements to switch between them. He also initializes (architecture specific) layers in different ways.
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
if mode.startswith('jax'):
# leave cls token as zeros to match jax impl
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
else:
trunc_normal_(self.cls_token, std=.02)
self.apply(_init_vit_weights)
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
""" ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
"""
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
elif name.startswith('pre_logits'):
lecun_normal_(module.weight)
nn.init.zeros_(module.bias)
else:
if jax_impl:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
else:
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif jax_impl and isinstance(module, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
I see what you're getting at and i believe this would alleviate having to explicitly set or wrap models to take initializers on init, but I'm not in love with it for a few reasons.
- our current initializers aren't general. They're most relevant to the original resnet family, but even then they don't cover all bases. Layer specific initializations are sort of general (there's a reasonable combination of these) but full initialization schemes, to me, are often model specific.
- the interface to our initializers is not something I'd want to call directly. it's not useful without looking at the code (to see setting and which layer types it applies to) and it awkward to use (a string enum that returns functions).
If I understand correctly, I think these are limitations with the yahp codepath, since yahp requires initialization to be an enum. The proposal here would be to pass in a function rather than an enum, so it would be general for all use cases.
- from a design/oop perspective it makes sense for a Model to have an initialization scheme, that's also been pytorch convention. though i'll admit it's frustrating how many people rely on the defaults set by pytorch layers (often without knowing what they are) rather than explicitly setting their own. (explicit > implicit).
Yes, it would also work to put the initializer on the CompsoerModel itself -- e.g. something like this?
class ComposerModel:
...
def initialize_model(self):
"""Called by the trainer to initialize model parameters"""
# User can optionally implement
pass
the initializing would be deterministic w.r.t the random seed (so the user no longer needs to call composer.utils.reproducibility.set_random_seed() or composer.utils.reproducibility.configure_deterministic_mode() before creating the model and constructing the trainer)
model layers replaced via surgery would be initialized.
- ^ These don't seem like benefits to me because we've now created a scenario where we're reliant on an (optional) argument for reproducibility. calling those functions should be standard practice for reproducibility, and layers replaced by model surgery could/should have their own initializations set in their respective algorithms.
TLDR: not happy with the current initialization so i'd rather not double down on it + having one clearly explained way to get deterministic behavior is better than having "or do this" interactions.
Alternatively, we can have the recommended way be to have the user call composer.utils.reproducibility, and then remove these arguments from the trainer's __init__. However, I think this would make it slightly less intuitive, as now a user would need another import and know where to call these functions.
The interface for ComposerModel is left to the user, but I do like the idea of passing initializers to the model as python functions. Upside is that they're swappable again. Downside, they aren't where people expect them to be defined (on the pytorch model). Though, if this is a user defined function on an optional, user defined api... do we actually need to do anything other than set convention or use it internally?
Alternatively, we can have the recommended way be to have the user call
composer.utils.reproducibility, and then remove these arguments from the trainer's__init__. However, I think this would make it slightly less intuitive, as now a user would need another import and know where to call these functions.
forgot to address this. I'd like to keep it on the trainer init if possible too (I don't like the additional import) but only if it actually does what it says it does.
Closing because we don't plan on adding this