RecursionError during pruning
Hi - thanks for a wonderful tool. I am trying to test it out with a pretrained model from here. However I am encountering the following error:
module name: encode1.conv0
pruning_idxs: [4, 5, 8, 9, 13, 14, 16, 17, 18, 19, 21, 23, 24, 26, 28, 30, 31, 32, 35, 36, 37, 38, 39, 40, 41, 43, 46, 48, 49, 53, 56, 58]
Traceback (most recent call last):
File "/home/nikhil/projects/green_comp_neuro/FastSurfer/FastSurferCNN/torch_prune_test.py", line 159, in <module>
load_pretrained(pretrained_ckpt, params_model, model, dummy_data, save_path)
File "/home/nikhil/projects/green_comp_neuro/FastSurfer/FastSurferCNN/torch_prune_test.py", line 120, in load_pretrained
model = torch_prune(model, dummy_data, params_model['prune_type'], params_model['prune_percent'])
File "/home/nikhil/projects/green_comp_neuro/FastSurfer/FastSurferCNN/torch_prune_test.py", line 92, in torch_prune
pruning_plan = DG.get_pruning_plan( module, tp.prune_conv, idxs=pruning_idxs )
File "../../Torch-Pruning/torch_pruning/dependency.py", line 398, in get_pruning_plan
_fix_denpendency_graph(root_node, pruning_fn, idxs)
File "../../Torch-Pruning/torch_pruning/dependency.py", line 397, in _fix_denpendency_graph
_fix_denpendency_graph(dep.broken_node, dep.handler, new_indices)
File "../../Torch-Pruning/torch_pruning/dependency.py", line 397, in _fix_denpendency_graph
_fix_denpendency_graph(dep.broken_node, dep.handler, new_indices)
File "../../Torch-Pruning/torch_pruning/dependency.py", line 397, in _fix_denpendency_graph
_fix_denpendency_graph(dep.broken_node, dep.handler, new_indices)
[Previous line repeated 990 more times]
File "../../Torch-Pruning/torch_pruning/dependency.py", line 387, in _fix_denpendency_graph
new_indices = dep.index_transform(indices)
File "../../Torch-Pruning/torch_pruning/dependency.py", line 148, in __call__
if self.reverse==True:
RecursionError: maximum recursion depth exceeded in comparison
The network architecture is based on this paper. Here is a figure showing the details:

Below is my test script that uses the model definition and pretrained weights from the model repo
# IMPORTS
import argparse
import nibabel as nib
import numpy as np
from datetime import datetime
import time
import sys
import os
import glob
import os.path as op
import logging
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms, utils
from scipy.ndimage.filters import median_filter, gaussian_filter
from skimage.measure import label, regionprops
from skimage.measure import label
from collections import OrderedDict
from os import makedirs
from models.networks import FastSurferCNN
import pandas as pd
# torch-pruning
sys.path.append('../../Torch-Pruning')
import torch_pruning as tp
def options_parse():
"""
Command line option parser
"""
parser = argparse.ArgumentParser()
# Options for model parameters setup (only change if model training was changed)
parser.add_argument('--num_filters', type=int, default=64,
help='Filter dimensions for DenseNet (all layers same). Default=64')
parser.add_argument('--num_classes_ax_cor', type=int, default=79,
help='Number of classes to predict in axial and coronal net, including background. Default=79')
parser.add_argument('--num_classes_sag', type=int, default=51,
help='Number of classes to predict in sagittal net, including background. Default=51')
parser.add_argument('--num_channels', type=int, default=7,
help='Number of input channels. Default=7 (thick slices)')
parser.add_argument('--kernel_height', type=int, default=5, help='Height of Kernel (Default 5)')
parser.add_argument('--kernel_width', type=int, default=5, help='Width of Kernel (Default 5)')
parser.add_argument('--stride', type=int, default=1, help="Stride during convolution (Default 1)")
parser.add_argument('--stride_pool', type=int, default=2, help="Stride during pooling (Default 2)")
parser.add_argument('--pool', type=int, default=2, help='Size of pooling filter (Default 2)')
sel_option = parser.parse_args()
return sel_option
def torch_prune(model,dummy_data,prune_type,prune_percent):
print(f'compressing model with prune type: {prune_type}, sparsity: {prune_percent}')
# 1. setup strategy (L1 Norm)
strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()
# 2. build layer dependency for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=dummy_data)
# 3. get a pruning plan from the dependency graph.
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
print(f'module name: {name}')
pruning_idxs = strategy(module.weight, amount=prune_percent) # or manually selected pruning_idxs=[2, 6, 9, ...]
print(f'pruning_idxs: {pruning_idxs}')
pruning_plan = DG.get_pruning_plan( module, tp.prune_conv, idxs=pruning_idxs )
print(pruning_plan)
# 4. execute this plan (prune the model)
pruning_plan.exec()
def load_pretrained(pretrained_ckpt, params_model, model):
model_state = torch.load(pretrained_ckpt, map_location=params_model["device"])
new_state_dict = OrderedDict()
# FastSurfer model specific configs
for k, v in model_state["model_state_dict"].items():
if k[:7] == "module." and not params_model["model_parallel"]:
new_state_dict[k[7:]] = v
elif k[:7] != "module." and params_model["model_parallel"]:
new_state_dict["module." + k] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval()
return model
if __name__ == "__main__":
args = options_parse()
plane = "Axial"
pretrained_ckpt = f'../checkpoints/{plane}_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl'
# Put it onto the GPU or CPU
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# Set up model for axial and coronal networks
params_model = {'num_channels': args.num_channels, 'num_filters': args.num_filters,
'kernel_h': args.kernel_height, 'kernel_w': args.kernel_width,
'stride_conv': args.stride, 'pool': args.pool,
'stride_pool': args.stride_pool, 'num_classes': args.num_classes_ax_cor,
'kernel_c': 1, 'kernel_d': 1,
'model_parallel': False,
'device': device
}
# Select the model
model = FastSurferCNN(params_model)
model.to(device)
# Load pretrained weights
model = load_pretrained(pretrained_ckpt, params_model, model)
# Prune model
dummy_data = torch.ones(1, 7, 256, 256)
model = torch_prune(model, dummy_data, prune_type='L1', prune_percent=0.5)
# Save pruned model
# save_path = f'./{plane}_pruned.pth'
# torch.save(model, save_path)
I will appreciate any help or suggestions! Thanks!
HI @VainF - Just checking if you have any update for this issue. Thanks!
This is not a bug with the cod per say. The recursion depth for python is 990 calls, it is set to avoid stackoverflow.
File "../../Torch-Pruning/torch_pruning/dependency.py", line 397, in _fix_denpendency_graph _fix_denpendency_graph(dep.broken_node, dep.handler, new_indices) [Previous line repeated 990 more times]
As you can see it got the RecursionError after 990 calls of _fix_dependency_graph
You can try and change the recursion depth for your python or alter the pruning percentage (pruning of 50% of the smallest weights is the reason why so many indices are being chosen for pruning , which leads to stackoverflow).
If you have a huge model or combination of models, then it throws a recursion error. Would it be possible to implement a dependency graph without recursion? @VainF
If you have a huge model or combination of models, then it throws a recursion error. Would it be possible to implement a dependency graph without recursion? @VainF
Hi @vinayak-sharan , thank you for your advice. I will try to re-implement it in the next version.
Hi everyone, the non-recursive implementation of dependency graph has been uploaded. I will keep the issue open for further discussion!