[BUG] cifar example fails with ValueError when replacing LayerNorm with BatchNorm with default params
When replacing LayerNorm with BatchNorm (merged in https://github.com/ml-explore/mlx/pull/217) with default params, the cifar example fails with ValueError with mlx 0.0.7:
File "../mlx-examples/cifar/main.py", line 48, in train_epoch
mx.eval(model.parameters(), optimizer.state)
ValueError: [eval] Illegal to eval an array during function transform without graph retention.
Using BatchNorm with nn.BatchNorm(dims, track_running_stats=False) seems to work OK.
Is that expected behavior?
Reproduce the issue
The ResNet with BatchNorm is available on branch resnet_batch_norm at https://github.com/menzHSE/mlx-examples. Run python main.py in mlx-examples/cifar
That is not expected, it sounds like a bug. Thanks for reporting, I will take a look.
I was able to repro using the code @menzHSE provided. If you pass retain_graph=True to the eval() method it suppresses the error, but it trains much slower, as expected. Here are the culprit lines.
Yes we are very much aware of this issue. Working with @angeloskath on a fix.
For now I recommend avoiding the running stats until we fix it
I was able to repro using the code @menzHSE provided. If you pass
retain_graph=Trueto theeval()method it suppresses the error, but it trains much slower, as expected. Here are the culprit lines.
Thanks! With retain_graph=True it runs out of memory on my 16 GB M1 though vs. approx. 840MB usage with retain_graph=False.
This should fix it https://github.com/ml-explore/mlx/pull/385