Skip to content

Metrics

Complexity

Bases: Metric

Computes the complexity of attributions.

Given attributions, calculates a fractional contribution distribution prob_mass, prob_mass[i] = hist[i] / sum(hist). where hist[i] = histogram(attributions[i]).

The complexity is defined by the entropy, evaluation = -sum(hist * ln(hist))

Parameters:

Name Type Description Default
model Model

The model used for evaluation

required
explainer Optional[Explainer]

The explainer used for evaluation.

None
n_bins int

The number of bins for histogram computation.

10
Reference

U. Bhatt, A. Weller, and J. M. F. Moura. Evaluating and aggregating feature-based model attributions. In Proceedings of the IJCAI (2020).

Source code in pnpxai/evaluator/metrics/complexity.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class Complexity(Metric):
    """
    Computes the complexity of attributions.

    Given `attributions`, calculates a fractional contribution distribution `prob_mass`,
    ``prob_mass[i] = hist[i] / sum(hist)``. where ``hist[i] = histogram(attributions[i])``.

    The complexity is defined by the entropy,
    ``evaluation = -sum(hist * ln(hist))``


    Args:
        model (Model): The model used for evaluation
        explainer (Optional[Explainer]): The explainer used for evaluation.
        n_bins (int): The number of bins for histogram computation.

    Reference:
        U. Bhatt, A. Weller, and J. M. F. Moura. Evaluating and aggregating feature-based model attributions. In Proceedings of the IJCAI (2020).
    """
    def __init__(
        self,
        model: Model,
        explainer: Optional[Explainer]=None,
        n_bins: int = 10
    ):
        super().__init__(model, explainer)
        self.n_bins = n_bins

    def evaluate(
        self,
        inputs: Optional[torch.Tensor] = None,
        targets: Optional[torch.Tensor] = None,
        attributions: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Evaluate the explainer's complexity based on their probability masses.

        Args:
            inputs (Optional[Tensor]): The input tensors to the model.
            targets (Optional[Tensor]): The target labels for the inputs.
            attributions (Optional[Tensor]): The attributions for the inputs.

        Returns:
            Tensor: A tensor of the complexity evaluations.
        """
        assert attributions.ndim in [3, 4], "Must have 2D or 3D attributions"
        if attributions.ndim == 4:
            attributions = rgb_to_grayscale(attributions)
        evaluations = []
        for attr in attributions:
            hist, _ = np.histogram(attr.detach().cpu(), bins=self.n_bins)
            prob_mass = hist / hist.sum()
            evaluations.append(entropy(prob_mass))
        return torch.tensor(evaluations)

evaluate(inputs=None, targets=None, attributions=None)

Evaluate the explainer's complexity based on their probability masses.

Parameters:

Name Type Description Default
inputs Optional[Tensor]

The input tensors to the model.

None
targets Optional[Tensor]

The target labels for the inputs.

None
attributions Optional[Tensor]

The attributions for the inputs.

None

Returns:

Name Type Description
Tensor Tensor

A tensor of the complexity evaluations.

Source code in pnpxai/evaluator/metrics/complexity.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def evaluate(
    self,
    inputs: Optional[torch.Tensor] = None,
    targets: Optional[torch.Tensor] = None,
    attributions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Evaluate the explainer's complexity based on their probability masses.

    Args:
        inputs (Optional[Tensor]): The input tensors to the model.
        targets (Optional[Tensor]): The target labels for the inputs.
        attributions (Optional[Tensor]): The attributions for the inputs.

    Returns:
        Tensor: A tensor of the complexity evaluations.
    """
    assert attributions.ndim in [3, 4], "Must have 2D or 3D attributions"
    if attributions.ndim == 4:
        attributions = rgb_to_grayscale(attributions)
    evaluations = []
    for attr in attributions:
        hist, _ = np.histogram(attr.detach().cpu(), bins=self.n_bins)
        prob_mass = hist / hist.sum()
        evaluations.append(entropy(prob_mass))
    return torch.tensor(evaluations)

MuFidelity

Bases: Metric

Computes the MuFidelity metric for attributions.

Given a model and inputs, mufidelity of model to an explainer at inputs is calculated by a correlation between difference of predictions and attributions of maked inputs, evaluation = corr(pred_diff, masked_attr).

The masked inputs are generated by masking subset_mask to noised inputs, masked = perturbed * subset_mask + (1.0 - subset_mask) * baseline

Parameters:

Name Type Description Default
model Model

The model to evaluate.

required
explainer Optional[Explainer]

The explainer to evaluate.

None
n_perturb int

Number of perturbations to generate.

150
noise_scale int

Scale factor for Gaussian random noise.

0.2
batch_size int

Batch size for model evaluation.

32
grid_size int

Size of the grid for creating subsets.

9
baseline Union[float, Tensor]

Baseline value for masked subsets.

0.0
mask_agg_dim Optional[int]

Dimension to aggregate masks.

None
**kwargs

Additional kwargs to compute metric in an evaluator. Not required for single usage.

required
Reference

U. Bhatt, A. Weller, and J. M. F. Moura. Evaluating and aggregating feature-based model attributions. In Proceedings of the IJCAI (2020).

Source code in pnpxai/evaluator/metrics/mu_fidelity.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class MuFidelity(Metric):
    """
    Computes the MuFidelity metric for attributions.

    Given a `model` and `inputs`, mufidelity of `model` to an explainer at `inputs` is calculated by
    a correlation between difference of predictions and attributions of maked inputs,
    ``evaluation = corr(pred_diff, masked_attr)``.

    The masked inputs are generated by masking `subset_mask` to noised `inputs`,
    ``masked = perturbed * subset_mask + (1.0 - subset_mask) * baseline``

    Args:
        model (Model): The model to evaluate.
        explainer (Optional[Explainer]): The explainer to evaluate.
        n_perturb (int): Number of perturbations to generate.
        noise_scale (int): Scale factor for Gaussian random noise.
        batch_size (int): Batch size for model evaluation.
        grid_size (int): Size of the grid for creating subsets.
        baseline (Union[float, torch.Tensor]): Baseline value for masked subsets.
        mask_agg_dim (Optional[int]): Dimension to aggregate masks.
        **kwargs: Additional kwargs to compute metric in an evaluator. Not required for single usage.

    Reference:
        U. Bhatt, A. Weller, and J. M. F. Moura. Evaluating and aggregating feature-based model attributions. In Proceedings of the IJCAI (2020).
    """

    def __init__(
        self,
        model: Model,
        explainer: Optional[Explainer] = None,
        n_perturb: int = 150,
        noise_scale: float = 0.2,
        batch_size: int = 32,
        grid_size: int = 9,
        baseline: float = 0.,
        mask_agg_dim: Optional[int] = None,
    ):
        super().__init__(model, explainer)
        self.n_perturb = n_perturb
        self.noise_scale = noise_scale
        self.batch_size = batch_size
        self.grid_size = grid_size
        self.baseline = baseline
        self.mask_agg_dim = mask_agg_dim

    def evaluate(
        self,
        inputs: torch.Tensor,
        targets: torch.Tensor,
        attributions: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            inputs (torch.Tensor): The input data (N x C x H x W).
            targets (torch.Tensor): The target labels for the inputs (N x 1).
            attributions (torch.Tensor): The attributions of the inputs.

        Returns:
            torch.Tensor: The result of the metric evaluation.
        """
        attributions = attributions.to(self.device)

        outputs = self.model(inputs)
        n_classes = outputs.shape[-1]
        predictions = (
            outputs * torch.eye(n_classes).to(self.device)[targets]
        ).sum(dim=-1).detach()

        # input, target, attr, pred
        evaluations = []
        zipped = zip(inputs, targets, attributions, predictions)
        for input, target, attr, pred in zipped:
            repeated = torch.stack([input]*self.n_perturb)
            # Add Gaussian random noise
            std, mean = torch.std_mean(repeated)
            noise = torch.randn_like(repeated).to(self.device) * std + mean
            perturbed = self.noise_scale * noise + repeated
            perturbed = torch.minimum(repeated, perturbed)
            perturbed = torch.maximum(repeated-1, perturbed)

            # prepare the random masks that will designate the modified subset (S in original equation)
            subset_size = int(self.grid_size ** 2 * self.noise_scale)
            subset_mask = torch.randn(
                (self.n_perturb, self.grid_size ** 2)).to(self.device)
            subset_mask = torch.argsort(subset_mask, dim=-1) > subset_size
            subset_mask = torch.reshape(subset_mask.type(
                torch.float32), (self.n_perturb, 1, self.grid_size, self.grid_size))
            subset_mask = transforms.Resize(
                perturbed.shape[-2:],
                transforms.InterpolationMode("nearest")
            ).forward(subset_mask)
            if self.mask_agg_dim is not None:
                subset_mask = subset_mask.mean(dim=self.mask_agg_dim)

            # Use the masks to set the selected subsets to baseline state
            masked = perturbed * subset_mask + \
                (1.0 - subset_mask) * self.baseline

            masked_output = _forward_batch(self.model, masked, self.batch_size)
            pred_diff = pred - masked_output[:, target]

            masked_attr = (attr * (1.0 - subset_mask))\
                .sum(dim=tuple(range(1, subset_mask.ndim)))

            corr, _ = spearmanr(
                pred_diff.cpu().detach().numpy(),
                masked_attr.cpu().detach().numpy(),
            )
            evaluations.append(corr)
        return torch.tensor(evaluations)

evaluate(inputs, targets, attributions)

Parameters:

Name Type Description Default
inputs Tensor

The input data (N x C x H x W).

required
targets Tensor

The target labels for the inputs (N x 1).

required
attributions Tensor

The attributions of the inputs.

required

Returns:

Type Description
Tensor

torch.Tensor: The result of the metric evaluation.

Source code in pnpxai/evaluator/metrics/mu_fidelity.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def evaluate(
    self,
    inputs: torch.Tensor,
    targets: torch.Tensor,
    attributions: torch.Tensor,
) -> torch.Tensor:
    """
    Args:
        inputs (torch.Tensor): The input data (N x C x H x W).
        targets (torch.Tensor): The target labels for the inputs (N x 1).
        attributions (torch.Tensor): The attributions of the inputs.

    Returns:
        torch.Tensor: The result of the metric evaluation.
    """
    attributions = attributions.to(self.device)

    outputs = self.model(inputs)
    n_classes = outputs.shape[-1]
    predictions = (
        outputs * torch.eye(n_classes).to(self.device)[targets]
    ).sum(dim=-1).detach()

    # input, target, attr, pred
    evaluations = []
    zipped = zip(inputs, targets, attributions, predictions)
    for input, target, attr, pred in zipped:
        repeated = torch.stack([input]*self.n_perturb)
        # Add Gaussian random noise
        std, mean = torch.std_mean(repeated)
        noise = torch.randn_like(repeated).to(self.device) * std + mean
        perturbed = self.noise_scale * noise + repeated
        perturbed = torch.minimum(repeated, perturbed)
        perturbed = torch.maximum(repeated-1, perturbed)

        # prepare the random masks that will designate the modified subset (S in original equation)
        subset_size = int(self.grid_size ** 2 * self.noise_scale)
        subset_mask = torch.randn(
            (self.n_perturb, self.grid_size ** 2)).to(self.device)
        subset_mask = torch.argsort(subset_mask, dim=-1) > subset_size
        subset_mask = torch.reshape(subset_mask.type(
            torch.float32), (self.n_perturb, 1, self.grid_size, self.grid_size))
        subset_mask = transforms.Resize(
            perturbed.shape[-2:],
            transforms.InterpolationMode("nearest")
        ).forward(subset_mask)
        if self.mask_agg_dim is not None:
            subset_mask = subset_mask.mean(dim=self.mask_agg_dim)

        # Use the masks to set the selected subsets to baseline state
        masked = perturbed * subset_mask + \
            (1.0 - subset_mask) * self.baseline

        masked_output = _forward_batch(self.model, masked, self.batch_size)
        pred_diff = pred - masked_output[:, target]

        masked_attr = (attr * (1.0 - subset_mask))\
            .sum(dim=tuple(range(1, subset_mask.ndim)))

        corr, _ = spearmanr(
            pred_diff.cpu().detach().numpy(),
            masked_attr.cpu().detach().numpy(),
        )
        evaluations.append(corr)
    return torch.tensor(evaluations)

Sensitivity

Bases: Metric

Computes the complexity of attributions.

Given attributions, calculates a fractional contribution distribution prob_mass, prob_mass[i] = hist[i] / sum(hist). where hist[i] = histogram(attributions[i]).

The complexity is defined by the entropy, evaluation = -sum(hist * ln(hist))

Parameters:

Name Type Description Default
model Model

The model used for evaluation

required
explainer Optional[Explainer]

The explainer used for evaluation.

None
n_iter Optional[int]

The number of iterations for perturbation.

8
epsilon Optional[float]

The magnitude of random uniform noise.

0.2
Source code in pnpxai/evaluator/metrics/sensitivity.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Sensitivity(Metric):
    """
    Computes the complexity of attributions.

    Given `attributions`, calculates a fractional contribution distribution `prob_mass`,
    ``prob_mass[i] = hist[i] / sum(hist)``. where ``hist[i] = histogram(attributions[i])``.

    The complexity is defined by the entropy,
    ``evaluation = -sum(hist * ln(hist))``


    Args:
        model (Model): The model used for evaluation
        explainer (Optional[Explainer]): The explainer used for evaluation.
        n_iter (Optional[int]): The number of iterations for perturbation.
        epsilon (Optional[float]): The magnitude of random uniform noise.
    """
    def __init__(
        self,
        model: Model,
        explainer: Optional[Explainer] = None,
        n_iter: Optional[int] = 8,
        epsilon: Optional[float] = 0.2,
    ):
        super().__init__(model, explainer)
        self.n_iter = n_iter
        self.epsilon = epsilon
        if explainer is None:
            warnings.warn('[Sensitivity] explainer is not provided. Please set explainer before evaluate.')

    def evaluate(
        self,
        inputs: torch.Tensor,
        targets: torch.Tensor,
        attributions: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """
        Args:
            inputs (torch.Tensor): The input data.
            targets (torch.Tensor): The target labels for the inputs.
            attributions (Optional[torch.Tensor]): The attributions of the inputs.

        Returns:
            torch.Tensor: The result of the metric evaluation.
        """
        if attributions is None:
            attributions = self.explainer.attribute(inputs, targets)
        attributions = attributions.to(self.device)
        evaluations = []
        for inp, target, attr in zip(inputs, targets, attributions):
            # Add random uniform noise which ranges [-epsilon, epsilon]
            perturbed = torch.stack([inp]*self.n_iter)
            noise = (
                torch.rand_like(perturbed).to(self.device) * self.epsilon * 2 \
                - self.epsilon
            )
            perturbed += noise
            # Get perturbed attribution results
            perturbed_attr = self.explainer.attribute(
                inputs=perturbed.to(self.device),
                targets=target.repeat(self.n_iter),
            )
            # Get maximum of the difference between the perturbed attribution and the original attribution
            attr_norm = torch.linalg.norm(attr).to(self.device)
            attr_diff = attr - perturbed_attr
            sens = max([torch.linalg.norm(diff)/attr_norm for diff in attr_diff])
            evaluations.append(sens)
        return torch.stack(evaluations).to(self.device)

evaluate(inputs, targets, attributions)

Parameters:

Name Type Description Default
inputs Tensor

The input data.

required
targets Tensor

The target labels for the inputs.

required
attributions Optional[Tensor]

The attributions of the inputs.

required

Returns:

Type Description
Tensor

torch.Tensor: The result of the metric evaluation.

Source code in pnpxai/evaluator/metrics/sensitivity.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def evaluate(
    self,
    inputs: torch.Tensor,
    targets: torch.Tensor,
    attributions: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Args:
        inputs (torch.Tensor): The input data.
        targets (torch.Tensor): The target labels for the inputs.
        attributions (Optional[torch.Tensor]): The attributions of the inputs.

    Returns:
        torch.Tensor: The result of the metric evaluation.
    """
    if attributions is None:
        attributions = self.explainer.attribute(inputs, targets)
    attributions = attributions.to(self.device)
    evaluations = []
    for inp, target, attr in zip(inputs, targets, attributions):
        # Add random uniform noise which ranges [-epsilon, epsilon]
        perturbed = torch.stack([inp]*self.n_iter)
        noise = (
            torch.rand_like(perturbed).to(self.device) * self.epsilon * 2 \
            - self.epsilon
        )
        perturbed += noise
        # Get perturbed attribution results
        perturbed_attr = self.explainer.attribute(
            inputs=perturbed.to(self.device),
            targets=target.repeat(self.n_iter),
        )
        # Get maximum of the difference between the perturbed attribution and the original attribution
        attr_norm = torch.linalg.norm(attr).to(self.device)
        attr_diff = attr - perturbed_attr
        sens = max([torch.linalg.norm(diff)/attr_norm for diff in attr_diff])
        evaluations.append(sens)
    return torch.stack(evaluations).to(self.device)

AbPC

Bases: PixelFlipping

A metric class for evaluating the correctness of explanations or attributions using the Area between Perturbation Curves (AbPC) technique.

This class inherits from the PixelFlipping class and assesses the quality of attributions by comparing the area between the perturbation curves obtained by perturbing input features (e.g., pixels) in both ascending and descending order of their attributed importance. The average probability change is measured, providing a comprehensive evaluation of the explainer's correctness.

Attributes:

Name Type Description
model Module

The model.

explainer Optional[Explainer]=None

The explainer whose explanations are being evaluated.

channel_dim int

Target channel dimension.

n_steps int

The number of perturbation steps.

baseline_fn Optional[BaselineFunction]

Function to generate baseline inputs for perturbation.

prob_fn Optional[Callable[[Tensor], Tensor]]

Function to compute probabilities from model outputs.

pred_fn Optional[Callable[[Tensor], Tensor]]

Function to compute predictions from model outputs.

lb float

The lower bound for clamping the probability differences.

Methods:

Name Description
evaluate

Evaluate the explainer's correctness using the AbPC technique by observing changes in model predictions.

Source code in pnpxai/evaluator/metrics/pixel_flipping.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
class AbPC(PixelFlipping):
    """
    A metric class for evaluating the correctness of explanations or attributions using the 
    Area between Perturbation Curves (AbPC) technique.

    This class inherits from the PixelFlipping class and assesses the quality of attributions by comparing 
    the area between the perturbation curves obtained by perturbing input features (e.g., pixels) in both 
    ascending and descending order of their attributed importance. The average probability change is 
    measured, providing a comprehensive evaluation of the explainer's correctness.

    Attributes:
        model (Module): The model.
        explainer (Optional[Explainer]=None): The explainer whose explanations are being evaluated.
        channel_dim (int): Target channel dimension.
        n_steps (int): The number of perturbation steps.
        baseline_fn (Optional[BaselineFunction]): Function to generate baseline inputs for perturbation.
        prob_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute probabilities from model outputs.
        pred_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute predictions from model outputs.
        lb (float): The lower bound for clamping the probability differences.

    Methods:
        evaluate(inputs, targets, attributions, attention_mask=None):
            Evaluate the explainer's correctness using the AbPC technique by observing changes in model predictions.
    """

    def __init__(
        self,
        model: Module,
        explainer: Optional[Explainer]=None,
        channel_dim: int=1,
        n_steps: int=10,
        baseline_fn: Optional[BaselineFunction]=None,
        prob_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.softmax(-1),
        pred_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.argmax(-1),
        lb: float=-1.,
    ):
        super().__init__(
            model, explainer, channel_dim, n_steps,
            baseline_fn, prob_fn, pred_fn,
        )
        self.lb = lb

    def evaluate(
        self,
        inputs: TensorOrTupleOfTensors,
        targets: Tensor,
        attributions: TensorOrTupleOfTensors,
        attention_mask: Optional[TensorOrTupleOfTensors]=None,
        return_pf=False,
    ) -> TensorOrTupleOfTensors:
        """
        Evaluate the explainer's correctness using the AbPC technique by observing changes in model predictions.

        Args:
            inputs (TensorOrTupleOfTensors): The input tensors to the model.
            targets (Tensor): The target labels for the inputs.
            attributions (TensorOrTupleOfTensors): The attributions for the inputs.
            attention_mask (Optional[TensorOrTupleOfTensors], optional): Attention masks for the inputs.
            return_pf (Optional[bool]): Whether to return the perturbation curves for ascending and descending orders.

        Returns:
            TensorOrTupleOfTensors: The mean clamped differences in probabilities at each perturbation step, 
                indicating the impact of perturbing the most and least relevant features.
        """
        # pf by ascending order: lerf
        pf_ascs = super().evaluate(inputs, targets, attributions, attention_mask, False)
        pf_ascs = format_into_tuple(pf_ascs)

        # pf by descending order: morf
        pf_descs = super().evaluate(inputs, targets, attributions, attention_mask, True)
        pf_descs = format_into_tuple(pf_descs)

        # abpc
        results = []
        for pf_asc, pf_desc in zip(pf_ascs, pf_descs):
            result = (pf_asc['probs'] - pf_desc['probs']).clamp(min=self.lb).mean(-1)
            if return_pf:
                result = tuple([result, pf_desc, pf_asc])
            results.append(result)
        if len(results) == 1:
            return results[0]
        return tuple(results)

evaluate(inputs, targets, attributions, attention_mask=None, return_pf=False)

Evaluate the explainer's correctness using the AbPC technique by observing changes in model predictions.

Parameters:

Name Type Description Default
inputs TensorOrTupleOfTensors

The input tensors to the model.

required
targets Tensor

The target labels for the inputs.

required
attributions TensorOrTupleOfTensors

The attributions for the inputs.

required
attention_mask Optional[TensorOrTupleOfTensors]

Attention masks for the inputs.

None
return_pf Optional[bool]

Whether to return the perturbation curves for ascending and descending orders.

False

Returns:

Name Type Description
TensorOrTupleOfTensors TensorOrTupleOfTensors

The mean clamped differences in probabilities at each perturbation step, indicating the impact of perturbing the most and least relevant features.

Source code in pnpxai/evaluator/metrics/pixel_flipping.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def evaluate(
    self,
    inputs: TensorOrTupleOfTensors,
    targets: Tensor,
    attributions: TensorOrTupleOfTensors,
    attention_mask: Optional[TensorOrTupleOfTensors]=None,
    return_pf=False,
) -> TensorOrTupleOfTensors:
    """
    Evaluate the explainer's correctness using the AbPC technique by observing changes in model predictions.

    Args:
        inputs (TensorOrTupleOfTensors): The input tensors to the model.
        targets (Tensor): The target labels for the inputs.
        attributions (TensorOrTupleOfTensors): The attributions for the inputs.
        attention_mask (Optional[TensorOrTupleOfTensors], optional): Attention masks for the inputs.
        return_pf (Optional[bool]): Whether to return the perturbation curves for ascending and descending orders.

    Returns:
        TensorOrTupleOfTensors: The mean clamped differences in probabilities at each perturbation step, 
            indicating the impact of perturbing the most and least relevant features.
    """
    # pf by ascending order: lerf
    pf_ascs = super().evaluate(inputs, targets, attributions, attention_mask, False)
    pf_ascs = format_into_tuple(pf_ascs)

    # pf by descending order: morf
    pf_descs = super().evaluate(inputs, targets, attributions, attention_mask, True)
    pf_descs = format_into_tuple(pf_descs)

    # abpc
    results = []
    for pf_asc, pf_desc in zip(pf_ascs, pf_descs):
        result = (pf_asc['probs'] - pf_desc['probs']).clamp(min=self.lb).mean(-1)
        if return_pf:
            result = tuple([result, pf_desc, pf_asc])
        results.append(result)
    if len(results) == 1:
        return results[0]
    return tuple(results)

LeRF

Bases: PixelFlipping

A metric class for evaluating the correctness of explanations or attributions using the Least Relevant First (LeRF) pixel flipping technique.

This class inherits from the PixelFlipping class and evaluates the quality of attributions by perturbing input features (e.g., pixels) in ascending order of their attributed importance. The average probability change is measured to assess the explainer's correctness.

Attributes:

Name Type Description
model Module

The model.

explainer Optional[Explainer]=None

The explainer whose explanations are being evaluated.

channel_dim int

Target channel dimension.

n_steps int

The number of perturbation steps.

baseline_fn Optional[BaselineFunction]

Function to generate baseline inputs for perturbation.

prob_fn Optional[Callable[[Tensor], Tensor]]

Function to compute probabilities from model outputs.

pred_fn Optional[Callable[[Tensor], Tensor]]

Function to compute predictions from model outputs.

Methods:

Name Description
evaluate

Evaluate the explainer's correctness using the LeRF technique by observing changes in model predictions.

Source code in pnpxai/evaluator/metrics/pixel_flipping.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class LeRF(PixelFlipping):
    """
    A metric class for evaluating the correctness of explanations or attributions using the 
    Least Relevant First (LeRF) pixel flipping technique.

    This class inherits from the PixelFlipping class and evaluates the quality of attributions by perturbing input 
    features (e.g., pixels) in ascending order of their attributed importance. The average probability change is 
    measured to assess the explainer's correctness.

    Attributes:
        model (Module): The model.
        explainer (Optional[Explainer]=None): The explainer whose explanations are being evaluated.
        channel_dim (int): Target channel dimension.
        n_steps (int): The number of perturbation steps.
        baseline_fn (Optional[BaselineFunction]): Function to generate baseline inputs for perturbation.
        prob_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute probabilities from model outputs.
        pred_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute predictions from model outputs.

    Methods:
        evaluate(inputs, targets, attributions, attention_mask=None):
            Evaluate the explainer's correctness using the LeRF technique by observing changes in model predictions.
    """

    def __init__(
        self,
        model: Module,
        explainer: Optional[Explainer]=None,
        channel_dim: int=1,
        n_steps: int=10,
        baseline_fn: Optional[BaselineFunction]=None,
        prob_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.softmax(-1),
        pred_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.argmax(-1),
    ):
        super().__init__(
            model, explainer, channel_dim, n_steps,
            baseline_fn, prob_fn, pred_fn,
        )

    def evaluate(
        self,
        inputs: TensorOrTupleOfTensors,
        targets: Tensor,
        attributions: TensorOrTupleOfTensors,
        attention_mask: TensorOrTupleOfTensors=None,
    ) -> TensorOrTupleOfTensors:
        """
        Evaluate the explainer's correctness using the LeRF technique by observing changes in model predictions.

        Args:
            inputs (TensorOrTupleOfTensors): The input tensors to the model.
            targets (Tensor): The target labels for the inputs.
            attributions (TensorOrTupleOfTensors): The attributions for the inputs.
            attention_mask (Optional[TensorOrTupleOfTensors], optional): Attention masks for the inputs. Default is None.

        Returns:
            TensorOrTupleOfTensors: The mean probabilities at each perturbation step, indicating the impact of 
                perturbing the least relevant features first.
        """
        pf_results = super().evaluate(inputs, targets, attributions, attention_mask, False)
        pf_results = format_into_tuple(pf_results)
        lerf = tuple(result['probs'].mean(-1) for result in pf_results)
        if len(lerf) == 1:
            lerf = lerf[0]
        return lerf

evaluate(inputs, targets, attributions, attention_mask=None)

Evaluate the explainer's correctness using the LeRF technique by observing changes in model predictions.

Parameters:

Name Type Description Default
inputs TensorOrTupleOfTensors

The input tensors to the model.

required
targets Tensor

The target labels for the inputs.

required
attributions TensorOrTupleOfTensors

The attributions for the inputs.

required
attention_mask Optional[TensorOrTupleOfTensors]

Attention masks for the inputs. Default is None.

None

Returns:

Name Type Description
TensorOrTupleOfTensors TensorOrTupleOfTensors

The mean probabilities at each perturbation step, indicating the impact of perturbing the least relevant features first.

Source code in pnpxai/evaluator/metrics/pixel_flipping.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def evaluate(
    self,
    inputs: TensorOrTupleOfTensors,
    targets: Tensor,
    attributions: TensorOrTupleOfTensors,
    attention_mask: TensorOrTupleOfTensors=None,
) -> TensorOrTupleOfTensors:
    """
    Evaluate the explainer's correctness using the LeRF technique by observing changes in model predictions.

    Args:
        inputs (TensorOrTupleOfTensors): The input tensors to the model.
        targets (Tensor): The target labels for the inputs.
        attributions (TensorOrTupleOfTensors): The attributions for the inputs.
        attention_mask (Optional[TensorOrTupleOfTensors], optional): Attention masks for the inputs. Default is None.

    Returns:
        TensorOrTupleOfTensors: The mean probabilities at each perturbation step, indicating the impact of 
            perturbing the least relevant features first.
    """
    pf_results = super().evaluate(inputs, targets, attributions, attention_mask, False)
    pf_results = format_into_tuple(pf_results)
    lerf = tuple(result['probs'].mean(-1) for result in pf_results)
    if len(lerf) == 1:
        lerf = lerf[0]
    return lerf

MoRF

Bases: PixelFlipping

A metric class for evaluating the correctness of explanations or attributions using the Most Relevant First (MoRF) pixel flipping technique.

This class inherits from the PixelFlipping class and evaluates the quality of attributions by perturbing input features (e.g., pixels) in descending order of their attributed importance. The average probability change is measured to assess the explainer's correctness (lower better).

Attributes:

Name Type Description
model Module

The model.

explainer Optional[Explainer]=None

The explainer whose explanations are being evaluated.

channel_dim int

Target channel dimension.

n_steps int

The number of perturbation steps.

baseline_fn Optional[BaselineFunction]

Function to generate baseline inputs for perturbation.

prob_fn Optional[Callable[[Tensor], Tensor]]

Function to compute probabilities from model outputs.

pred_fn Optional[Callable[[Tensor], Tensor]]

Function to compute predictions from model outputs.

Methods:

Name Description
evaluate

Evaluate the explainer's correctness using the MoRF technique by observing changes in model predictions.

Source code in pnpxai/evaluator/metrics/pixel_flipping.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class MoRF(PixelFlipping):
    """
    A metric class for evaluating the correctness of explanations or attributions using the 
    Most Relevant First (MoRF) pixel flipping technique.

    This class inherits from the PixelFlipping class and evaluates the quality of attributions by perturbing input 
    features (e.g., pixels) in descending order of their attributed importance. The average probability change is 
    measured to assess the explainer's correctness (lower better).

    Attributes:
        model (Module): The model.
        explainer (Optional[Explainer]=None): The explainer whose explanations are being evaluated.
        channel_dim (int): Target channel dimension.
        n_steps (int): The number of perturbation steps.
        baseline_fn (Optional[BaselineFunction]): Function to generate baseline inputs for perturbation.
        prob_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute probabilities from model outputs.
        pred_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute predictions from model outputs.

    Methods:
        evaluate(inputs, targets, attributions, attention_mask=None):
            Evaluate the explainer's correctness using the MoRF technique by observing changes in model predictions.
    """

    def __init__(
        self,
        model: Module,
        explainer: Optional[Explainer]=None,
        channel_dim: int=1,
        n_steps: int=10,
        baseline_fn: Optional[BaselineFunction]=None,
        prob_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.softmax(-1),
        pred_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.argmax(-1),
    ):
        super().__init__(
            model, explainer, channel_dim, n_steps,
            baseline_fn, prob_fn, pred_fn,
        )

    def evaluate(
        self,
        inputs: TensorOrTupleOfTensors,
        targets: Tensor,
        attributions: TensorOrTupleOfTensors,
        attention_mask: Optional[TensorOrTupleOfTensors]=None
    ) -> TensorOrTupleOfTensors:
        """
        Evaluate the explainer's correctness using the MoRF technique by observing changes in model predictions.

        Args:
            inputs (TensorOrTupleOfTensors): The input tensors to the model.
            targets (Tensor): The target labels for the inputs.
            attributions (TensorOrTupleOfTensors): The attributions for the inputs.
            attention_mask (Optional[TensorOrTupleOfTensors], optional): Attention masks for the inputs. Default is None.

        Returns:
            TensorOrTupleOfTensors: The mean probabilities at each perturbation step, indicating the impact of 
                perturbing the most relevant features first.
        """
        pf_results = super().evaluate(inputs, targets, attributions, attention_mask, True)
        pf_results = format_into_tuple(pf_results)
        morf = tuple(result['probs'].mean(-1) for result in pf_results)
        if len(morf) == 1:
            morf = morf[0]
        return morf

evaluate(inputs, targets, attributions, attention_mask=None)

Evaluate the explainer's correctness using the MoRF technique by observing changes in model predictions.

Parameters:

Name Type Description Default
inputs TensorOrTupleOfTensors

The input tensors to the model.

required
targets Tensor

The target labels for the inputs.

required
attributions TensorOrTupleOfTensors

The attributions for the inputs.

required
attention_mask Optional[TensorOrTupleOfTensors]

Attention masks for the inputs. Default is None.

None

Returns:

Name Type Description
TensorOrTupleOfTensors TensorOrTupleOfTensors

The mean probabilities at each perturbation step, indicating the impact of perturbing the most relevant features first.

Source code in pnpxai/evaluator/metrics/pixel_flipping.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def evaluate(
    self,
    inputs: TensorOrTupleOfTensors,
    targets: Tensor,
    attributions: TensorOrTupleOfTensors,
    attention_mask: Optional[TensorOrTupleOfTensors]=None
) -> TensorOrTupleOfTensors:
    """
    Evaluate the explainer's correctness using the MoRF technique by observing changes in model predictions.

    Args:
        inputs (TensorOrTupleOfTensors): The input tensors to the model.
        targets (Tensor): The target labels for the inputs.
        attributions (TensorOrTupleOfTensors): The attributions for the inputs.
        attention_mask (Optional[TensorOrTupleOfTensors], optional): Attention masks for the inputs. Default is None.

    Returns:
        TensorOrTupleOfTensors: The mean probabilities at each perturbation step, indicating the impact of 
            perturbing the most relevant features first.
    """
    pf_results = super().evaluate(inputs, targets, attributions, attention_mask, True)
    pf_results = format_into_tuple(pf_results)
    morf = tuple(result['probs'].mean(-1) for result in pf_results)
    if len(morf) == 1:
        morf = morf[0]
    return morf

PixelFlipping

Bases: Metric

A metric class for evaluating the correctness of explanations or attributions provided by an explainer using the pixel flipping technique.

This class assesses the quality of attributions by perturbing input features (e.g., pixels) in the order of their attributed importance and measuring the resulting change in the model's predictions. Correct attributions should lead to significant changes in model predictions when the most important features are perturbed.

Attributes:

Name Type Description
model Module

The model.

explainer Optional[Explainer]=None

The explainer whose explanations are being evaluated.

channel_dim int

Target channel dimension.

n_steps int

The number of perturbation steps.

baseline_fn Optional[BaselineFunction]

Function to generate baseline inputs for perturbation.

prob_fn Optional[Callable[[Tensor], Tensor]]

Function to compute probabilities from model outputs.

pred_fn Optional[Callable[[Tensor], Tensor]]

Function to compute predictions from model outputs.

Methods:

Name Description
evaluate

Evaluate the explainer's correctness based on the attributions by observing changes in model predictions.

Source code in pnpxai/evaluator/metrics/pixel_flipping.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class PixelFlipping(Metric):
    """
    A metric class for evaluating the correctness of explanations or attributions provided by an explainer 
    using the pixel flipping technique.

    This class assesses the quality of attributions by perturbing input features (e.g., pixels) in the order 
    of their attributed importance and measuring the resulting change in the model's predictions. Correct attributions 
    should lead to significant changes in model predictions when the most important features are perturbed.

    Attributes:
        model (Module): The model.
        explainer (Optional[Explainer]=None): The explainer whose explanations are being evaluated.
        channel_dim (int): Target channel dimension.
        n_steps (int): The number of perturbation steps.
        baseline_fn (Optional[BaselineFunction]): Function to generate baseline inputs for perturbation.
        prob_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute probabilities from model outputs.
        pred_fn (Optional[Callable[[Tensor], Tensor]]): Function to compute predictions from model outputs.

    Methods:
        evaluate(inputs, targets, attributions, attention_mask=None, descending=True):
            Evaluate the explainer's correctness based on the attributions by observing changes in model predictions.
    """

    def __init__(
        self,
        model: Module,
        explainer: Optional[Explainer]=None,
        channel_dim: int=1,
        n_steps: int=10,
        baseline_fn: Optional[BaselineFunction]=None,
        prob_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.softmax(-1),
        pred_fn: Optional[Callable[[Tensor], Tensor]]=lambda outputs: outputs.argmax(-1),
    ):
        super().__init__(model, explainer)
        self.channel_dim = channel_dim
        self.n_steps = n_steps
        self.baseline_fn = baseline_fn
        self.prob_fn = prob_fn
        self.pred_fn = pred_fn

    @torch.no_grad()
    def evaluate(
        self,
        inputs: TensorOrTupleOfTensors,
        targets: Tensor,
        attributions: TensorOrTupleOfTensors,
        attention_mask: Optional[TensorOrTupleOfTensors]=None,
        descending: bool=True,
    ) -> Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]]:
        """
        Evaluate the explainer's correctness based on the attributions by observing changes in model predictions.

        Args:
            inputs (TensorOrTupleOfTensors): The input tensors to the model.
            targets (Tensor): The target labels for the inputs.
            attributions (TensorOrTupleOfTensors): The attributions for the inputs.
            attention_mask (Optional[TensorOrTupleOfTensors], optional): Attention masks for the inputs. Default is None.
            descending (bool, optional): Whether to flip pixels in descending order of attribution. Default is True.

        Returns:
            Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]]: A dictionary or tuple of dictionaries containing
                the probabilities and predictions at each perturbation step.
        """
        forward_args, additional_forward_args = self.explainer._extract_forward_args(inputs)
        formatted: Dict[str, Tuple[Any]] = format_into_tuple_all(
            forward_args=forward_args,
            additional_forward_args=additional_forward_args,
            attributions=attributions,
            channel_dim=self.channel_dim,
            baseline_fn=self.baseline_fn,
            attention_mask=attention_mask or (None,)*len(format_into_tuple(forward_args)),
        )
        assert all(
            len(formatted['forward_args']) == len(formatted[k]) for k in formatted
            if k != 'additional_forward_args'
        )

        bsz = formatted['forward_args'][0].size(0)
        results = []

        outputs = self.model(
            *formatted['forward_args'],
            *formatted['additional_forward_args'],
        )
        init_probs = self.prob_fn(outputs)
        init_preds = self.pred_fn(outputs)

        for pos, forward_arg in enumerate(formatted['forward_args']):
            baseline_fn = formatted['baseline_fn'][pos]
            attrs = formatted['attributions'][pos]
            attrs, original_size = _flatten_if_not_1d(attrs)
            if formatted['attention_mask'][pos] is not None:
                attn_mask, _ = _flatten_if_not_1d(formatted['attention_mask'][pos])
                mask_value = -torch.inf if descending else torch.inf
                attrs = torch.where(attn_mask == 1, attrs, mask_value)

            valid_n_features = (~attrs.isinf()).sum(-1)
            n_flipped_per_step = valid_n_features // self.n_steps
            n_flipped_per_step = n_flipped_per_step.clamp(min=1) # ensure at least a pixel flipped
            sorted_indices = torch.argsort(
                attrs,
                descending=descending,
                stable=True,
            )
            probs = [_extract_target_probs(init_probs, targets)]
            preds = [init_preds]
            for step in range(1, self.n_steps):
                n_flipped = n_flipped_per_step * step
                if step + 1 == self.n_steps:
                    n_flipped = valid_n_features
                is_index_of_flipped = (
                    F.one_hot(n_flipped-1, num_classes=attrs.size(-1)).to(self.device)
                    .flip(-1).cumsum(-1).flip(-1)
                )
                is_flipped = _sort_by_order(
                    is_index_of_flipped, sorted_indices.argsort(-1))
                is_flipped = _recover_shape_if_flattened(is_flipped, original_size)
                is_flipped = _match_channel_dim_if_pooled(
                    is_flipped,
                    formatted['channel_dim'][pos],
                    forward_arg.size()
                )

                baseline = baseline_fn(forward_arg)
                flipped_forward_arg = baseline * is_flipped + forward_arg * (1 - is_flipped)

                flipped_forward_args = tuple(
                    flipped_forward_arg if i == pos else formatted['forward_args'][i]
                    for i in range(len(formatted['forward_args']))
                )
                flipped_outputs = self.model(
                    *flipped_forward_args,
                    *formatted['additional_forward_args'],
                )
                probs.append(_extract_target_probs(self.prob_fn(flipped_outputs), targets))
                preds.append(self.pred_fn(flipped_outputs))
            results.append({
                'probs': torch.stack(probs).transpose(1, 0),
                'preds': torch.stack(preds).transpose(1, 0),
            })
        if len(results) == 1:
            return results[0]
        return tuple(results)

