Skip to content

Base

Explainer

Bases: ABC

Abstract base class for implementing attribution explanations for machine learning models.

This class provides methods for extracting forward arguments, loading baseline and feature mask functions, and applying them during attribution.

Parameters:

Name Type Description Default
model Module

The PyTorch model for which attribution is to be computed.

required
forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract forward arguments from inputs.

None
additional_forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract additional forward arguments.

None
**kwargs

Additional keyword arguments to pass to the constructor.

{}
Notes
  • Subclasses must implement the attribute method to define how attributions are computed.
  • The forward_arg_extractor and additional_forward_arg_extractor functions allow for customization in extracting forward arguments from the inputs.
Source code in pnpxai/explainers/base.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class Explainer(ABC):
    """
    Abstract base class for implementing attribution explanations for machine learning models.

    This class provides methods for extracting forward arguments, loading baseline and feature mask functions, and applying them during attribution.

    Parameters:
        model (Module): The PyTorch model for which attribution is to be computed.
        forward_arg_extractor (Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]): Optional function to extract forward arguments from inputs.
        additional_forward_arg_extractor (Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]): Optional function to extract additional forward arguments.
        **kwargs: Additional keyword arguments to pass to the constructor.

    Notes:
        - Subclasses must implement the `attribute` method to define how attributions are computed.
        - The `forward_arg_extractor` and `additional_forward_arg_extractor` functions allow for customization in extracting forward arguments from the inputs.
    """

    EXPLANATION_TYPE: ExplanationType = "attribution"
    SUPPORTED_MODULES = []
    TUNABLES = {}

    def __init__(
        self,
        model: Module,
        forward_arg_extractor: Optional[ForwardArgumentExtractor] = None,
        additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None,
        **kwargs,
    ) -> None:
        self.model = model.eval()
        self.forward_arg_extractor = forward_arg_extractor
        self.additional_forward_arg_extractor = additional_forward_arg_extractor

    @property
    def device(self):
        return next(self.model.parameters()).device

    def __repr__(self):
        kwargs_repr = ', '.join(
            '{}={}'.format(key, value)
            for key, value in self.__dict__.items()
            if key not in NON_DISPLAYED_ATTRS and value is not None
        )
        return "{}({})".format(self.__class__.__name__, kwargs_repr)

    def _extract_forward_args(
        self, inputs: Union[Tensor, Tuple[Tensor]]
    ) -> Tuple[Union[Tensor, Tuple[Tensor], Type[None]]]:
        forward_args = (
            self.forward_arg_extractor(inputs) if self.forward_arg_extractor else inputs
        )
        additional_forward_args = (
            self.additional_forward_arg_extractor(inputs)
            if self.additional_forward_arg_extractor
            else None
        )
        return forward_args, additional_forward_args

    def copy(self):
        return copy.copy(self)

    def set_kwargs(self, **kwargs):
        clone = self.copy()
        for k, v in kwargs.items():
            setattr(clone, k, v)
        return clone

    def _load_util_fn(
        self, util_attr: str, util_fn_class: Type[UtilFunction]
    ) -> Optional[Union[UtilFunction, Tuple[UtilFunction]]]:
        attr = getattr(self, util_attr)
        if attr is None:
            return None

        attr_values = []
        for attr_value in format_into_tuple(attr):
            if isinstance(attr_value, str):
                attr_value = util_fn_class.from_method(method=attr_value)
            attr_values.append(attr_value)
        attr_values = tuple(attr_values)
        return format_out_tuple_if_single(attr_values)

    def _get_baselines(self, forward_args) -> Union[Tensor, Tuple[Tensor]]:
        baseline_fns = self._load_util_fn("baseline_fn", BaselineFunction)
        if baseline_fns is None:
            return None

        forward_args = format_into_tuple(forward_args)
        baseline_fns = format_into_tuple(baseline_fns)

        assert len(forward_args) == len(baseline_fns)
        baselines = tuple(
            baseline_fn(forward_arg)
            for baseline_fn, forward_arg in zip(baseline_fns, forward_args)
        )
        return format_out_tuple_if_single(baselines)

    def _get_feature_masks(self, forward_args) -> Union[Tensor, Tuple[Tensor]]:
        feature_mask_fns = self._load_util_fn("feature_mask_fn", FeatureMaskFunction)
        if feature_mask_fns is None:
            return None

        feature_mask_fns = format_into_tuple(feature_mask_fns)
        forward_args = format_into_tuple(forward_args)

        assert len(forward_args) == len(feature_mask_fns)
        feature_masks = []
        max_vals = None
        for feature_mask_fn, forward_arg in zip(feature_mask_fns, forward_args):
            feature_mask = feature_mask_fn(forward_arg)
            if max_vals is not None:
                feature_mask += (
                    max_vals[(...,) + (None,) * (feature_mask.dim() - 1)] + 1
                )
            feature_masks.append(feature_mask)

            # update max_vals
            bsz, *size = feature_mask.size()
            max_vals = feature_mask.view(-1, math.prod(size)).max(axis=1).values
        feature_masks = tuple(feature_masks)
        return format_out_tuple_if_single(feature_masks)

    @abstractmethod
    def attribute(
        self,
        inputs: Union[Tensor, Tuple[Tensor]],
        targets: Tensor,
    ) -> Union[Tensor, Tuple[Tensor]]:
        """
        Computes attributions for the given inputs and targets.

        Args:
            inputs (Union[Tensor, Tuple[Tensor]]): The inputs for the model.
            targets (Tensor): The target labels.

        Returns:
            Union[Tensor, Tuple[Tensor]]: The computed attributions.
        """
        raise NotImplementedError

    def get_tunables(self) -> Dict[str, Tuple[type, dict]]:
        """
        Returns a dictionary of tunable parameters for the explainer.

        Returns:
            Dict[str, Tuple[type, dict]]: Dictionary of tunable parameters.
        """
        return {}

attribute(inputs, targets) abstractmethod

Computes attributions for the given inputs and targets.

Parameters:

Name Type Description Default
inputs Union[Tensor, Tuple[Tensor]]

The inputs for the model.

required
targets Tensor

The target labels.

required

Returns:

Type Description
Union[Tensor, Tuple[Tensor]]

Union[Tensor, Tuple[Tensor]]: The computed attributions.

Source code in pnpxai/explainers/base.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@abstractmethod
def attribute(
    self,
    inputs: Union[Tensor, Tuple[Tensor]],
    targets: Tensor,
) -> Union[Tensor, Tuple[Tensor]]:
    """
    Computes attributions for the given inputs and targets.

    Args:
        inputs (Union[Tensor, Tuple[Tensor]]): The inputs for the model.
        targets (Tensor): The target labels.

    Returns:
        Union[Tensor, Tuple[Tensor]]: The computed attributions.
    """
    raise NotImplementedError

get_tunables()

Returns a dictionary of tunable parameters for the explainer.

Returns:

Type Description
Dict[str, Tuple[type, dict]]

Dict[str, Tuple[type, dict]]: Dictionary of tunable parameters.

Source code in pnpxai/explainers/base.py
170
171
172
173
174
175
176
177
def get_tunables(self) -> Dict[str, Tuple[type, dict]]:
    """
    Returns a dictionary of tunable parameters for the explainer.

    Returns:
        Dict[str, Tuple[type, dict]]: Dictionary of tunable parameters.
    """
    return {}