Quantus
Quantus copied to clipboard
use batched processing instead of processing by instance
Currently all of the metrics are more or less structured by the following scheme:
x: array
y: array
a: array
for x_instance, y_instance, a_instance in zip(x, y, a):
for perturbation_step in range(perturbation_steps):
x_perturbed = perturb_instance(x_instance, a_instance, perturbation_step)
y_perturbed = model(x_perturbed)
score = calculate_score_for_instance(y_instance, y_perturbed)
The choice of perturb_instance arguments are just for simplicity, the code is of course more complex than presented.
But this kind of implementation doesn't use the performance benefits from batched model-prediction and vectorized numpy functions. Instead we could speed up computations by a magnitude if we would instead use the following approach:
x: array
y: array
a: array
batch_size: int
generator = BatchGenerator(x, y, a, batch_size)
for x_batch, y_batch, a_batch in next(generator):
for perturbation_step in range(perturbation_steps):
x_batch_perturbed = perturb_batch(x_batch, a_batch, perturbation_step)
y_batch_perturbed = model(x_batch_perturbed)
score = calculate_score_for_batch(y_batch, y_batch_perturbed)
Some of perturb_batch functions may need an inner for-loop again, but others could be computed on the whole batch for sure.
Depending on the dataset size and model complexity, this should lead to significant improvements in performance.