Quantus icon indicating copy to clipboard operation
Quantus copied to clipboard

use batched processing instead of processing by instance

Open dkrako opened this issue 3 years ago • 0 comments

Currently all of the metrics are more or less structured by the following scheme:

x: array
y: array
a: array

for x_instance, y_instance, a_instance in zip(x, y, a):
    for perturbation_step in range(perturbation_steps):
        x_perturbed = perturb_instance(x_instance, a_instance, perturbation_step)
        y_perturbed = model(x_perturbed)
        score = calculate_score_for_instance(y_instance, y_perturbed)

The choice of perturb_instance arguments are just for simplicity, the code is of course more complex than presented.

But this kind of implementation doesn't use the performance benefits from batched model-prediction and vectorized numpy functions. Instead we could speed up computations by a magnitude if we would instead use the following approach:

x: array
y: array
a: array
batch_size: int

generator = BatchGenerator(x, y, a, batch_size)
for x_batch, y_batch, a_batch in next(generator):
    for perturbation_step in range(perturbation_steps):
        x_batch_perturbed = perturb_batch(x_batch, a_batch, perturbation_step)
        y_batch_perturbed = model(x_batch_perturbed)
        score = calculate_score_for_batch(y_batch, y_batch_perturbed)

Some of perturb_batch functions may need an inner for-loop again, but others could be computed on the whole batch for sure. Depending on the dataset size and model complexity, this should lead to significant improvements in performance.

dkrako avatar Mar 16 '22 13:03 dkrako