Skip to content

AttentionRollout

AttentionRollout

Bases: AttentionRolloutBase

Implementation of AttentionRollout explainer.

Supported Modules: Attention

Parameters:

Name Type Description Default
model Module

The PyTorch model for which attribution is to be computed.

required
interpolate_mode Optional[str]

The interpolation mode used by the explainer. Available methods are: "bilinear" and "bicubic"

'bilinear'
head_fusion_method Literal['min', 'max', 'mean']

(Optional[str]): Method to apply to head fusion. Available methods are: "min", "max", "mean"

'min'
discard_ratio float

(Optional[float]): Describes ration of attention values to discard.

0.9
forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract forward arguments from inputs.

None
additional_forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract additional forward arguments.

None
n_classes Optional[int]

(Optional[int]): Number of classes

None
**kwargs

Keyword arguments that are forwarded to the base implementation of the Explainer

required
Reference

Samira Abnar, Willem Zuidema. Quantifying Attention Flow in Transformers.

Source code in pnpxai/explainers/attention_rollout.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class AttentionRollout(AttentionRolloutBase):
    """
    Implementation of `AttentionRollout` explainer.

    Supported Modules: `Attention`

    Parameters:
        model (Module): The PyTorch model for which attribution is to be computed.
        interpolate_mode (Optional[str]): The interpolation mode used by the explainer. Available methods are: "bilinear" and "bicubic"
        head_fusion_method: (Optional[str]): Method to apply to head fusion. Available methods are: `"min"`, `"max"`, `"mean"`
        discard_ratio: (Optional[float]): Describes ration of attention values to discard.
        forward_arg_extractor (Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]): Optional function to extract forward arguments from inputs.
        additional_forward_arg_extractor (Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]): Optional function to extract additional forward arguments.
        n_classes: (Optional[int]): Number of classes
        **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer

    Reference:
        Samira Abnar, Willem Zuidema. Quantifying Attention Flow in Transformers.
    """
    def __init__(
        self,
        model: Module,
        interpolate_mode: Literal['bilinear']='bilinear',
        head_fusion_method: Literal['min', 'max', 'mean']='min',
        discard_ratio: float=0.9,
        forward_arg_extractor: Optional[ForwardArgumentExtractor]=None,
        additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None,
        n_classes: Optional[int]=None,
    ) -> None:
        super().__init__(
            model,
            interpolate_mode,
            head_fusion_method,
            discard_ratio,
            forward_arg_extractor,
            additional_forward_arg_extractor,
            n_classes
        )

    def collect_attention_map(self, inputs, targets):
        # get all attn maps
        with SavingAttentionAttributor(model=self.model) as attributor:
            weights_all = attributor(inputs, None)
        return (weights_all,)

    def rollout(self, weights_all):
        sz = weights_all[0].size()
        assert all(
            attn_weights.size() == sz
            for attn_weights in weights_all
        )
        bsz, num_heads, tgt_len, src_len = sz
        rollout = torch.eye(tgt_len).repeat(bsz, 1, 1).to(self.device)
        for attn_weights in weights_all:
            attn_map = self.head_fusion_function(attn_weights)
            attn_map = self._discard(attn_map)
            identity = torch.eye(tgt_len).repeat(bsz, 1, 1).to(self.device)
            attn_map = .5 * attn_map + .5 * identity
            attn_map /= attn_map.sum(dim=-1, keepdim=True)
            rollout = torch.matmul(rollout, attn_map)
        return rollout

AttentionRolloutBase

Bases: ZennitExplainer

Base class for AttentionRollout and TransformerAttribution explainers.

Supported Modules: Attention

Parameters:

Name Type Description Default
model Module

The PyTorch model for which attribution is to be computed.

required
interpolate_mode Optional[str]

The interpolation mode used by the explainer. Available methods are: "bilinear" and "bicubic"

'bilinear'
head_fusion_method Literal['min', 'max', 'mean']

(Optional[str]): Method to apply to head fusion. Available methods are: "min", "max", "mean"

'min'
discard_ratio float

(Optional[float]): Describes ration of attention values to discard.

0.9
forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract forward arguments from inputs.

None
additional_forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract additional forward arguments.

None
n_classes Optional[int]

(Optional[int]): Number of classes

None
forward_arg_extractor Optional[ForwardArgumentExtractor]

A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned.

None
additional_forward_arg_extractor Optional[ForwardArgumentExtractor]

