flax icon indicating copy to clipboard operation
flax copied to clipboard

Improve Error Message: Assigning a list of Modules to `self` incorrectly

Open marcvanzee opened this issue 5 years ago • 15 comments

In setup(), submodules are expected to "know their name" when they are assigned to self, which is implemented by overriding __setattr__. This can cause problems when appending modules to a list. Consider the following code.

class Test(nn.Module):
  def setup(self):
    self.layers = []
    for i in range(3):
      self.layers.append(nn.Dense(5))

  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)  # <--- ValueError: Can't call methods on orphaned modules
    return x

The code is failing because self.layers.append() does not assign anything to self, so the overwritten __setattr__ isn't triggered. As a result, the Dense modules are not properly initialized and they are recognized as "orphaned" (no parent scope). The canonical way to fix this code is:

def setup(self):
  self.layers = [nn.Dense(5) for _ in range(3)]

Requiring the user to understand this distinction causes mental overload, especially since the error message is cryptic.

Some suggestions for improvements:

  1. Improve the error message by explaining how modules should be assigned to self in setup.

  2. Create a new abstraction ModuleList, and only allow lists of Modules here (similar to PyTorch).

  3. Convert the list to a some data structure that tracks mutation properly.

Option (1) is a quick fix, but doesn't seem like a great long-term solution, since this error can appear in other cases as well, so we want to make sure the error message also covers the other cases.

Option (2) doesn't require much mental overload, but the user should remember that this abstraction exists.

Option (3) would also work, but I fear that this magic may lead to new special cases in the future that we should address.

I personally feel option (2) is most consistent and robust in the long-term. Since it doesn't have to be used in a particular way, it seems also easiest to understand for users.

marcvanzee avatar Oct 09 '20 12:10 marcvanzee

I am sympathetic to option (2), but here are a few questions:

  1. Could/should we raise an error if you assign a "vanilla" list of Modules?
  2. What happens if you want to store a list that has both Modules and other things (e.g. the relu function, which is a common need for some Sequential patterns)

Also curious to hear from @levskaya who wrote the original "module list logic" in Modules.

avital avatar Oct 22 '20 11:10 avital

Just came across the (lack of) these module containers and searched my way here. With the setup() approach it's definitely useful for more than just the setattr issue mentioned here.

As a PyTorch user, I find it useful to have both a Sequential and ModuleList, where sequential has a call that by default sequentially iterates over the contained modules saving the extra code in call. ModuleList requires manual iteration by iterator or index. On occasion I also use ModuleDict but definitely less than Sequential and ModuleList (in that order).

In PT one can specify the module names in the Sequential container by passing (name, module) pairs via OrderedDict. Otherwise the a sequence of numbers from '0' is generated for you.

Being able to place lambdas / fn in the containers, especially Sequential is important when you don't have module equivalents for activations.

rwightman avatar Oct 25 '20 05:10 rwightman

My high-level view: ModuleLists are much less useful here than pytorch, since we really don't want to be mucking about with the submodules state after setup, compared to pytorch - since we need to guarantee certain invariant functional behavior and we use lazy initialization everywhere. I'm pretty sure supporting mutable ModuleLists is just going to open up a new nightmare dimension of bugs... and if you can't use them outside setup() I'm not sure there's much point. I might be overly paranoid here, so feel free to push back, but the current approach follows the generally immutable design approach we take elsewhere in flax/linen.

@avital I'm definitely not a fan of forcing a ModuleList container, and I lean against 2/3 for the above reasons. I prefer 1 where we make it clearer in error messages, etc. these structures in self assignments are immutable by design. We could support a very limited form of (3) safely inside setup(), but it's a lot of code for what seems like not much benefit, but I could be convinced otherwise.

Regarding your Question 2 - you can assign mixed functions/Modules in setattr, only the Modules inside the pytree will be treated specially and registered, otherwise it's just a list or dict or whatever containing anything you'd like.

@rwightman - a pleasure to see you here! pytorch-image-models is a wonderful piece of work!

Below are two code sketches showing the two ways we'd approach setting up and using a sequential list of layers, the approaches would work fine with functions mixed in as well.

