Bug: Performance bottleneck in Pixel-Flipping algorithm
The Pixel-Flipping metric might have significant room for improvement regarding its performance because it is taking longer than expected for a batch of only one element. I'm not sure what the root cause of this issue might be.
Expected behavior
Calculation of Pixel-Flipping metric using Quantus for 100 steps should finish within 2 minutes of execution, which is roughly what my own implementation of Pixel-Flipping (see demo.ipynb) takes.
Current behavior
Calculation had run for 45 minutes (still wasn't finished) when I interrupted the execution.
Reproduction steps
Prerequisites
-
Download and unzip the pr-attachments.zip, which contains the tensors
XandR—as.ptfiles—used in the Minimal Working Example (MWE) below. -
Load tensors
XandRusingtorch.load
MWE
import quantus
import torchvision
import numpy
import torch
from typing import Union, Dict
# Init required arguments
input: torch.Tensor = X.clone().detach()
x_batch: numpy.ndarray = input.numpy()
y_batch: numpy.ndarray = numpy.array([483])
a_batch: numpy.ndarray = R.clone().detach().numpy()
model = torchvision.models.vgg16(pretrained=True)
model.eval()
# Init metric
metric_params: Dict[str, Union[str,bool]] = {
'perturb_baseline': 'uniform',
'disable_warnings': True,
"display_progressbar": True,
"max_steps_per_input": 98,
}
metric: quantus.Metric = quantus.PixelFlipping(abs=True, normalise=False, **metric_params)
# Run Pixel-Flipping algorithm
call_params: Dict[str, bool] = {
'channel_first': True,
}
scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch, **call_params)
Details
Both, X and R (relevance scores), are in NCHW format and have shape torch.Size([1, 3, 224, 224]).
Meaning of NCHW format: - N: number of images in the batch - C: number of channels of the image (3 for RGB, 1 for grayscale) - H: height of the image - W: width of the image
Hi @rodrigobdz
Haven't myself experienced the issues that you are having (especially with 1 element) so will look into this!
In the meanwhile, we're working on a bigger code update to enable batch processing for the different metrics, starting with pixelflipping (which so far looks to have significant performance gains)
Expected to be resolved in the PR here: https://github.com/understandable-machine-intelligence-lab/Quantus/pull/87
also see here: https://github.com/understandable-machine-intelligence-lab/Quantus/issues/80
For reference, I just released the source code for my Pixel-Flipping/Region Perturbation implementation, which I developed under Grégoire's supervision.
Repo: rodrigobdz/lrp
Fixed in previous update 0.3.0