captum icon indicating copy to clipboard operation
captum copied to clipboard

Implement PartitionSHAP attribution method

Open sisird864 opened this issue 7 months ago • 0 comments

Motivation:

Captum currently provides KernelShap and ShapleyValueSampling, but lacks a fast hierarchical Shapley estimator.
PartitionSHAP (Lundberg & Erion 2021) reduces the number of model evaluations from O(M · d) to O(M · log d) by recursively partitioning the feature set. Users working with large‐parameter LLMs and computer-vision models will benefit from 10-100× faster attributions without sacrificing accuracy.


Proposed API (mirrors Captum style):

from captum.attr._core.perturbation_attribution import PerturbationAttribution
from typing import Callable, Any, Optional

class PartitionShap(PerturbationAttribution):
    def __init__(
        self,
        forward_func: Callable,
        baselines: Optional[Any] = None,
        max_evals: int = 512,
        perturbations_per_eval: int = 1,
        partition_agg: str = "mean",   # mean | sum | max
        cluster_method: str = "greedy" # future-proof for alt heuristics
    ): ...

Returns → tensor of Shapley values with same shape as inputs.

Design outline:

  1. Hierarchy builder – greedy divisive split on absolute gradients

    • (torch.topk for speed, stop at single-feature leaves).
  2. Conditional expectation – Monte-Carlo permutations per internal node,

    • batched using perturbations_per_eval.
  3. Reuse Captum helpers

    • _tensorize_baselines, _construct_perturbations
    • infer_target_device_dtype for GPU / dtype safety.
  4. Early-exit guard – stop once eval_count ≥ max_evals, emit UserWarning.

Acceptance criteria / test plan

  • Correctness: PartitionSHAP ≈ KernelSHAP on a 5-feature linear model (torch.testing.assert_close(..., atol=1e-2, rtol=1e-2)).

  • Speed: On bert-base-uncased CLS token, n_samples=512 runs ≥20× faster than KernelSHAP n_samples=2048 (CPU benchmark).

  • Coverage ≥99 % for new file (pytest --cov=...).

  • Docs: docs/source/algorithms.rst section Tutorial notebook: GPT-2 toxicity example (<3 min CPU runtime).

Task checklist:

  • Core implementation (captum/attr/_core/partition_shap.py)
  • Unit tests (tests/attr/test_partition_shap.py)
  • Benchmarks script (optional, skipped on CI)
  • API & math docs
  • Example notebook
  • Changelog entry

References:

G. Lundberg & G. Erion, “Partition SHAP: Explaining complex models via recursive feature partitioning,” 2021. (arXiv 2105.14814)

Assignees:

@sisird864 — claiming this feature implementation. Maintainers: please add the appropriate labels (triaged, feature, help-wanted) and let me know of any design concerns before I begin. Thanks! Looking forward to contributing. — @sisird864

sisird864 avatar Jul 13 '25 15:07 sisird864