compressors
compressors copied to clipboard
A small library with distillation, quantization and pruning pipelines
Compressors
Warning! Alpha version! This is not product-ready solution so far.
Compressors is a library with a lot of pipelines connected with model compression without significantly performance lose.
- Compressors
- Why Compressors?
- Install
- Features
- Distillation
- Pruning
- Minimal Examples
- Distillation
- MNIST
- CIFAR100 ResNet
- AG NEWS BERT (transformers)
- Pruning
- Distillation
- Examples
Why Compressors?
Compressors provides many ways to compress your model. You can use it for CV and NLP task.
Library separated into three parts:
- Distillation
- Pruning
- Quantization
There are two ways to use Compressors: with Catalyst or just use functional API.
Install
pip install git+https://github.com/elephantmipt/compressors.git
Features
Distillation
| Name | References | Status |
|---|---|---|
| KL-divergence | Hinton et al. | Implemented |
| MSE | Hinton et al. | Implemented |
| Probabilistic KT | Passalis et al. | Implemented |
| Cosine | ??? | Implemented |
| Attention Transfer | Zagoruyko et al. | Implemented |
| Constrative Representation Distillation | Tian et al. | Implemented (without dataset) |
| Probablility Shift | Wen et al. | Implemented and tested |
Pruning
| Name | References | Status |
|---|---|---|
| Lottery ticket hypothesis | Frankle et al. | Implemented |
| Iterative pruning | Paganini et al. | Implemented |
Minimal Examples
Distillation
MNIST
from itertools import chain
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T
from catalyst.contrib.datasets import MNIST
from catalyst.callbacks import AccuracyCallback, OptimizerCallback
from compressors.distillation.runners import EndToEndDistilRunner
from compressors.models import MLP
from compressors.utils.data import TorchvisionDatasetWrapper as Wrp
teacher = MLP(num_layers=4)
student = MLP(num_layers=3)
datasets = {
"train": Wrp(MNIST("./data", train=True, download=True, transform=T.ToTensor())),
"valid": Wrp(MNIST("./data", train=False, transform=T.ToTensor())),
}
loaders = {
dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
for dl_key, dataset in datasets.items()
}
optimizer = torch.optim.Adam(chain(teacher.parameters(), student.parameters()))
runner = EndToEndDistilRunner(
hidden_state_loss="mse",
num_train_teacher_epochs=5
)
runner.train(
model = torch.nn.ModuleDict({"teacher": teacher, "student": student}),
loaders=loaders,
optimizer=optimizer,
num_epochs=4,
callbacks=[
OptimizerCallback(metric_key="loss"),
AccuracyCallback(input_key="logits", target_key="targets")
],
valid_metric="accuracy01",
minimize_valid_metric=False,
logdir="./logs",
valid_loader="valid",
criterion=torch.nn.CrossEntropyLoss()
)
CIFAR100 ResNet
from catalyst.callbacks import (
AccuracyCallback,
ControlFlowCallback,
CriterionCallback,
OptimizerCallback,
SchedulerCallback,
)
import torch
from torch.hub import load_state_dict_from_url
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR100
from compressors.distillation.callbacks import (
AttentionHiddenStatesCallback,
KLDivCallback,
MetricAggregationCallback,
)
from compressors.distillation.runners import DistilRunner
from compressors.models.cv import resnet_cifar_8, resnet_cifar_56
from compressors.utils.data import TorchvisionDatasetWrapper as Wrp
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
datasets = {
"train": Wrp(CIFAR100(root=".", train=True, download=True, transform=transform_train)),
"valid": Wrp(CIFAR100(root=".", train=False, transform=transform_test)),
}
loaders = {
k: DataLoader(v, batch_size=32, shuffle=k == "train", num_workers=2)
for k, v in datasets.items()
}
teacher_sd = load_state_dict_from_url(
"https://github.com/chenyaofo/CIFAR-pretrained-models/releases/download/resnet/cifar100-resnet56-2f147f26.pth"
)
teacher_model = resnet_cifar_56(num_classes=100)
teacher_model.load_state_dict(teacher_sd)
student_model = resnet_cifar_8(num_classes=100)
optimizer = torch.optim.SGD(
student_model.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
runner = DistilRunner(apply_probability_shift=True)
runner.train(
model={"teacher": teacher_model, "student": student_model},
loaders=loaders,
optimizer=optimizer,
scheduler=scheduler,
valid_metric="accuracy",
minimize_valid_metric=False,
logdir="./cifar100_logs",
callbacks=[
ControlFlowCallback(AttentionHiddenStatesCallback(), loaders="train"),
ControlFlowCallback(KLDivCallback(temperature=4), loaders="train"),
CriterionCallback(input_key="s_logits", target_key="targets", metric_key="cls_loss"),
ControlFlowCallback(
MetricAggregationCallback(
prefix="loss",
metrics={
"attention_loss": 1000,
"kl_div_loss": 0.9,
"cls_loss": 0.1,
},
mode="weighted_sum",
),
loaders="train",
),
AccuracyCallback(input_key="s_logits", target_key="targets"),
OptimizerCallback(metric_key="loss", model_key="student"),
SchedulerCallback(),
],
valid_loader="valid",
num_epochs=200,
criterion=torch.nn.CrossEntropyLoss(),
)
AG NEWS BERT (transformers)
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, load_metric
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from catalyst.callbacks import ControlFlowCallback, OptimizerCallback
from catalyst.callbacks.metric import LoaderMetricCallback
from compressors.distillation.callbacks import (
HiddenStatesSelectCallback,
KLDivCallback,
LambdaPreprocessCallback,
MetricAggregationCallback,
MSEHiddenStatesCallback,
)
from compressors.distillation.runners import HFDistilRunner
from compressors.metrics.hf_metric import HFMetric
from compressors.runners.hf_runner import HFRunner
datasets = load_dataset("ag_news")
tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-4_H-128_A-2")
datasets = datasets.map(
lambda e: tokenizer(e["text"], truncation=True, padding="max_length", max_length=128),
batched=True,
)
datasets = datasets.map(lambda e: {"labels": e["label"]}, batched=True)
datasets.set_format(
type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"],
)
loaders = {
"train": DataLoader(datasets["train"], batch_size=64, shuffle=True),
"valid": DataLoader(datasets["test"], batch_size=64),
}
metric_callback = LoaderMetricCallback(
metric=HFMetric(metric=load_metric("accuracy")), input_key="logits", target_key="labels",
)
################### Teacher Training #####################
teacher_model = AutoModelForSequenceClassification.from_pretrained(
"google/bert_uncased_L-4_H-128_A-2", num_labels=4
)
runner = HFRunner()
runner.train(
model=teacher_model,
loaders=loaders,
optimizer=torch.optim.Adam(teacher_model.parameters(), lr=1e-4),
callbacks=[metric_callback],
num_epochs=5,
valid_metric="accuracy",
minimize_valid_metric=False,
verbose=True
)
############### Distillation ##################
slct_callback = ControlFlowCallback(
HiddenStatesSelectCallback(hiddens_key="t_hidden_states", layers=[1, 3]), loaders="train",
)
lambda_hiddens_callback = ControlFlowCallback(
LambdaPreprocessCallback(
lambda s_hiddens, t_hiddens: (
[c_s[:, 0] for c_s in s_hiddens],
[t_s[:, 0] for t_s in t_hiddens], # tooks only CLS token
)
),
loaders="train",
)
mse_hiddens = ControlFlowCallback(MSEHiddenStatesCallback(), loaders="train")
kl_div = ControlFlowCallback(KLDivCallback(temperature=4), loaders="train")
aggregator = ControlFlowCallback(
MetricAggregationCallback(
prefix="loss",
metrics={"kl_div_loss": 0.2, "mse_loss": 0.2, "task_loss": 0.6},
mode="weighted_sum",
),
loaders="train",
)
runner = HFDistilRunner()
student_model = AutoModelForSequenceClassification.from_pretrained(
"google/bert_uncased_L-2_H-128_A-2", num_labels=4
)
metric_callback = LoaderMetricCallback(
metric=HFMetric(metric=load_metric("accuracy")), input_key="s_logits", target_key="labels",
)
runner.train(
model=torch.nn.ModuleDict({"teacher": teacher_model, "student": student_model}),
loaders=loaders,
optimizer=torch.optim.Adam(student_model.parameters(), lr=1e-4),
callbacks=[
metric_callback,
slct_callback,
lambda_hiddens_callback,
mse_hiddens,
kl_div,
aggregator,
OptimizerCallback(metric_key="loss"),
],
num_epochs=5,
valid_metric="accuracy",
minimize_valid_metric=False,
valid_loader="valid",
verbose=True
)
Pruning
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from catalyst.callbacks import (
PruningCallback,
OptimizerCallback,
CriterionCallback,
AccuracyCallback,
ControlFlowCallback
)
from catalyst.contrib.datasets import MNIST
from compressors.distillation.callbacks import MetricAggregationCallback
from compressors.distillation.callbacks import KLDivCallback
from compressors.models import MLP
from compressors.pruning.runners import PruneRunner
from compressors.utils.data import TorchvisionDatasetWrapper as Wrp
model = MLP(num_layers=3)
model = model.load_state_dict(torch.load("trained_model.pth"))
datasets = {
"train": Wrp(MNIST("./data", train=True, download=True, transform=ToTensor())),
"valid": Wrp(MNIST("./data", train=False, transform=ToTensor())),
}
loaders = {
dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
for dl_key, dataset in datasets.items()
}
optimizer = torch.optim.Adam(model.parameters())
runner = PruneRunner(num_sessions=10)
runner.train(
model=model,
loaders=loaders,
optimizer=optimizer,
criterion=torch.nn.CrossEntropyLoss(),
callbacks=[
PruningCallback(pruning_fn="l1_unstructured", amount=0.2, remove_reparametrization_on_stage_end=False),
OptimizerCallback(metric_key="loss"),
CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
AccuracyCallback(input_key="logits", target_key="targets"),
],
logdir="./pruned_model",
valid_loader="valid",
valid_metric="accuracy",
minimize_valid_metric=False,
)