A secondary function that extract additional forward arguments from the input batch(s).

None
**kwargs

Keyword arguments that are forwarded to the base implementation of the Explainer

required
Reference

Samira Abnar, Willem Zuidema. Quantifying Attention Flow in Transformers.

Source code in pnpxai/explainers/attention_rollout.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class AttentionRolloutBase(ZennitExplainer):
    """
    Base class for `AttentionRollout` and `TransformerAttribution` explainers.

    Supported Modules: `Attention`

    Parameters:
        model (Module): The PyTorch model for which attribution is to be computed.
        interpolate_mode (Optional[str]): The interpolation mode used by the explainer. Available methods are: "bilinear" and "bicubic"
        head_fusion_method: (Optional[str]): Method to apply to head fusion. Available methods are: `"min"`, `"max"`, `"mean"`
        discard_ratio: (Optional[float]): Describes ration of attention values to discard.
        forward_arg_extractor (Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]): Optional function to extract forward arguments from inputs.
        additional_forward_arg_extractor (Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]): Optional function to extract additional forward arguments.
        n_classes: (Optional[int]): Number of classes
        forward_arg_extractor: A function that extracts forward arguments from the input batch(s) where the attribution scores are assigned.
        additional_forward_arg_extractor: A secondary function that extract additional forward arguments from the input batch(s).
        **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer

    Reference:
        Samira Abnar, Willem Zuidema. Quantifying Attention Flow in Transformers.
    """

    SUPPORTED_MODULES = [Attention]

    def __init__(
        self,
        model: Module,
        interpolate_mode: Literal['bilinear']='bilinear',
        head_fusion_method: Literal['min', 'max', 'mean']='min',
        discard_ratio: float=0.9,
        forward_arg_extractor: Optional[ForwardArgumentExtractor]=None,
        additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None,
        n_classes: Optional[int]=None,
    ) -> None:
        super().__init__(
            model,
            forward_arg_extractor,
            additional_forward_arg_extractor,
            n_classes,
        )
        self.interpolate_mode = interpolate_mode
        self.head_fusion_method = head_fusion_method
        self.discard_ratio = discard_ratio

    @property
    def head_fusion_function(self):
        return _get_rollout_head_fusion_function(self.head_fusion_method)

    @abstractmethod
    def collect_attention_map(self, inputs, targets):
        raise NotImplementedError

    @abstractmethod
    def rollout(self, *args):
        raise NotImplementedError

    def _discard(self, fused_attn_map):
        org_size = fused_attn_map.size() # keep size to recover it after discard
        flattened = fused_attn_map.flatten(1)
        bsz, n_tokens = flattened.size()
        attn_cls = flattened[:, 0] # keep attn scores of cls token to recover them after discard
        _, indices = flattened.topk(
            k=int(n_tokens*self.discard_ratio),
            dim=-1,
            largest=False,
        )
        flattened[torch.arange(bsz)[:, None], indices] = 0. # discard
        flattened[:, 0] = attn_cls # recover attn scores of cls token
        discarded = flattened.view(*org_size)
        return discarded

    def attribute(
        self,
        inputs: Union[Tensor, Tuple[Tensor]],
        targets: Tensor
    ) -> Union[Tensor, Tuple[Tensor]]:
        """
        Computes attributions for the given inputs and targets.

        Args:
            inputs (torch.Tensor): The input data.
            targets (torch.Tensor): The target labels for the inputs.

        Returns:
            torch.Tensor: The result of the explanation.
        """

        attn_maps = self.collect_attention_map(inputs, targets)
        with torch.no_grad():
            rollout = self.rollout(*attn_maps)

        # attn btw cls and patches
        attrs = rollout[:, 0, 1:]
        n_patches = attrs.size(-1)        
        bsz, _, h, w = inputs.size()
        p_h = int(h / w * n_patches ** .5)
        p_w = n_patches // p_h
        attrs = attrs.view(bsz, 1, p_h, p_w)

        # upsampling
        attrs = LayerAttribution.interpolate(
            layer_attribution=attrs,
            interpolate_dims=(h, w),
            interpolate_mode=self.interpolate_mode,
        )
        return attrs

    def get_tunables(self):
        """
        Provides Tunable parameters for the optimizer

        Tunable parameters:
            `interpolate_mode` (str): Value can be selected of `"bilinear"` and `"bicubic"`

            `head_fusion_method` (str): Value can be selected of `"min"`, `"max"`, and `"mean"`

            `discard_ratio` (float): Value can be selected in the range of `range(0, 0.95, 0.05)`
        """
        return {
            'interpolate_mode': (list, {'choices': ['bilinear', 'bicubic']}),
            'head_fusion_method': (list, {'choices': ['min', 'max', 'mean']}),
            'discard_ratio': (float, {'low': 0., 'high': .95, 'step': .05}),
        }