About the only thing we do now that might be annoying is that we force a particular name scheme to lists, etc. assigned to e.g. self.foo - we force the names to be foo_<tree_path>'. For lists 'foo_0, foo_1 etc. We could relax that to just being the defaults if this is a pattern people want to use with custom names... (part of the reason for forcing a canonical name-scheme is to force all submodule names into a canonical relationship with the names on self, but I'd need to think a bit if there's really strong benefits to doing that.)

Let me know if this illuminates anything about the current approach or raises other questions or complaints. ;) Also if you're looking around the code, always feel free to reach out to us directly with questions! We're trying to get a lot of docs and commentary out for the newer API soon, so apologies for the bare state of things at the moment - we've been focused on trials and refining things with a bunch of early users.

from jax import random, numpy as jnp
from flax import linen as nn
from typing import List

class Example1(nn.Module):
  feature_sizes: List[int]
  def setup(self):
    self.layers = [nn.Dense(sz)
                   for idx, sz in enumerate(self.feature_sizes)]
  def __call__(self, x):
    for lyr in self.layers:
      x = lyr(x)
    return x

#NB: this only works inside Modules, not as a top-level Module
#    due to a restriction we can probably remove soon.
class Sequential(nn.Module):
  layers: List[nn.Module]
  @nn.compact
  def __call__(self, x):
    for lyr in self.layers:
      x = lyr(x)
    return x

class Example2(nn.Module):
  feature_sizes: List[int]
  @nn.compact
  def __call__(self, x):
    return Sequential([nn.Dense(sz, name=f'layers_{idx}')
                       for idx,sz in enumerate(self.feature_sizes)])(x)

x = jnp.ones((10,))
key = random.PRNGKey(0)

mdl1 = Example1([3,5,2])
mdl2 = Example2([3,5,2])

variables1 = mdl1.init(key, x)
y1 = mdl1.apply(variables1, x)

variables2 = mdl2.init(key, x)
y2 = mdl2.apply(variables2, x)

jnp.all(y1==y2) # True

levskaya avatar Oct 25 '20 08:10 levskaya

@levskaya thanks! I'm currently exploring jax via Flax (Linen) and Objax right now and still building up my mental model of how everything works by implementing my MBConvNet generator (EfficientNet, MixNet, MobileNet, etc) in way that Flax Linen Setup(), Flax Linen compact, or Objax can be plugged in to the same model config / layer builder. So far it's going well and I have models running validation with either TF or PyTorch origin weights. Hopefully I'll have something to share soon.

At this stage when I run into issues, I have a tendency to fall back to concepts I'm familiar with (PT). I wasn't thinking about the mutability. The Sequential approach mentioned above seems like it'd work. I do like having fairly fine grained control over naming, aside from the OCD aspect, it can be quite useful trying to keep naming schemes fairly similar across different frameworks, etc for weight portability.

I ran into this issue after having implemented the network using the compact decorator approach first as it was better documented and featured in more examples. When I started fiddling with the setup() variant I initially tried using append to fill out my stage/block repeats where I have a list of stages, each stage containing a list of blocks. In PT and Objax I impl those as Sequential of Sequential containers. I don't flatten them as feature map sizes fall along the stage boundaries and useful to keep that for extraction of features.

rwightman avatar Oct 25 '20 17:10 rwightman

@levskaya I agree I think we don't need a mutable ModuleList container (it seems that however you want to construct your module list, you could do so in setup code and then make the final list immutable when assigning it to an attribute). But raising errors in case of Example1 below would be very instructive compared to the current ValueError. Note that if we make the conversion to ModuleList implicit this still allows users to generate the confusing error like in the (hopefully less frequent) Example2 below.

class Example1(nn.Module):
  feature_sizes: List[int]
  final_size: int
  def setup(self):
    self.layers = [nn.Dense(sz) for sz in self.feature_sizes]
    self.layers.append(nn.Dense(final_size))  # Raises some ImmutableModuleList error

class Example2(nn.Module):
  feature_sizes: List[int]
  def setup(self):
    layers = [nn.Dense(sz) for sz in self.feature_sizes]
    self.layers = layers
    layers.append(nn.Dense(final_size))  # Would still fail with ValueError in __call__() ?

andsteing avatar Oct 26 '20 10:10 andsteing

Thanks for the explanation @levskaya! Two questions:

I'm pretty sure supporting mutable ModuleLists is just going to open up a new nightmare dimension of bugs

