Skip to content

GradientXInput

GradientXInput

Bases: Explainer

Grad X Input explainer.

Supported Modules: Linear, Convolution, LSTM, RNN, Attention

Parameters:

Name Type Description Default
model Module

The PyTorch model for which attribution is to be computed.

required
layer Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]

The target module to be explained

None
forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned.

None
additional_forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

A secondary function that extract additional forward arguments from the input batch(s).

None
**kwargs

Keyword arguments that are forwarded to the base implementation of the Explainer

required
Reference

Avanti Shrikumar, Peyton Greenside, Anna Shcherbina, Anshul Kundaje. Not Just a Black Box: Learning Important Features Through Propagating Activation Differences.

Source code in pnpxai/explainers/grad_x_input.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class GradientXInput(Explainer):
    """
    Grad X Input explainer.

    Supported Modules: `Linear`, `Convolution`, `LSTM`, `RNN`, `Attention`

    Parameters:
        model (Module): The PyTorch model for which attribution is to be computed.
        layer (Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]]): The target module to be explained
        forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned.
        additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s).
        **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer

    Reference:
        Avanti Shrikumar, Peyton Greenside, Anna Shcherbina, Anshul Kundaje. Not Just a Black Box: Learning Important Features Through Propagating Activation Differences.
    """

    SUPPORTED_MODULES = [Linear, Convolution, LSTM, RNN, Attention]

    def __init__(
        self,
        model: Module,
        layer: Optional[Union[Union[str, Module], Sequence[Union[str, Module]]]] = None,
        forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None,
        additional_forward_arg_extractor: Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]] = None,
    ) -> None:
        super().__init__(model, forward_arg_extractor, additional_forward_arg_extractor)
        self.layer = layer


    @property
    def _layer_explainer(self) -> CaptumLayerGradientXInput:
        wrapped_model = captum_wrap_model_input(self.model)
        layers = [
            wrapped_model.input_maps[layer] if isinstance(layer, str)
            else layer for layer in self.layer
        ] if isinstance(self.layer, Sequence) else self.layer
        return CaptumLayerGradientXInput(
            forward_func=wrapped_model,
            layer=layers,
        )

    @property
    def _explainer(self) -> CaptumGradientXInput:
        return CaptumGradientXInput(forward_func=self.model)

    @property
    def explainer(self) -> Union[CaptumGradientXInput, CaptumLayerGradientXInput]:
        if self.layer is None:
            return self._explainer
        return self._layer_explainer

    def attribute(
        self,
        inputs: Union[Tensor, Tuple[Tensor]],
        targets: Tensor
    ) -> Union[Tensor, Tuple[Tensor]]:
        """
        Computes attributions for the given inputs and targets.

        Args:
            inputs (torch.Tensor): The input data.
            targets (torch.Tensor): The target labels for the inputs.

        Returns:
            Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation.
        """

        forward_args, additional_forward_args = self._extract_forward_args(inputs)
        attrs = self.explainer.attribute(
            inputs=forward_args,
            target=targets,
            additional_forward_args=additional_forward_args,
        )
        if isinstance(attrs, list):
            attrs = tuple(attrs)
        if isinstance(attrs, tuple) and len(attrs) == 1:
            attrs = attrs[0]
        return attrs

attribute(inputs, targets)

Computes attributions for the given inputs and targets.

Parameters:

Name Type Description Default
inputs Tensor

The input data.

required
targets Tensor

The target labels for the inputs.

required

Returns:

Type Description
Union[Tensor, Tuple[Tensor]]

Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation.

Source code in pnpxai/explainers/grad_x_input.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def attribute(
    self,
    inputs: Union[Tensor, Tuple[Tensor]],
    targets: Tensor
) -> Union[Tensor, Tuple[Tensor]]:
    """
    Computes attributions for the given inputs and targets.

    Args:
        inputs (torch.Tensor): The input data.
        targets (torch.Tensor): The target labels for the inputs.

    Returns:
        Union[torch.Tensor, Tuple[torch.Tensor]]: The result of the explanation.
    """

    forward_args, additional_forward_args = self._extract_forward_args(inputs)
    attrs = self.explainer.attribute(
        inputs=forward_args,
        target=targets,
        additional_forward_args=additional_forward_args,
    )
    if isinstance(attrs, list):
        attrs = tuple(attrs)
    if isinstance(attrs, tuple) and len(attrs) == 1:
        attrs = attrs[0]
    return attrs