attribute(inputs, targets)

Computes attributions for the given inputs and targets.

Parameters:

Name Type Description Default
inputs Tensor

The input data.

required
targets Tensor

The target labels for the inputs.

required

Returns:

Type Description
Union[Tensor, Tuple[Tensor]]

torch.Tensor: The result of the explanation.

Source code in pnpxai/explainers/attention_rollout.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def attribute(
    self,
    inputs: Union[Tensor, Tuple[Tensor]],
    targets: Tensor
) -> Union[Tensor, Tuple[Tensor]]:
    """
    Computes attributions for the given inputs and targets.

    Args:
        inputs (torch.Tensor): The input data.
        targets (torch.Tensor): The target labels for the inputs.

    Returns:
        torch.Tensor: The result of the explanation.
    """

    attn_maps = self.collect_attention_map(inputs, targets)
    with torch.no_grad():
        rollout = self.rollout(*attn_maps)

    # attn btw cls and patches
    attrs = rollout[:, 0, 1:]
    n_patches = attrs.size(-1)        
    bsz, _, h, w = inputs.size()
    p_h = int(h / w * n_patches ** .5)
    p_w = n_patches // p_h
    attrs = attrs.view(bsz, 1, p_h, p_w)

    # upsampling
    attrs = LayerAttribution.interpolate(
        layer_attribution=attrs,
        interpolate_dims=(h, w),
        interpolate_mode=self.interpolate_mode,
    )
    return attrs

get_tunables()

Provides Tunable parameters for the optimizer

Tunable parameters

interpolate_mode (str): Value can be selected of "bilinear" and "bicubic"

head_fusion_method (str): Value can be selected of "min", "max", and "mean"

discard_ratio (float): Value can be selected in the range of range(0, 0.95, 0.05)

Source code in pnpxai/explainers/attention_rollout.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def get_tunables(self):
    """
    Provides Tunable parameters for the optimizer

    Tunable parameters:
        `interpolate_mode` (str): Value can be selected of `"bilinear"` and `"bicubic"`

        `head_fusion_method` (str): Value can be selected of `"min"`, `"max"`, and `"mean"`

        `discard_ratio` (float): Value can be selected in the range of `range(0, 0.95, 0.05)`
    """
    return {
        'interpolate_mode': (list, {'choices': ['bilinear', 'bicubic']}),
        'head_fusion_method': (list, {'choices': ['min', 'max', 'mean']}),
        'discard_ratio': (float, {'low': 0., 'high': .95, 'step': .05}),
    }

TransformerAttribution

Bases: AttentionRolloutBase

Implementation of TransformerAttribution explainer.

Supported Modules: Attention

Parameters:

Name Type Description Default
model Module

The PyTorch model for which attribution is to be computed.

required
interpolate_mode Optional[str]

The interpolation mode used by the explainer. Available methods are: "bilinear" and "bicubic"

'bilinear'
head_fusion_method Literal['min', 'max', 'mean']

(Optional[str]): Method to apply to head fusion. Available methods are: "min", "max", "mean"

'mean'
discard_ratio float

(Optional[float]): Describes ration of attention values to discard.

0.9
forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract forward arguments from inputs.

None
additional_forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract additional forward arguments.

None
n_classes Optional[int]

(Optional[int]): Number of classes

None
**kwargs

Keyword arguments that are forwarded to the base implementation of the Explainer

required
Reference

Chefer H., Gur S., and Wolf L. Self-Attention Attribution: Transformer interpretability beyond attention visualization.

