Error differentiating ResNet from `torchvision`
In trying to get an image classification example working for FastAI.jl, I tried training a pretrained ResNet model from torchvision. The forward pass works fine, but when differentiating, I get an error.
I think this is actually a limitation of functorch, but figured I'd report here.
Minimum working example (last line fails on cpu and gpu):
using Cuda, PyCallChainRules
torchvision = pyimport("torchvision")
model = TorchModuleWrapper(torchvision.models.resnet18(pretrained=true).to("cuda:0"))
xs = randn(Float32, 128, 128, 3, 1) |> cu
ys = model(xs)
Zygote.gradient(() -> Flux.mse(model(xs), ys))
Stacktrace
julia> Zygote.gradient(() -> Flux.mae(model(xs), ys))
ERROR: PyError ($(Expr(:escape, :(ccall(#= /home/lorenz/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw)))))
RuntimeError('During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::add_.Tensor) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.')
File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 243, in vjp
try:
File "/home/lorenz/.julia/packages/PyCall/7a7w0/src/pyeval.jl", line 3, in newfn
const Py_eval_input = 258
File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/functorch/_src/make_functional.py", line 259, in forward
@staticmethod
File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torchvision/models/resnet.py", line 283, in forward
return self._forward_impl(x)
File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torchvision/models/resnet.py", line 267, in _forward_impl
x = self.bn1(x)
File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 148, in forward
self.num_batches_tracked.add_(1) # type: ignore[has-type]
Stacktrace:
[1] pyerr_check
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:62 [inlined]
[2] pyerr_check
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:66 [inlined]
[3] _handle_error(msg::String)
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/exception.jl:83
[4] macro expansion
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:97 [inlined]
[5] #107
@ ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:43 [inlined]
[6] disable_sigint
@ ./c.jl:458 [inlined]
[7] __pycall!
@ ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:42 [inlined]
[8] _pycall!(ret::PyObject, o::PyObject, args::Tuple{PyObject, NTuple{62, PyObject}, PyObject}, nargs::Int64, kw::Ptr{Nothing})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:29
[9] _pycall!(ret::PyObject, o::PyObject, args::Tuple{PyObject, NTuple{62, PyObject}, PyObject}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:11
[10] (::PyObject)(::PyObject, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:86
[11] (::PyObject)(::PyObject, ::Vararg{Any})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:86
[12] rrule(wrap::TorchModuleWrapper, args::Array{Float32, 4}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCallChainRules.Torch ~/.julia/packages/PyCallChainRules/ebIKG/src/pytorch.jl:62
[13] rrule
@ ~/.julia/packages/PyCallChainRules/ebIKG/src/pytorch.jl:57 [inlined]
[14] rrule
@ ~/.julia/packages/ChainRulesCore/RbX5a/src/rules.jl:134 [inlined]
[15] chain_rrule
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:216 [inlined]
[16] macro expansion
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0 [inlined]
[17] _pullback(ctx::Zygote.Context, f::TorchModuleWrapper, args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:9
[18] _pullback
@ ~/.julia/dev/_InteractiveSessions/22_03/03_25_pychain_fastai.jl:86 [inlined]
[19] _pullback(::Zygote.Context, ::var"#27#28")
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[20] _pullback(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:34
[21] pullback(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:40
[22] gradient(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:75
[23] top-level scope
@ ~/.julia/dev/_InteractiveSessions/22_03/03_25_pychain_fastai.jl:86
https://github.com/rejuvyesh/PyCallChainRules.jl/blob/main/test/test_pytorch_hub.jl might be of interest. functorch recommends replacing in place batchnorm with other things like groupnorm which works equally well.
I see! Thanks for sharing that, I'll try it out and get back here once I have a working FastAI.jl example. Feel free to close the issue, though
I used the linked code to load a pretrained ResNet and the forward and backward passes work:

I then started training it using the standard FastAI.jl image classification which also works, however, after 50 or so steps, I get a CUDA out-of-memory error thrown by PyTorch. Since the training ran fine for 50 batches and reducing the batch size didn't help, I am assuming there is a GPU memory leak somewhere.

Have you run into this and have any advice on pinpointing or alleviating the problem? Thanks for your help!
Would it be possible to share the script you are running? It's definitely possible that DLPack's memorypool is not freeing the tensors appropriately.
Sure, here it is (adapted lines from test_pytorch_hub.jl included for completeness):
using PyCall, PyCallChainRules, Zygote, Flux
using PyCallChainRules.Torch: TorchModuleWrapper, torch
using FastAI
using CUDA
py"""
import torch
def bn2group(module):
num_groups = 16 # hyper_parameter of GroupNorm
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module_output = torch.nn.GroupNorm(num_groups,
module.num_features,
module.eps,
module.affine,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, bn2group(child))
del module
return module_output
"""
function loadresnet(c::Int)
model = torch.hub.load("pytorch/vision", "resnet18")
model.fc = torch.nn.Linear(model.fc.in_features, c) # change number of output classes
model_gn = py"bn2group"(model)
return TorchModuleWrapper(model_gn)
end
Flux.gpu(m::TorchModuleWrapper) = fmap(CUDA.cu, m)
# FastAI.jl part
data, blocks = loaddataset("imagenette2-320")
task = ImageClassificationSingle(blocks)
learner = tasklearner(
task, data;
callbacks=[ToGPU()],
batchsize=4,
model=gpu(loadresnet(length(blocks[2].classes)))) # model being loaded here
# Training
fitonecycle!(learner, 1)
Epoch 1 TrainingPhase(): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:03:50
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 1.0 │ 2.45761 │
└───────────────┴───────┴─────────┘
Epoch 1 ValidationPhase(): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:19
┌─────────────────┬───────┬───────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼───────┤
│ ValidationPhase │ 1.0 │ 2.464 │
└─────────────────┴───────┴───────┘
Again, seems to run correctly for me. I kept nvtop open on the side as well and I never saw memory usage go higher than 70% for me and was quite constant.
I think with this and #18 the common factor seems to be older CUDA (and possibly older NVIDIA drivers) on your end?
Edit:
Also:
julia> torch.__version__
"1.11.0"
julia> functorch.__version__
"0.1.0"
in case that matters.
I also let this run for a few more epochs. While I do see a slight uptick in memory usage with multiple epochs, it didn't become drastic enough to kill training.
By the way, I found this page in the functorch docs which gives some options for dealing with batch norm layers that may be a bit more convenient, e.g.
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
Nice, I'll move to using this function then!
I updated my CUDA drivers to 11.6, but am still experiencing memory leaks as described above :cry:
I also tried using a vanilla training loop to take FluxTraining.jl out of the equation, so the training loop is not an issue.
My updated CUDA version info:
CUDA.versioninfo()
CUDA toolkit 11.6, artifact installation
NVIDIA driver 510.47.3, for CUDA 11.6
CUDA driver 11.6
Libraries:
- CUBLAS: 11.8.1
- CURAND: 10.2.9
- CUFFT: 10.7.0
- CUSOLVER: 11.3.2
- CUSPARSE: 11.7.1
- CUPTI: 16.0.0
- NVML: 11.0.0+510.47.3
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)
Toolchain:
- Julia: 1.8.0-beta3
- LLVM: 13.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0, 7.1, 7.2
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80, sm_86
1 device:
0: NVIDIA GeForce GTX 1080 Ti (sm_61, 23.625 MiB / 11.000 GiB available)
I put together a smaller MWE that reproduces the GPU OOM error:
using CUDA, PyCall, PyCallChainRules
using PyCallChainRules.Torch: TorchModuleWrapper, torch
fexp = pyimport("functorch.experimental")
model_py = torch.hub.load("pytorch/vision", "resnet18")
model_pygn = fexp.replace_all_batch_norm_modules_(model_py).to(device="cuda")
model = TorchModuleWrapper(model_pygn)
function memoryused()
info = CUDA.MemoryInfo()
return 1 - (info.free_bytes / info.total_bytes)
end
function oom()
xs = cu(randn(Float32, 224, 224, 3, 16))
usage = [memoryused()]
try
for _ in 1:1000
model(xs)
push!(usage, memoryused())
end
catch
finally
return usage
end
end
oom()
Which produces linearly growing utilization values before the error:

@rejuvyesh any idea where this leak may be coming from or how to get started debugging this?
Just wanted to comment that I can reproduce this, but haven't been able to get the time to figure out the reason. Likely need to create a reproducer with just DLPack.jl because this is just forward pass, with no gradients. DLPack.jl keeps a memorypool of shared tensors at: https://github.com/pabloferz/DLPack.jl/blob/2e491ac7e839a7428d817b652c4d525faa52ceac/src/DLPack.jl#L173 and we will need to track the state of this variable to figure out what's happening.
This is definitely a bug in the dlpack interaction:
using CUDA, PyCall, DLPack, Functors
dlpack = pyimport("torch.utils.dlpack")
torch = pyimport("torch")
fexp = pyimport("functorch.experimental")
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
struct TorchModel
fn::PyObject
end
function (wrap::TorchModel)(args...; kwargs...)
return wrap.fn(fmap(x -> DLPack.share(x, PyObject, pyfrom_dlpack), args)...; kwargs...)
end
model_py = torch.hub.load("pytorch/vision", "resnet18")
model_pygn = fexp.replace_all_batch_norm_modules_(model_py).to(device="cuda")
model = TorchModel(model_pygn)
function memoryused()
info = CUDA.MemoryInfo()
return 1 - (info.free_bytes / info.total_bytes)
end
function oom()
batchsize = 128
usage = [memoryused()]
try
for _ in 1:1000
xs = cu(randn(Float32, 224, 224, 3, batchsize))
model(xs)
push!(usage, memoryused())
end
catch
finally
return usage
end
end
oom()
also fails? Might need to change the batchsize for your GPU.
Hi @rejuvyesh
I could run your code on my machine. it seems the value of memoryused() increases gradually.
julia> oom()
0.1589230896872148
0.3943831411266905
0.3996930224840288
0.3996930224840288
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.41296772587737496
0.41296772587737496
0.41296772587737496
⋮
0.8988218700738405
0.8988218700738405
0.8988218700738405
0.8988218700738405
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9120965734671866
0.9120965734671866
(EDIT)
Here is my hardware information.
julia> CUDA.versioninfo()
CUDA toolkit 11.7, artifact installation
NVIDIA driver 510.60.2, for CUDA 11.6
CUDA driver 11.6
Libraries:
- CUBLAS: 11.10.1
- CURAND: 10.2.10
- CUFFT: 10.7.2
- CUSOLVER: 11.3.5
- CUSPARSE: 11.7.3
- CUPTI: 17.0.0
- NVML: 11.0.0+510.60.2
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)
Toolchain:
- Julia: 1.7.2
- LLVM: 12.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80
2 devices:
0: NVIDIA GeForce RTX 3060 (sm_86, 11.762 GiB / 12.000 GiB available)
1: NVIDIA GeForce RTX 3060 (sm_86, 11.752 GiB / 12.000 GiB available)
(tmp) pkg> st
Status `~/tmp/Project.toml`
[fbb218c0] BSON v0.3.5
[052768ef] CUDA v3.10.0
[53c2dc0f] DLPack v0.1.1
[2e981812] DataLoaders v0.1.3
[587475ba] Flux v0.13.1
[d9f16b24] Functors v0.2.8
[dbeba491] Metalhead v0.7.1
[3bd65402] Optimisers v0.2.4
[92933f4c] ProgressMeter v1.7.2
[438e738f] PyCall v1.93.1
[b12ccfe2] PyCallChainRules v0.3.2
[e88e6eb3] Zygote v0.6.40
Python 3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import functorch
>>> torch.__version__
'1.11.0+cu113'
>>> functorch.__version__
'0.1.1'
Thanks for looking into this. I am currently away from my GPU workstation, but can hopefully investigate myself next week.
This is definitely a bug in the dlpack interaction
Can you tell if this is an issue in DLPack.jl or just the way it's used? If the former, I guess we should open an issue there?
The issue is definitely just with dlpack.jl and pytorch interaction and maybe the issue should go there. I need to check whether this happens with Jax as well.
@terasakisatoshi can you increase the batchsize and see if it actually OOMs (earlier)?
@rejuvyesh
Setting batchsize=384 returns oom384() function before reaching length(usage) == 1000.
julia> function oom384()
batchsize = 384
usage = [memoryused()]
try
for _ in 1:1000
xs = cu(randn(Float32, 224, 224, 3, batchsize))
model(xs)
push!(usage, memoryused())
end
catch
finally
return usage
end
end
oom384 (generic function with 1 method)
julia> oom384()
101-element Vector{Float64}:
0.1589230896872148
0.8455571227080395
0.864141707458724
0.864141707458724
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
⋮
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
julia>
I don't think this is the best solution, but the following function notoom, just added GC.gc(false) for each loop, does not occur OOM.
using ProgressMeter
function notoom()
batchsize = 384
usage = [memoryused()]
@showprogress for _ in 1:1000
xs = cu(randn(Float32, 224, 224, 3, batchsize))
model(xs)
push!(usage, memoryused())
GC.gc(false) # <---
end
@assert length(usage) == 1+1000
end
julia> notoom()
Progress: 100%|█████████████████████████████████████████| Time: 0:11:57
Interesting 🤔. I wonder if it would be enough to do that every few steps only and what the memory usage curve would look like
This is not really an issue with DLPack.jl nor with PyCallChainRules.jl or PyCall.jl. As mentioned in https://github.com/pabloferz/DLPack.jl/issues/26 It's just that julia has no way of knowing how often it should garbage collect PyObjects in general.
@terasakisatoshi example above is the correct way of handling this.