batched processing
This is a work in progress for implementing the following issue: use batched processing instead of processing by instance #80
Batch processing will, given sufficient data samples to be processed, most likely improve throughput and performance (always depending on the system). Batching the data along the sample axis of the given data has its limitations in terms of flexibility. Consider the case where a user wants to perform perturbation analysis where single pixels are replaced and the perturbed sample is to be evaluated by the model, however only for a single data point. In this case, no performance benefits will apply as in-parallel perturbation of all available samples will result in a sequential evaluation.
I therefore propose the following approach, which does not create data batches according to the given samples to be evaluated, but in terms of "work packages" resulting from all to be evaluated combinations of data samples times perturbation steps. In such a manner, the perturbations could be computed and prepared before the perturbed samples will be passed through the network, in almost arbitrary order (similar as to a data generator for NN training would operate), and the delta f(x) for each evaluated instance could be collected easily. This way, even when single samples are to be perturbed numerous times, those perturbed states could be packaged into neural network input batches and evaluated more efficiently. That is, I think it can safely be assumed that the input level perturbations, even though the individual perturbation steps usually are not independent from each other (in most cases we want to perform successive data perturbation building upon the previously perturbed sample state), will be way less computationally expensive than a pass through the network.
consider the following pseudo code:
net: a neural network function
X: data samples
A : attributions corresponding to X
k : nn_input_batch_size
# use a worker pool to smartly prepare the perturbed samples.
# creates n workers in an e.g. cpu-bound worker pool (the user should specify n here)
with workers = create_workers(n) : # pseudo context to make sure workers are eliminated once they are done working.
for next_perturbed_samples in generate_perturbations(workers, X, A, k):
pred = net.forward(next_perturbed_samples)
[...] do stuff with pred
where
# generate_perturbations should behave similarly to this, but be implemented more intelligently to not occupy RAM unnecessarily via all_generated_samples. generate_perturbations could be treated as a generator which on the fly is able to generate and newly perturbed samples in the most atomic quantities. workers should also not be bound to the context of a single sample x with a, but share the workload over X and A.
def generate_perturbations(workers, X, A, k):
all_generated_samples = []
for each x in X with corresponding a in X:
all_xt = workers.create_all_xt(x,a)
all_generated_samples.append(all_xt)
for i in range(0,len(all_generated_samples),k):
yield all_generated_samples[i,i+k]
the basic idea is to use generate_perturbations to prepare some k-sized batches using n workers from X according t A. Since a lot of combinations of datapoints times perturbation states will be generated overall, generate_perturbations could internally work with a buffer of a low number of pre-computed batches to assure a good and fluid data transfer between the worker output and the neural network. Further, generate_perturbations must and should not be bound to parallelize over the data axis of X, which would limit applicability again in the single-sample case, but just return the next k perturbed states from datapoints in X, ie also parallelize over the perturbation state axis.
bump, pseudocode update @dkrako , @annahedstroem and @leanderweber
Just noting this, so we remember: PR #88 contains some indexing functions which are tailored to the current non-batching implementation of metrics in the main branch. These functions would also need to be (slightly) altered for batched processing. I have left some TODOs w.r.t. that in the code also
An update version of this work can be found in: https://github.com/understandable-machine-intelligence-lab/Quantus/pull/168.