Quantus icon indicating copy to clipboard operation
Quantus copied to clipboard

Add function-specific keyword-arguments to metric functions.

Open dkrako opened this issue 3 years ago • 3 comments

This PR fixes the main issue from #82

This is done, as the same arguments passed via the unspecific kwargs-argument are then further passed for example to both perturb_func() and model.predict().

This commit therefore introduces normalise_func_kwargs, model_predict_kwargs and perturb_func_kwargs. For the current version the actual behaviour will not be changed, so passing unspecific kwargs is still working as before, i.e. they are being passed to normalise_func and perturb_func. This should be deprecated and eventually removed in future versions.

For this feature, the base.Metric class was extended to contain all the general attributes and preprocessing which was present in each specific metric. Most of this is now handled by the prepare() method of the base-class. There is also the base.PerturbationMetric class added, which handles the perturbation attributes. This way maintainability and readability is much improved, as fixing issues can be potentially solved in the base-class only, without needing to go through all of the metrics.

Each metric now has to implement a process_instance() method, where a single instance of the input batch is processed and the score is simply returned. The scores will then be written in the base-class into the last_results attribute. Addtionally, each metric can implement preprocess() and postprocess() methods. If the structure of the base-class doesn't fit this design, it can still be implemented the old way, like seen in the metric ModelParameterRandomisation, which nevertheless still uses the handy prepare() method from the base class.

dkrako avatar Jun 13 '22 16:06 dkrako

Fantastic @dkrako. More in-depth review to come.

In the meanwhile, pasting ome comments for later review:

  • Asserts, deprecation warnings
  • Check that progress bar works
  • How about device?
  • Dev safety mechanisms:
    • Kwargs append… if key is overwritten: write a print (could add later, now add a print)
    • Check if defaults are used?
  • Naming: call it custom_preprocess

annahedstroem avatar Jun 19 '22 17:06 annahedstroem

As discussed I have left the **kwargs in each metric method but raise a ValueError with informative text if there are any unexpected keyword arguments passed to the base methods __init__() and __call__().

See here: https://github.com/dkrako/Quantus/blob/30b656b2620ae7517f946d178d8636571000d12d/quantus/helpers/asserts.py#L253

dkrako avatar Jul 21 '22 10:07 dkrako

I have updated test_axiomatic_metrics.py to propose the change we will need to apply to every metric test. https://github.com/dkrako/Quantus/blob/feature/function-specific-kwargs/tests/metrics/test_axiomatic_metrics.py

I have nested the params dictionary by adding an init and call key, where the arguments for the respective methods will be stored. If a_batch_generate == True, then explain_func and explain_func_kwargs will be read from the call params dict.

@annahedstroem @leanderweber Please leave me feedback here if this would be the appropriate change that we should apply for the remaining test cases.

dkrako avatar Jul 21 '22 12:07 dkrako

Codecov Report

Merging #124 (cb3a0bb) into main (6f7ced1) will decrease coverage by 0.68%. The diff coverage is 96.02%.

@@            Coverage Diff             @@
##             main     #124      +/-   ##
==========================================
- Coverage   95.48%   94.80%   -0.69%     
==========================================
  Files          18       51      +33     
  Lines        3037     2407     -630     
==========================================
- Hits         2900     2282     -618     
+ Misses        137      125      -12     
Impacted Files Coverage Δ
quantus/helpers/discretise_func.py 100.00% <ø> (ø)
quantus/helpers/norm_func.py 100.00% <ø> (ø)
quantus/helpers/perturb_func.py 96.00% <ø> (-0.04%) :arrow_down:
quantus/helpers/similarity_func.py 94.59% <ø> (ø)
...ics/randomisation/model_parameter_randomisation.py 82.60% <82.60%> (ø)
quantus/helpers/explanation_func.py 90.00% <83.33%> (-1.09%) :arrow_down:
quantus/helpers/loss_func.py 83.33% <83.33%> (ø)
quantus/evaluation.py 84.21% <85.71%> (-7.46%) :arrow_down:
quantus/helpers/pytorch_model.py 93.61% <90.00%> (+1.78%) :arrow_up:
...us/metrics/localisation/relevance_rank_accuracy.py 90.90% <90.90%> (ø)
... and 41 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov-commenter avatar Sep 23 '22 16:09 codecov-commenter