Torch-Pruning
Torch-Pruning copied to clipboard
Infinite Loop in Dependency Graph Construction When Pruning Qwen2.5VL's Visual Module
Problem Description
I'm encountering problems when pruning the Qwen2.5VL model using Torch-Pruning, specifically with its visual component (maybe I guess). The dependency graph construction enters an endless loop. Here's my observation:
Observation
-
Basic Pruning Attempt:
- With simple text inputs (
torch.randint(0, 100000, (3, 28, 28)), the language model (LM) prunes correctly but the visual module remains unpruned. This causes dimension mismatch between the Merger's output features and the lm_head's input features.
- With simple text inputs (
-
Multimodal Input Adjustment:
- When using proper multimodal inputs containing both
pixel_valuesandimage_grid_thw, the visual module gets included in the dependency graph. However, the pruning process enters an infinite loop during_fix_dependency_graph_non_recursive:
Traceback (most recent call last): [...] File "/path/to/dependency.py", line 504, in _fix_dependency_graph_non_recursive if (new_dep.target in visited_node) and group.has_pruning_op( KeyboardInterrupt - When using proper multimodal inputs containing both
-
Rotary Position Embedding Experiments:
- Inspired by SlimSAM's approach (#337), I tried removing
rotary_pos_embbut encountered attribute error inpruner = tp.pruner.MetaPruner:
AttributeError: 'Qwen2_5_VisionTransformerPretrainedModel' object has no attribute 'rotary_pos_emb'- Then I tried to make a dummy rotary_pos_emb:
class DummyRotaryPosEmb(torch.nn.Module): def __init__(self, original_module): super().__init__() self.original_module = original_module self.inv_freq = torch.nn.Parameter(torch.randn(64)) if hasattr(original_module, 'inv_freq') else None def forward(self, *args, **kwargs): shape = self.original_module(*args, **kwargs).shape if self.original_module else (1, 16, 1280) return torch.zeros(shape, device=args[0].device)This avoids the attribute error but still leads to infinite looping.
- Inspired by SlimSAM's approach (#337), I tried removing
Anyway, what is the correct way to prune Qwen2.5VL? Am I making a mistake with my current approach?
Code
Here's my whole code:
import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch_pruning as tp
from qwen_vl_utils import process_vision_info
def prune_model(model, processor, pruning_ratio):
class DummyRotaryPosEmb(torch.nn.Module):
def __init__(self, original_module):
super().__init__()
self.original_module = original_module
# 添加必要的假参数避免None判断
self.inv_freq = torch.nn.Parameter(torch.randn(64)) if hasattr(original_module, 'inv_freq') else None
def forward(self, *args, **kwargs):
# 返回与原始模块相同形状的零张量
if self.original_module is not None:
shape = self.original_module(*args, **kwargs).shape
else:
shape = (1, 16, 1280)
return torch.zeros(shape, device=args[0].device)
# 保存并替换位置编码模块
original_visual_rotary = model.visual.rotary_pos_emb
model.visual.rotary_pos_emb = DummyRotaryPosEmb(original_visual_rotary)
# 配置LM注意力头参数,视觉部分无需处理
num_heads = {}
for name, module in model.named_modules():
if name.endswith("self_attn"):
num_heads[module.q_proj] = model.config.num_attention_heads
num_heads[module.k_proj] = model.config.num_key_value_heads
num_heads[module.v_proj] = model.config.num_key_value_heads
# 通道分组,LM部分无需处理,视觉部分需要处理多头注意力
out_channel_groups = {}
for m in model.modules():
if isinstance(m, model.visual.blocks[0].attn.__class__):
out_channel_groups[m.qkv] = m.num_heads * 3 # 是否需要乘3?
print("out_channel_groups", out_channel_groups)
importance = tp.importance.GroupNormImportance(p=2) #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear])
# 处理未封装的参数
unwrapped_parameters = []
for name, param in model.named_parameters():
if 'norm' in name and 'weight' in name:
unwrapped_parameters.append( (param, 0) ) # 指定剪枝维度为0
if 'ln_q' in name and 'weight' in name:
unwrapped_parameters.append( (param, 0) ) # 指定剪枝维度为0
# 忽略最后的lm_head
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 151936:
ignored_layers.append(m)
print("ignored_layers", ignored_layers)
# 构建输入
# example_inputs = torch.randint(0, 100000, (3, 28, 28), dtype=torch.long, device='cuda:0')
# 正确构造多模态输入
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": url,
},
{"type": "text", "text": "Describe this image"},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, _ = process_vision_info(messages)
processed_inputs = processor(
text=[text],
images=image_inputs,
return_tensors="pt",
padding=True,
).to(model.device)
example_inputs = {
"input_ids": processed_inputs.input_ids,
"attention_mask": processed_inputs.attention_mask,
"pixel_values": processed_inputs.pixel_values,
"image_grid_thw": processed_inputs.image_grid_thw
}
# 创建剪枝器
model.config.use_cache = False
pruner = tp.pruner.MetaPruner(
model,
example_inputs=example_inputs,
importance=importance,
global_pruning=False,
pruning_ratio=pruning_ratio,
ignored_layers=ignored_layers,
num_heads=num_heads,
prune_num_heads=True,
head_pruning_ratio=pruning_ratio,
prune_head_dims=True,
out_channel_groups=out_channel_groups,
round_to=4,
unwrapped_parameters=unwrapped_parameters,
)
# 执行剪枝
for g in pruner.step(interactive=True):
# print(g)
g.prune()
# 恢复位置编码模块
model.visual.rotary_pos_emb = original_visual_rotary
return model
def main():
model_path = "***/Qwen2.5-VL-3B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
# torch_dtype=torch.bfloat16,
device_map="cuda:0"
)
processor = AutoProcessor.from_pretrained(model_path)
print("========= Before Pruning =========")
print(model)
print("Starting pruning process...")
pruned_model = prune_model(model, processor, pruning_ratio=0.428571428)
print("========= After Pruning =========")
print(pruned_model)
if __name__ == "__main__":
main()
你解决了这个问题了吗?