PyCallChainRules.jl icon indicating copy to clipboard operation
PyCallChainRules.jl copied to clipboard

Error differentiating ResNet from `torchvision`

Open lorenzoh opened this issue 3 years ago • 21 comments

In trying to get an image classification example working for FastAI.jl, I tried training a pretrained ResNet model from torchvision. The forward pass works fine, but when differentiating, I get an error.

I think this is actually a limitation of functorch, but figured I'd report here.

Minimum working example (last line fails on cpu and gpu):

using Cuda, PyCallChainRules

torchvision = pyimport("torchvision")

model = TorchModuleWrapper(torchvision.models.resnet18(pretrained=true).to("cuda:0"))
xs = randn(Float32, 128, 128, 3, 1) |> cu
ys = model(xs)
Zygote.gradient(() -> Flux.mse(model(xs), ys))
Stacktrace
julia> Zygote.gradient(() -> Flux.mae(model(xs), ys))
ERROR: PyError ($(Expr(:escape, :(ccall(#= /home/lorenz/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) 
RuntimeError('During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::add_.Tensor) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.')
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 243, in vjp
    try:
  File "/home/lorenz/.julia/packages/PyCall/7a7w0/src/pyeval.jl", line 3, in newfn
    const Py_eval_input = 258
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/functorch/_src/make_functional.py", line 259, in forward
    @staticmethod
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torchvision/models/resnet.py", line 283, in forward
    return self._forward_impl(x)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torchvision/models/resnet.py", line 267, in _forward_impl
    x = self.bn1(x)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 148, in forward
    self.num_batches_tracked.add_(1)  # type: ignore[has-type]

Stacktrace: [1] pyerr_check @ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:62 [inlined] [2] pyerr_check @ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:66 [inlined] [3] _handle_error(msg::String) @ PyCall ~/.julia/packages/PyCall/7a7w0/src/exception.jl:83 [4] macro expansion @ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:97 [inlined] [5] #107 @ ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:43 [inlined] [6] disable_sigint @ ./c.jl:458 [inlined] [7] __pycall! @ ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:42 [inlined] [8] _pycall!(ret::PyObject, o::PyObject, args::Tuple{PyObject, NTuple{62, PyObject}, PyObject}, nargs::Int64, kw::Ptr{Nothing}) @ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:29 [9] _pycall!(ret::PyObject, o::PyObject, args::Tuple{PyObject, NTuple{62, PyObject}, PyObject}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:11 [10] (::PyObject)(::PyObject, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:86 [11] (::PyObject)(::PyObject, ::Vararg{Any}) @ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:86 [12] rrule(wrap::TorchModuleWrapper, args::Array{Float32, 4}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ PyCallChainRules.Torch ~/.julia/packages/PyCallChainRules/ebIKG/src/pytorch.jl:62 [13] rrule @ ~/.julia/packages/PyCallChainRules/ebIKG/src/pytorch.jl:57 [inlined] [14] rrule @ ~/.julia/packages/ChainRulesCore/RbX5a/src/rules.jl:134 [inlined] [15] chain_rrule @ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:216 [inlined] [16] macro expansion @ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0 [inlined] [17] _pullback(ctx::Zygote.Context, f::TorchModuleWrapper, args::Array{Float32, 4}) @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:9 [18] _pullback @ ~/.julia/dev/_InteractiveSessions/22_03/03_25_pychain_fastai.jl:86 [inlined] [19] _pullback(::Zygote.Context, ::var"#27#28") @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0 [20] _pullback(::Function) @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:34 [21] pullback(::Function) @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:40 [22] gradient(::Function) @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:75 [23] top-level scope @ ~/.julia/dev/_InteractiveSessions/22_03/03_25_pychain_fastai.jl:86

lorenzoh avatar Apr 27 '22 09:04 lorenzoh

https://github.com/rejuvyesh/PyCallChainRules.jl/blob/main/test/test_pytorch_hub.jl might be of interest. functorch recommends replacing in place batchnorm with other things like groupnorm which works equally well.

rejuvyesh avatar Apr 27 '22 15:04 rejuvyesh

I see! Thanks for sharing that, I'll try it out and get back here once I have a working FastAI.jl example. Feel free to close the issue, though

lorenzoh avatar Apr 27 '22 16:04 lorenzoh

I used the linked code to load a pretrained ResNet and the forward and backward passes work:

image

I then started training it using the standard FastAI.jl image classification which also works, however, after 50 or so steps, I get a CUDA out-of-memory error thrown by PyTorch. Since the training ran fine for 50 batches and reducing the batch size didn't help, I am assuming there is a GPU memory leak somewhere.

image

Have you run into this and have any advice on pinpointing or alleviating the problem? Thanks for your help!

lorenzoh avatar Apr 27 '22 18:04 lorenzoh

Would it be possible to share the script you are running? It's definitely possible that DLPack's memorypool is not freeing the tensors appropriately.

rejuvyesh avatar Apr 27 '22 18:04 rejuvyesh

Sure, here it is (adapted lines from test_pytorch_hub.jl included for completeness):

using PyCall, PyCallChainRules, Zygote, Flux
using PyCallChainRules.Torch: TorchModuleWrapper, torch
using FastAI
using CUDA

py"""
import torch
def bn2group(module):
    num_groups = 16 # hyper_parameter of GroupNorm
    module_output = module
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module_output = torch.nn.GroupNorm(num_groups,
                                           module.num_features,
                                           module.eps,
                                           module.affine,
                                          )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
    for name, child in module.named_children():
        module_output.add_module(name, bn2group(child))
    del module
    return module_output
"""


function loadresnet(c::Int)
    model = torch.hub.load("pytorch/vision", "resnet18")
    model.fc = torch.nn.Linear(model.fc.in_features, c)  # change number of output classes
    model_gn = py"bn2group"(model)
    return TorchModuleWrapper(model_gn)
end

Flux.gpu(m::TorchModuleWrapper) = fmap(CUDA.cu, m)


# FastAI.jl part
data, blocks = loaddataset("imagenette2-320")
task = ImageClassificationSingle(blocks)
learner = tasklearner(
    task, data;
    callbacks=[ToGPU()],
    batchsize=4,
    model=gpu(loadresnet(length(blocks[2].classes))))  # model being loaded here

# Training
fitonecycle!(learner, 1)

lorenzoh avatar Apr 27 '22 19:04 lorenzoh

Epoch 1 TrainingPhase(): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:03:50
┌───────────────┬───────┬─────────┐
│         Phase │ Epoch │    Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   1.0 │ 2.45761 │
└───────────────┴───────┴─────────┘
Epoch 1 ValidationPhase(): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:19
┌─────────────────┬───────┬───────┐
│           Phase │ Epoch │  Loss │
├─────────────────┼───────┼───────┤
│ ValidationPhase │   1.0 │ 2.464 │
└─────────────────┴───────┴───────┘

Again, seems to run correctly for me. I kept nvtop open on the side as well and I never saw memory usage go higher than 70% for me and was quite constant. I think with this and #18 the common factor seems to be older CUDA (and possibly older NVIDIA drivers) on your end?

Edit:

Also:

julia> torch.__version__
"1.11.0"
julia> functorch.__version__
"0.1.0"

in case that matters.

rejuvyesh avatar Apr 27 '22 20:04 rejuvyesh

I also let this run for a few more epochs. While I do see a slight uptick in memory usage with multiple epochs, it didn't become drastic enough to kill training.

rejuvyesh avatar Apr 27 '22 20:04 rejuvyesh

By the way, I found this page in the functorch docs which gives some options for dealing with batch norm layers that may be a bit more convenient, e.g.

from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

lorenzoh avatar Apr 28 '22 08:04 lorenzoh

Nice, I'll move to using this function then!

rejuvyesh avatar Apr 28 '22 16:04 rejuvyesh

I updated my CUDA drivers to 11.6, but am still experiencing memory leaks as described above :cry:

I also tried using a vanilla training loop to take FluxTraining.jl out of the equation, so the training loop is not an issue.

My updated CUDA version info:

CUDA.versioninfo()
CUDA toolkit 11.6, artifact installation
NVIDIA driver 510.47.3, for CUDA 11.6
CUDA driver 11.6

Libraries: 
- CUBLAS: 11.8.1
- CURAND: 10.2.9
- CUFFT: 10.7.0
- CUSOLVER: 11.3.2
- CUSPARSE: 11.7.1
- CUPTI: 16.0.0
- NVML: 11.0.0+510.47.3
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)

Toolchain:
- Julia: 1.8.0-beta3
- LLVM: 13.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0, 7.1, 7.2
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80, sm_86

1 device:
  0: NVIDIA GeForce GTX 1080 Ti (sm_61, 23.625 MiB / 11.000 GiB available)

lorenzoh avatar Apr 30 '22 17:04 lorenzoh

I put together a smaller MWE that reproduces the GPU OOM error:

using CUDA, PyCall, PyCallChainRules
using PyCallChainRules.Torch: TorchModuleWrapper, torch
fexp = pyimport("functorch.experimental")

model_py = torch.hub.load("pytorch/vision", "resnet18")
model_pygn = fexp.replace_all_batch_norm_modules_(model_py).to(device="cuda")
model = TorchModuleWrapper(model_pygn)

function memoryused()
    info = CUDA.MemoryInfo()
    return 1 - (info.free_bytes / info.total_bytes)
end

function oom()
    xs = cu(randn(Float32, 224, 224, 3, 16))
    usage = [memoryused()]
    try
        for _ in 1:1000
            model(xs)
            push!(usage, memoryused())
        end
    catch
    finally
        return usage
    end
end

oom()

Which produces linearly growing utilization values before the error:

image

@rejuvyesh any idea where this leak may be coming from or how to get started debugging this?

lorenzoh avatar May 04 '22 15:05 lorenzoh

Just wanted to comment that I can reproduce this, but haven't been able to get the time to figure out the reason. Likely need to create a reproducer with just DLPack.jl because this is just forward pass, with no gradients. DLPack.jl keeps a memorypool of shared tensors at: https://github.com/pabloferz/DLPack.jl/blob/2e491ac7e839a7428d817b652c4d525faa52ceac/src/DLPack.jl#L173 and we will need to track the state of this variable to figure out what's happening.

rejuvyesh avatar May 12 '22 03:05 rejuvyesh

This is definitely a bug in the dlpack interaction:

using CUDA, PyCall, DLPack, Functors

dlpack = pyimport("torch.utils.dlpack")
torch = pyimport("torch")
fexp = pyimport("functorch.experimental")

pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject


struct TorchModel
    fn::PyObject
end

function (wrap::TorchModel)(args...; kwargs...)
    return wrap.fn(fmap(x -> DLPack.share(x, PyObject, pyfrom_dlpack), args)...; kwargs...)
end


model_py = torch.hub.load("pytorch/vision", "resnet18")
model_pygn = fexp.replace_all_batch_norm_modules_(model_py).to(device="cuda")
model = TorchModel(model_pygn)

function memoryused()
    info = CUDA.MemoryInfo()
    return 1 - (info.free_bytes / info.total_bytes)
end

function oom()
    batchsize = 128
    usage = [memoryused()]
    try
        for _ in 1:1000
            xs = cu(randn(Float32, 224, 224, 3, batchsize))
            model(xs)
            push!(usage, memoryused())
        end
    catch
    finally
        return usage
    end
end

oom()

also fails? Might need to change the batchsize for your GPU.

rejuvyesh avatar May 16 '22 01:05 rejuvyesh

Hi @rejuvyesh

I could run your code on my machine. it seems the value of memoryused() increases gradually.

julia> oom()
 0.1589230896872148
 0.3943831411266905
 0.3996930224840288
 0.3996930224840288
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.41296772587737496
 0.41296772587737496
 0.41296772587737496
 ⋮
 0.8988218700738405
 0.8988218700738405
 0.8988218700738405
 0.8988218700738405
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9120965734671866
 0.9120965734671866

(EDIT)

Here is my hardware information.

julia> CUDA.versioninfo()
CUDA toolkit 11.7, artifact installation
NVIDIA driver 510.60.2, for CUDA 11.6
CUDA driver 11.6

Libraries: 
- CUBLAS: 11.10.1
- CURAND: 10.2.10
- CUFFT: 10.7.2
- CUSOLVER: 11.3.5
- CUSPARSE: 11.7.3
- CUPTI: 17.0.0
- NVML: 11.0.0+510.60.2
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)

Toolchain:
- Julia: 1.7.2
- LLVM: 12.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80

2 devices:
  0: NVIDIA GeForce RTX 3060 (sm_86, 11.762 GiB / 12.000 GiB available)
  1: NVIDIA GeForce RTX 3060 (sm_86, 11.752 GiB / 12.000 GiB available)

(tmp) pkg> st
      Status `~/tmp/Project.toml`
  [fbb218c0] BSON v0.3.5
  [052768ef] CUDA v3.10.0
  [53c2dc0f] DLPack v0.1.1
  [2e981812] DataLoaders v0.1.3
  [587475ba] Flux v0.13.1
  [d9f16b24] Functors v0.2.8
  [dbeba491] Metalhead v0.7.1
  [3bd65402] Optimisers v0.2.4
  [92933f4c] ProgressMeter v1.7.2
  [438e738f] PyCall v1.93.1
  [b12ccfe2] PyCallChainRules v0.3.2
  [e88e6eb3] Zygote v0.6.40
Python 3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import functorch
>>> torch.__version__
'1.11.0+cu113'
>>> functorch.__version__
'0.1.1'

terasakisatoshi avatar May 18 '22 15:05 terasakisatoshi

Thanks for looking into this. I am currently away from my GPU workstation, but can hopefully investigate myself next week.

This is definitely a bug in the dlpack interaction

Can you tell if this is an issue in DLPack.jl or just the way it's used? If the former, I guess we should open an issue there?

lorenzoh avatar May 20 '22 08:05 lorenzoh

The issue is definitely just with dlpack.jl and pytorch interaction and maybe the issue should go there. I need to check whether this happens with Jax as well.

rejuvyesh avatar May 20 '22 21:05 rejuvyesh

@terasakisatoshi can you increase the batchsize and see if it actually OOMs (earlier)?

rejuvyesh avatar May 20 '22 21:05 rejuvyesh

@rejuvyesh

Setting batchsize=384 returns oom384() function before reaching length(usage) == 1000.

julia> function oom384()
           batchsize = 384
           usage = [memoryused()]
           try
               for _ in 1:1000
                   xs = cu(randn(Float32, 224, 224, 3, batchsize))
                   model(xs)
                   push!(usage, memoryused())
               end
           catch
           finally
               return usage
           end
       end
oom384 (generic function with 1 method)

julia> oom384()
101-element Vector{Float64}:
 0.1589230896872148
 0.8455571227080395
 0.864141707458724
 0.864141707458724
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 ⋮
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538

julia>

terasakisatoshi avatar May 21 '22 13:05 terasakisatoshi

I don't think this is the best solution, but the following function notoom, just added GC.gc(false) for each loop, does not occur OOM.

using ProgressMeter
function notoom()
    batchsize = 384
    usage = [memoryused()]
    @showprogress for _ in 1:1000
        xs = cu(randn(Float32, 224, 224, 3, batchsize))
        model(xs)
        push!(usage, memoryused())
        GC.gc(false) # <---
    end
    @assert length(usage) == 1+1000
end

julia> notoom()
Progress: 100%|█████████████████████████████████████████| Time: 0:11:57

terasakisatoshi avatar May 22 '22 06:05 terasakisatoshi

Interesting 🤔. I wonder if it would be enough to do that every few steps only and what the memory usage curve would look like

lorenzoh avatar May 22 '22 08:05 lorenzoh

This is not really an issue with DLPack.jl nor with PyCallChainRules.jl or PyCall.jl. As mentioned in https://github.com/pabloferz/DLPack.jl/issues/26 It's just that julia has no way of knowing how often it should garbage collect PyObjects in general.

@terasakisatoshi example above is the correct way of handling this.

p-zubieta avatar Jul 16 '22 00:07 p-zubieta