Add function-specific keyword-arguments to metric functions.
This PR fixes the main issue from #82
This is done, as the same arguments passed via the unspecific kwargs-argument are
then further passed for example to both perturb_func() and model.predict().
This commit therefore introduces normalise_func_kwargs, model_predict_kwargs and
perturb_func_kwargs. For the current version the actual behaviour will not be
changed, so passing unspecific kwargs is still working as before, i.e.
they are being passed to normalise_func and perturb_func.
This should be deprecated and eventually removed in future versions.
For this feature, the base.Metric class was extended to contain all the
general attributes and preprocessing which was present in each specific
metric. Most of this is now handled by the prepare() method of the
base-class. There is also the base.PerturbationMetric class added, which
handles the perturbation attributes. This way maintainability and
readability is much improved, as fixing issues can be potentially solved
in the base-class only, without needing to go through all of the
metrics.
Each metric now has to implement a process_instance() method, where a
single instance of the input batch is processed and the score is simply
returned. The scores will then be written in the base-class into the last_results
attribute. Addtionally, each metric can implement preprocess() and
postprocess() methods. If the structure of the base-class doesn't fit
this design, it can still be implemented the old way, like seen in the
metric ModelParameterRandomisation, which nevertheless still uses the handy prepare()
method from the base class.
Fantastic @dkrako. More in-depth review to come.
In the meanwhile, pasting ome comments for later review:
- Asserts, deprecation warnings
- Check that progress bar works
- How about device?
- Dev safety mechanisms:
- Kwargs append… if key is overwritten: write a print (could add later, now add a print)
- Check if defaults are used?
- Naming: call it custom_preprocess
As discussed I have left the **kwargs in each metric method but raise a ValueError with informative text if there are any unexpected keyword arguments passed to the base methods __init__() and __call__().
See here: https://github.com/dkrako/Quantus/blob/30b656b2620ae7517f946d178d8636571000d12d/quantus/helpers/asserts.py#L253
I have updated test_axiomatic_metrics.py to propose the change we will need to apply to every metric test.
https://github.com/dkrako/Quantus/blob/feature/function-specific-kwargs/tests/metrics/test_axiomatic_metrics.py
I have nested the params dictionary by adding an init and call key, where the arguments for the respective methods will be stored.
If a_batch_generate == True, then explain_func and explain_func_kwargs will be read from the call params dict.
@annahedstroem @leanderweber Please leave me feedback here if this would be the appropriate change that we should apply for the remaining test cases.
Codecov Report
Merging #124 (cb3a0bb) into main (6f7ced1) will decrease coverage by
0.68%. The diff coverage is96.02%.
@@ Coverage Diff @@
## main #124 +/- ##
==========================================
- Coverage 95.48% 94.80% -0.69%
==========================================
Files 18 51 +33
Lines 3037 2407 -630
==========================================
- Hits 2900 2282 -618
+ Misses 137 125 -12
| Impacted Files | Coverage Δ | |
|---|---|---|
| quantus/helpers/discretise_func.py | 100.00% <ø> (ø) |
|
| quantus/helpers/norm_func.py | 100.00% <ø> (ø) |
|
| quantus/helpers/perturb_func.py | 96.00% <ø> (-0.04%) |
:arrow_down: |
| quantus/helpers/similarity_func.py | 94.59% <ø> (ø) |
|
| ...ics/randomisation/model_parameter_randomisation.py | 82.60% <82.60%> (ø) |
|
| quantus/helpers/explanation_func.py | 90.00% <83.33%> (-1.09%) |
:arrow_down: |
| quantus/helpers/loss_func.py | 83.33% <83.33%> (ø) |
|
| quantus/evaluation.py | 84.21% <85.71%> (-7.46%) |
:arrow_down: |
| quantus/helpers/pytorch_model.py | 93.61% <90.00%> (+1.78%) |
:arrow_up: |
| ...us/metrics/localisation/relevance_rank_accuracy.py | 90.90% <90.90%> (ø) |
|
| ... and 41 more |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.