mlx icon indicating copy to clipboard operation
mlx copied to clipboard

The `parameters()` method of a Module can return duplicates

Open francescofarina opened this issue 2 years ago • 2 comments

Calling the parameters method (or closely related ones like trainable_parameters or children) can potentially return duplicates when nesting modules.

Please ignore if it's a desired feature.

Example

Define a simple MLP consisting of 3 submodules (self.input, self.hidden and self.output) which are also concatenated in self.layers

import mlx.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int):
        super().__init__()

        self.input = nn.Linear(in_dims, 1, bias=False)
        self.hidden = nn.Sequential(
            nn.Linear(1, 1, bias=False), nn.Linear(1, 1, bias=False)
        )
        self.output = nn.Linear(1, out_dims, bias=False)
        self.layers = nn.Sequential(self.input, self.hidden, self.output)

Currently, by instantiating the model and printing the parameters

model = MLP(1, 1)
print(model.parameters())

one gets

{
	'input': {
		'weight': array([[0.373375]], dtype=float32)
	}, 
	'hidden': {
		'layers': [
			{
				'weight': array([[0.171902]], dtype=float32)
			}, 
			{
				'weight': array([[-0.483425]], dtype=float32)
			}
		]
	}, 
	'output': {
		'weight': array([[0.819667]], dtype=float32)
	}, 
	'layers': {
		'layers': [
			{
				'weight': array([[0.373375]], dtype=float32)
			}, 
			{
				'layers': [
					{
						'weight': array([[0.171902]], dtype=float32)
					}, 
					{
						'weight': array([[-0.483425]], dtype=float32)
					}
				]
			}, 
			{
				'weight': array([[0.819667]], dtype=float32)
			}
		]
	}
}

which contains duplicates as the sequential layers of self.layers contains references to self.input, self.hidden and self.output (if one tries to update the parameters of self.hidden also self.layers gets "updated").

I'm not sure whether this is a desired feature but it looks unexpected and potentially error-leading when iterating over the parameters of a model.

For reference, the same implementation in pytorch would return each parameter only once

import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int):
        super().__init__()

        self.input = nn.Linear(in_dims, 1, bias=False)
        self.hidden = nn.Sequential(
            nn.Linear(1, 1, bias=False), nn.Linear(1, 1, bias=False)
        )
        self.output = nn.Linear(1, out_dims, bias=False)
        self.layers = nn.Sequential(self.input, self.hidden, self.output)


model = MLP(1, 1)
for i in model.named_parameters():
    print(i)
('input.weight', Parameter containing:
tensor([[0.2972]], requires_grad=True))
('hidden.0.weight', Parameter containing:
tensor([[-0.1292]], requires_grad=True))
('hidden.1.weight', Parameter containing:
tensor([[-0.0626]], requires_grad=True))
('output.weight', Parameter containing:
tensor([[-0.0066]], requires_grad=True))

A similar behavior to mlx is non default in pytorch and can be achieved by specifying model.named_parameters(remove_duplicate=False).

francescofarina avatar Dec 17 '23 17:12 francescofarina

That's a very good point. The parameters() can return duplicates which feels kind of weird. The reason this is not as simple to solve as removing duplicates from the returned dictionary is that for instance what happens when one updates the parameters with new. Should the "duplicate" parameter be updated or not?

Btw I am not saying that what we do now is correct or better, we currently do nothing which is probably confusing. I am wondering what the right way to deal with this would be. Perhaps disallow duplicates completely and make it an error?

On the other hand it is quite easy to solve the situation above by making layers a "constant" layer that is not considered in the parameters calculation by simply making it a "protected" member (starting its name with an underscore).

import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map

class MLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int):
        super().__init__()

        self.input = nn.Linear(in_dims, 1, bias=False)
        self.hidden = nn.Sequential(
            nn.Linear(1, 1, bias=False), nn.Linear(1, 1, bias=False)
        )
        self.output = nn.Linear(1, out_dims, bias=False)
        # underscore below kinda important
        self._layers = nn.Sequential(self.input, self.hidden, self.output)

model = MLP(1, 1)
tree_flatten(model.parameters())
# [('input.weight', array([[0.96391]], dtype=float32)),
#  ('hidden.layers.0.weight', array([[0.884289]], dtype=float32)),
#  ('hidden.layers.1.weight', array([[0.220383]], dtype=float32)),
#  ('output.weight', array([[-0.0979315]], dtype=float32))]
print(model._layers.layers[0].weight)
# array([[0.96391]], dtype=float32)
model.update(tree_map(lambda x: 2*x, model.parameters()))
tree_flatten(model.parameters())
# [('input.weight', array([[1.92782]], dtype=float32)),
#  ('hidden.layers.0.weight', array([[1.76858]], dtype=float32)),
#  ('hidden.layers.1.weight', array([[0.440766]], dtype=float32)),
#  ('output.weight', array([[-0.195863]], dtype=float32))]
print(model._layers.layers[0].weight)
# array([[1.92782]], dtype=float32)

angeloskath avatar Dec 18 '23 00:12 angeloskath

That makes sense, but I would argue that having to use protected members may result in a design/behavioral pattern that's hard to enforce in a "soft" way. Throwing an error is a possibility (in this case if layers is not protected) but I'm not sure that'd be the best solution.

The "design" problem I see is that the "duplicates" returned by parameters() are not really duplicates in memory. For example, in my initial example, both self.input and self.layers.layers[0] point to the same memory location.

print(tree_flatten(model.parameters()))
# [('input.weight', array([[0.373375]], dtype=float32)), 
# ('hidden.layers.0.weight', array([[0.171902]], dtype=float32)), 
# ('hidden.layers.1.weight', array([[-0.483425]], dtype=float32)), 
# ('output.weight', array([[0.819667]], dtype=float32)), 
# ('layers.layers.0.weight', array([[0.373375]], dtype=float32)), 
# ('layers.layers.1.layers.0.weight', array([[0.171902]], dtype=float32)), 
# ('layers.layers.1.layers.1.weight', array([[-0.483425]], dtype=float32)), 
# ('layers.layers.2.weight', array([[0.819667]], dtype=float32))]

# update the `input` submodule
model.input.update(tree_map(lambda x: 2 * x, model.input.parameters()))
print(tree_flatten(model.parameters()))
# [('input.weight', array([[0.746751]], dtype=float32)),  <-- I updated this
# ('hidden.layers.0.weight', array([[0.171902]], dtype=float32)), 
# ('hidden.layers.1.weight', array([[-0.483425]], dtype=float32)), 
# ('output.weight', array([[0.819667]], dtype=float32)), 
# ('layers.layers.0.weight', array([[0.746751]], dtype=float32)),  <-- Also this one got "updated" as it points to the same array
# ('layers.layers.1.layers.0.weight', array([[0.171902]], dtype=float32)), 
# ('layers.layers.1.layers.1.weight', array([[-0.483425]], dtype=float32)), 
# ('layers.layers.2.weight', array([[0.819667]], dtype=float32))]

I'm not sure what the best solution would be, but probably to return all the arrays that are in memory, not the different references to them.

Just to add a bit more context on why I think this is an important point. Imagine one wants to perform some per-parameter computation or create new parameter-specific variables. Having parameters() return "duplicates" and relying on those to create new variables would result in a "waste" of memory. In my case, I'm working on a PR to enable some uncertainty estimation features on arbitrary models and I need to keep track of each parameter moving average and variance - if the parameters (and duplicates) are many, the memory overhead is significant.

francescofarina avatar Dec 18 '23 09:12 francescofarina