Source code in pnpxai/explainers/attention_rollout.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class TransformerAttribution(AttentionRolloutBase):
    """
    Implementation of `TransformerAttribution` explainer.

    Supported Modules: `Attention`

    Parameters:
        model (Module): The PyTorch model for which attribution is to be computed.
        interpolate_mode (Optional[str]): The interpolation mode used by the explainer. Available methods are: "bilinear" and "bicubic"
        head_fusion_method: (Optional[str]): Method to apply to head fusion. Available methods are: `"min"`, `"max"`, `"mean"`
        discard_ratio: (Optional[float]): Describes ration of attention values to discard.
        forward_arg_extractor (Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]): Optional function to extract forward arguments from inputs.
        additional_forward_arg_extractor (Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]): Optional function to extract additional forward arguments.
        n_classes: (Optional[int]): Number of classes
        **kwargs: Keyword arguments that are forwarded to the base implementation of the Explainer

    Reference:
        Chefer H., Gur S., and Wolf L. Self-Attention Attribution: Transformer interpretability beyond attention visualization.
    """

    SUPPORTED_MODULES = [Attention]

    def __init__(
        self,
        model: Module,
        interpolate_mode: Literal['bilinear']='bilinear',
        head_fusion_method: Literal['min', 'max', 'mean']='mean',
        discard_ratio: float=0.9,
        alpha: float=2.,
        beta: float=1.,
        stabilizer: float=1e-6,
        zennit_canonizers: Optional[List[Canonizer]]=None,
        layer: Optional[Union[Module, Sequence[Module]]]=None,
        forward_arg_extractor: Optional[ForwardArgumentExtractor]=None,
        additional_forward_arg_extractor: Optional[ForwardArgumentExtractor]=None,
        n_classes: Optional[int]=None
    ) -> None:
        super().__init__(
            model,
            interpolate_mode,
            head_fusion_method,
            discard_ratio,
            forward_arg_extractor,
            additional_forward_arg_extractor,
            n_classes
        )
        self.alpha = alpha
        self.beta = beta
        self.stabilizer = stabilizer
        self.zennit_canonizers = zennit_canonizers or []
        self.layer = layer

    @staticmethod
    def default_head_fusion_fn(attns):
        return attns.mean(dim=1)

    @property
    def zennit_composite(self):
        layer_map = [
            (MultiheadAttention, CGWAttentionPropagation(
                alpha=self.alpha,
                beta=self.beta,
                stabilizer=self.stabilizer,
                save_attn_output_weights=False,
            )),
            (Linear, AlphaBeta(
                alpha=self.alpha,
                beta=self.beta,
                stabilizer=self.stabilizer,
            )),
        ] + layer_map_base(stabilizer=self.stabilizer)
        canonizers = default_attention_converters + self.zennit_canonizers
        return LayerMapComposite(layer_map=layer_map, canonizers=canonizers)

    @property
    def _layer_gradient(self) -> LayerGradient:
        wrapped_model = captum_wrap_model_input(self.model)
        layers = [
            wrapped_model.input_maps[layer] if isinstance(layer, str)
            else layer for layer in self.layer
        ] if isinstance(self.layer, Sequence) else self.layer
        return LayerGradient(
            model=wrapped_model,
            layer=layers,
            composite=self.zennit_composite,
        )

    @property
    def _gradient(self) -> Gradient:
        return Gradient(
            model=self.model,
            composite=self.zennit_composite,
        )

    @property
    def attributor(self):
        if self.layer is None:
            return self._gradient
        return self._layer_gradient

    def collect_attention_map(self, inputs, targets):
        forward_args, additional_forward_args = self._extract_forward_args(inputs)
        with self.attributor as attributor:
            attributor.forward(
                forward_args=forward_args,
                targets=targets,
                additional_forward_args=additional_forward_args,
            )
            grads, rels = [], []
            for hook_ref in attributor.composite.hook_refs:
                if isinstance(hook_ref, CGWAttentionPropagation):
                    grads.append(hook_ref.stored_tensors["attn_grads"])
                    rels.append(hook_ref.stored_tensors["attn_rels"])
        return grads, rels

    def rollout(self, grads, rels):
        bsz, num_heads, tgt_len, src_len = grads[0].shape
        assert tgt_len == src_len, "Must be self-attention"
        rollout = torch.eye(tgt_len).repeat(bsz, 1, 1).to(self.device)
        for grad, rel in zip(grads, rels):
            grad_x_rel = grad * rel
            attn_map = self.head_fusion_function(grad_x_rel)
            attn_map = self._discard(attn_map)
            identity = torch.eye(tgt_len).repeat(bsz, 1, 1).to(self.device)
            attn_map = .5 * attn_map + .5 * identity
            attn_map /= attn_map.sum(dim=-1, keepdim=True)
            rollout = torch.matmul(rollout, attn_map)
        return rollout