evaluate(inputs, targets, attributions, attention_mask=None, descending=True)

Evaluate the explainer's correctness based on the attributions by observing changes in model predictions.

Parameters:

Name Type Description Default
inputs TensorOrTupleOfTensors

The input tensors to the model.

required
targets Tensor

The target labels for the inputs.

required
attributions TensorOrTupleOfTensors

The attributions for the inputs.

required
attention_mask Optional[TensorOrTupleOfTensors]

Attention masks for the inputs. Default is None.

None
descending bool

Whether to flip pixels in descending order of attribution. Default is True.

True

Returns:

Type Description
Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]]

Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]]: A dictionary or tuple of dictionaries containing the probabilities and predictions at each perturbation step.

Source code in pnpxai/evaluator/metrics/pixel_flipping.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
@torch.no_grad()
def evaluate(
    self,
    inputs: TensorOrTupleOfTensors,
    targets: Tensor,
    attributions: TensorOrTupleOfTensors,
    attention_mask: Optional[TensorOrTupleOfTensors]=None,
    descending: bool=True,
) -> Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]]:
    """
    Evaluate the explainer's correctness based on the attributions by observing changes in model predictions.

    Args:
        inputs (TensorOrTupleOfTensors): The input tensors to the model.
        targets (Tensor): The target labels for the inputs.
        attributions (TensorOrTupleOfTensors): The attributions for the inputs.
        attention_mask (Optional[TensorOrTupleOfTensors], optional): Attention masks for the inputs. Default is None.
        descending (bool, optional): Whether to flip pixels in descending order of attribution. Default is True.

    Returns:
        Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]]: A dictionary or tuple of dictionaries containing
            the probabilities and predictions at each perturbation step.
    """
    forward_args, additional_forward_args = self.explainer._extract_forward_args(inputs)
    formatted: Dict[str, Tuple[Any]] = format_into_tuple_all(
        forward_args=forward_args,
        additional_forward_args=additional_forward_args,
        attributions=attributions,
        channel_dim=self.channel_dim,
        baseline_fn=self.baseline_fn,
        attention_mask=attention_mask or (None,)*len(format_into_tuple(forward_args)),
    )
    assert all(
        len(formatted['forward_args']) == len(formatted[k]) for k in formatted
        if k != 'additional_forward_args'
    )

    bsz = formatted['forward_args'][0].size(0)
    results = []

    outputs = self.model(
        *formatted['forward_args'],
        *formatted['additional_forward_args'],
    )
    init_probs = self.prob_fn(outputs)
    init_preds = self.pred_fn(outputs)

    for pos, forward_arg in enumerate(formatted['forward_args']):
        baseline_fn = formatted['baseline_fn'][pos]
        attrs = formatted['attributions'][pos]
        attrs, original_size = _flatten_if_not_1d(attrs)
        if formatted['attention_mask'][pos] is not None:
            attn_mask, _ = _flatten_if_not_1d(formatted['attention_mask'][pos])
            mask_value = -torch.inf if descending else torch.inf
            attrs = torch.where(attn_mask == 1, attrs, mask_value)

        valid_n_features = (~attrs.isinf()).sum(-1)
        n_flipped_per_step = valid_n_features // self.n_steps
        n_flipped_per_step = n_flipped_per_step.clamp(min=1) # ensure at least a pixel flipped
        sorted_indices = torch.argsort(
            attrs,
            descending=descending,
            stable=True,
        )
        probs = [_extract_target_probs(init_probs, targets)]
        preds = [init_preds]
        for step in range(1, self.n_steps):
            n_flipped = n_flipped_per_step * step
            if step + 1 == self.n_steps:
                n_flipped = valid_n_features
            is_index_of_flipped = (
                F.one_hot(n_flipped-1, num_classes=attrs.size(-1)).to(self.device)
                .flip(-1).cumsum(-1).flip(-1)
            )
            is_flipped = _sort_by_order(
                is_index_of_flipped, sorted_indices.argsort(-1))
            is_flipped = _recover_shape_if_flattened(is_flipped, original_size)
            is_flipped = _match_channel_dim_if_pooled(
                is_flipped,
                formatted['channel_dim'][pos],
                forward_arg.size()
            )

            baseline = baseline_fn(forward_arg)
            flipped_forward_arg = baseline * is_flipped + forward_arg * (1 - is_flipped)

            flipped_forward_args = tuple(
                flipped_forward_arg if i == pos else formatted['forward_args'][i]
                for i in range(len(formatted['forward_args']))
            )
            flipped_outputs = self.model(
                *flipped_forward_args,
                *formatted['additional_forward_args'],
            )
            probs.append(_extract_target_probs(self.prob_fn(flipped_outputs), targets))
            preds.append(self.pred_fn(flipped_outputs))
        results.append({
            'probs': torch.stack(probs).transpose(1, 0),
            'preds': torch.stack(preds).transpose(1, 0),
        })
    if len(results) == 1:
        return results[0]
    return tuple(results)