nn_pruning
nn_pruning copied to clipboard
nn_pruning doesn't seem to work for T5 Models, Roberta-based Models
Hi @madlag @julien-c @co42 @srush @Narsil
I am trying to use nn_pruning for Pruning different transformer models.
Code:
model_checkpoint = "t5-small"
t5small_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
mpc.patch_model(t5small_model)
t5small_model.save_pretrained("models/patched")
Error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
[<ipython-input-47-602943fc51a1>](https://localhost:8080/#) in <module>()
1
2 t5small_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
----> 3 mpc.patch_model(t5small_model)
4
5 t5small_model.save_pretrained("models/patched")
[/usr/local/lib/python3.7/dist-packages/nn_pruning/patch_coordinator.py](https://localhost:8080/#) in patch_model(self, model, trial)
640 patched_count += 2 * layers_count
641
--> 642 assert (patcher.stats["patched"] == patched_count)
643
644 if layer_norm_patch:
AssertionError:
[Colab] (https://colab.research.google.com/drive/1Gz7rozG8NbeBtsiWXjGNQ5wnVU7SE_Wl?usp=sharing)
Hi @madlag @julien-c @co42 @srush @Narsil
I am trying to use
nn_pruningfor Pruning different transformer models.Code:
model_checkpoint = "t5-small" t5small_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device) mpc.patch_model(t5small_model) t5small_model.save_pretrained("models/patched")Error:
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) [<ipython-input-47-602943fc51a1>](https://localhost:8080/#) in <module>() 1 2 t5small_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device) ----> 3 mpc.patch_model(t5small_model) 4 5 t5small_model.save_pretrained("models/patched") [/usr/local/lib/python3.7/dist-packages/nn_pruning/patch_coordinator.py](https://localhost:8080/#) in patch_model(self, model, trial) 640 patched_count += 2 * layers_count 641 --> 642 assert (patcher.stats["patched"] == patched_count) 643 644 if layer_norm_patch: AssertionError:[Colab] (https://colab.research.google.com/drive/1Gz7rozG8NbeBtsiWXjGNQ5wnVU7SE_Wl?usp=sharing)
I have the same problem, did you fix it? @shubham-krishna
I have the same problem, did you fix it? @robotsp @ghost
sorry I don't know, I can't really help.