Why can't we introduce a ModuleList abstraction that is immutable? At least we will have an abstraction with consistent behavior then. I think the main problem is that we currently are using Python lists but it is somewhat difficult to understand how they should be used.

I prefer 1 where we make it clearer in error messages, etc. these structures in self assignments are immutable by design.

When you mean "these structures", do you mean lists containing Modules? Or are there other structures you are referring to? I'm trying to understand how we can improve the error message. since currently the error message triggers in a bit of a weird place, and it would be good to catch it somewhere else.

marcvanzee avatar Oct 26 '20 16:10 marcvanzee

@marcvanzee - curious though, what is the utility of an immutable ModuleList? I thought the main point of them in pytorch was capturing and processing all updates like append, etc? Is it just sort of a trivial wrapper to make it more familiar to PT users? the lack of mutability in that case would be a big annoyance if we mirror their api.

re: "these structures" - we allow assigning any container (dict, list, struct.dataclass, any flax-registered container) containing sub-modules to self. There certainly aren't the correct immutability guards set up yet and we could definitely improve errors messages. I'm not 100% committed to this generality, though I think it could be handy for some advanced things in the future, but in the wild the most common use by far is just lists.

levskaya avatar Oct 27 '20 01:10 levskaya

@andsteing - re (1) if we stick w. the current behavior, at a minimum we should certainly raise an Immutable error of some sort. For (2) there's a real question whether we should allow setattr at all inside other functions. We could allow it inside "compact" functions and probably have everything make sense, but for modules w. multiple "public" methods force setattr in setup() just like we do variables in general.
The problem w. setattr in general with anything dynamic is that I worry it could be easy for users to forget that we're not actually working in python with durable python state once things are jitted or pmap'd.

levskaya avatar Oct 27 '20 01:10 levskaya

@levskaya: I don't have much experience with ModuleList of Pytorch, but I agree that if the behavior is different for Flax it will probably only confuse users. So in that case, I think trying to improve the error message seems like a good next step.

For (2) there's a real question whether we should allow setattr at all inside other functions.

Can we generally "forbid" this? Seems like there are some edges cases, like the one below. Do we ignore them and hope the user just doesn't use local variables to modify structures in self, or do you think there's a way to catch them?

class SimpleModule:
  def __setattr__(self, name, val):
    current_fn = inspect.stack()[1][3]
    if current_fn != 'setup' and current_fn != '__init__':
      raise ValueError('Can only assign to `self` in setup')
    object.__setattr__(self, name, val)

  def __init__(self):
    self.test = [0, 1, 2]
    self.foo()

  def foo(self):
    x = self.test
    x.append(3)   # No error is raised.

marcvanzee avatar Oct 27 '20 18:10 marcvanzee

@levskaya since this issue came up again recently for Kevin Murphy, and it is quite an annoying sharp edge of Linen, I've increased the priority from P2 to P1.

Also I think you mentioned in the chat that it would be a good short-term solution to at least freeze the list/dict/... assigned to self so we can present a better error message. Do you think this is feasible, or do you think it may be a lot of work?

marcvanzee avatar Jan 15 '21 14:01 marcvanzee

Actually the recent Frozen modules PR https://github.com/google/flax/pull/823 turns lists and dicts to tuples and frozen dicts. Not sure about a custom error message though

jheek avatar Jan 15 '21 14:01 jheek

Actually we could of course make a ModuleDict and ModuleList that are mutable during setup and afterwards frozen (and track the modules correctly)

jheek avatar Jan 15 '21 14:01 jheek

Today @salayatana66 ran into this issue again, and they indeed ran into the new error message referring to the frozen dicts. Still I think it took a while for them to understand it so improving the error message is definitely still high priority.

marcvanzee avatar Feb 04 '21 15:02 marcvanzee

After reading all the thread I got an alternative proposal:

  • No ModuleList/Dict abstraction, users just use regular python lists/dicts.
  • The list -> tuple and dict -> FrozenDict conversion happens after setup finishes, this way the user can mutate the list/dicts inside setup.
  • __setattr__ only keeps track of the set of fields that need to be frozen instead of doing on the spot. Alternatively, just iterate over all fields and check which need to be frozen after setup.

cgarciae avatar May 11 '22 16:05 cgarciae