Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Infinite Loop in Dependency Graph Construction When Pruning Qwen2.5VL's Visual Module

Open zrrraa opened this issue 10 months ago • 1 comments

Problem Description

I'm encountering problems when pruning the Qwen2.5VL model using Torch-Pruning, specifically with its visual component (maybe I guess). The dependency graph construction enters an endless loop. Here's my observation:

Observation

  1. Basic Pruning Attempt:

    • With simple text inputs (torch.randint(0, 100000, (3, 28, 28)), the language model (LM) prunes correctly but the visual module remains unpruned. This causes dimension mismatch between the Merger's output features and the lm_head's input features.
  2. Multimodal Input Adjustment:

    • When using proper multimodal inputs containing both pixel_values and image_grid_thw, the visual module gets included in the dependency graph. However, the pruning process enters an infinite loop during _fix_dependency_graph_non_recursive:
    Traceback (most recent call last):
      [...]
      File "/path/to/dependency.py", line 504, in _fix_dependency_graph_non_recursive
      if (new_dep.target in visited_node) and group.has_pruning_op(
    KeyboardInterrupt
    
  3. Rotary Position Embedding Experiments:

    • Inspired by SlimSAM's approach (#337), I tried removing rotary_pos_emb but encountered attribute error in pruner = tp.pruner.MetaPruner:
    AttributeError: 'Qwen2_5_VisionTransformerPretrainedModel' object has no attribute 'rotary_pos_emb'
    
    • Then I tried to make a dummy rotary_pos_emb:
    class DummyRotaryPosEmb(torch.nn.Module):
        def __init__(self, original_module):
            super().__init__()
            self.original_module = original_module
            self.inv_freq = torch.nn.Parameter(torch.randn(64)) if hasattr(original_module, 'inv_freq') else None
    
        def forward(self, *args, **kwargs):
            shape = self.original_module(*args, **kwargs).shape if self.original_module else (1, 16, 1280)
            return torch.zeros(shape, device=args[0].device)
    

    This avoids the attribute error but still leads to infinite looping.

Anyway, what is the correct way to prune Qwen2.5VL? Am I making a mistake with my current approach?

Code

Here's my whole code:

import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch_pruning as tp
from qwen_vl_utils import process_vision_info

def prune_model(model, processor, pruning_ratio):
    
    class DummyRotaryPosEmb(torch.nn.Module):
        def __init__(self, original_module):
            super().__init__()
            self.original_module = original_module
            # 添加必要的假参数避免None判断
            self.inv_freq = torch.nn.Parameter(torch.randn(64)) if hasattr(original_module, 'inv_freq') else None
            
        def forward(self, *args, **kwargs):
            # 返回与原始模块相同形状的零张量
            if self.original_module is not None:
                shape = self.original_module(*args, **kwargs).shape
            else:
                shape = (1, 16, 1280)
            return torch.zeros(shape, device=args[0].device)
    
    # 保存并替换位置编码模块
    original_visual_rotary = model.visual.rotary_pos_emb
    model.visual.rotary_pos_emb = DummyRotaryPosEmb(original_visual_rotary)
    
    # 配置LM注意力头参数,视觉部分无需处理
    num_heads = {}
    for name, module in model.named_modules():
        if name.endswith("self_attn"):
            num_heads[module.q_proj] = model.config.num_attention_heads
            num_heads[module.k_proj] = model.config.num_key_value_heads
            num_heads[module.v_proj] = model.config.num_key_value_heads

    # 通道分组,LM部分无需处理,视觉部分需要处理多头注意力
    out_channel_groups = {}
    for m in model.modules():
        if isinstance(m, model.visual.blocks[0].attn.__class__):
            out_channel_groups[m.qkv] = m.num_heads * 3 # 是否需要乘3?
    print("out_channel_groups", out_channel_groups)

    importance = tp.importance.GroupNormImportance(p=2) #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear])
    
    # 处理未封装的参数
    unwrapped_parameters = []
    for name, param in model.named_parameters():
        if 'norm' in name and 'weight' in name:
            unwrapped_parameters.append( (param, 0) )  # 指定剪枝维度为0
        if 'ln_q' in name and 'weight' in name:
            unwrapped_parameters.append( (param, 0) )  # 指定剪枝维度为0
    
    # 忽略最后的lm_head
    ignored_layers = []
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == 151936:
            ignored_layers.append(m)
    print("ignored_layers", ignored_layers)

    # 构建输入
    # example_inputs = torch.randint(0, 100000, (3, 28, 28), dtype=torch.long, device='cuda:0')

    # 正确构造多模态输入
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": url,
                },
                {"type": "text", "text": "Describe this image"},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, _ = process_vision_info(messages)
    processed_inputs = processor(
        text=[text],
        images=image_inputs,
        return_tensors="pt",
        padding=True,
    ).to(model.device)

    example_inputs = {
        "input_ids": processed_inputs.input_ids,
        "attention_mask": processed_inputs.attention_mask,
        "pixel_values": processed_inputs.pixel_values,
        "image_grid_thw": processed_inputs.image_grid_thw
    }

    # 创建剪枝器
    model.config.use_cache = False
    pruner = tp.pruner.MetaPruner(
        model,
        example_inputs=example_inputs,
        importance=importance,
        global_pruning=False,
        pruning_ratio=pruning_ratio,
        ignored_layers=ignored_layers,
        num_heads=num_heads,
        prune_num_heads=True,
        head_pruning_ratio=pruning_ratio,
        prune_head_dims=True,
        out_channel_groups=out_channel_groups,
        round_to=4,
        unwrapped_parameters=unwrapped_parameters,
    )
    
    # 执行剪枝
    for g in pruner.step(interactive=True):
        # print(g)
        g.prune()

    # 恢复位置编码模块
    model.visual.rotary_pos_emb = original_visual_rotary

    return model
            
def main():
    model_path = "***/Qwen2.5-VL-3B-Instruct"
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path,
        # torch_dtype=torch.bfloat16,
        device_map="cuda:0"
    )
    processor = AutoProcessor.from_pretrained(model_path)

    print("========= Before Pruning =========")
    print(model)
    
    print("Starting pruning process...")
    pruned_model = prune_model(model, processor, pruning_ratio=0.428571428)
    print("========= After Pruning =========")
    print(pruned_model)

if __name__ == "__main__":
    main()

zrrraa avatar Mar 19 '25 16:03 zrrraa

你解决了这个问题了吗?

Junzhou-Chen avatar Jul 11 '25 16:07 Junzhou-Chen