Skip to content

Base

pnpxai.explainers.base

ABC = abc.ABC if sys.version_info >= (3, 4) else abc.ABCMeta(str('ABC'), (), {}) module-attribute
NON_DISPLAYED_ATTRS = ['model', 'forward_arg_extractor', 'additional_forward_arg_extractor', 'device', 'n_classes', 'zennit_composite'] module-attribute
Explainer

Bases: ABC

Abstract base class for implementing attribution explanations for machine learning models.

This class provides methods for extracting forward arguments, loading baseline and feature mask functions, and applying them during attribution.

Parameters:

Name Type Description Default
model Module

The PyTorch model for which attribution is to be computed.

required
forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract forward arguments from inputs.

None
additional_forward_arg_extractor Optional[Callable[[Tuple[Tensor]], Union[Tensor, Tuple[Tensor]]]]

Optional function to extract additional forward arguments.

None
**kwargs

Additional keyword arguments to pass to the constructor.

{}
Notes
  • Subclasses must implement the attribute method to define how attributions are computed.
  • The forward_arg_extractor and additional_forward_arg_extractor functions allow for customization in extracting forward arguments from the inputs.
EXPLANATION_TYPE: ExplanationType = 'attribution' class-attribute instance-attribute
SUPPORTED_MODULES = [] class-attribute instance-attribute
TUNABLES = {} class-attribute instance-attribute
model = model.eval() instance-attribute
forward_arg_extractor = forward_arg_extractor instance-attribute
additional_forward_arg_extractor = additional_forward_arg_extractor instance-attribute
device property
__init__(model: Module, forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, additional_forward_arg_extractor: Optional[ForwardArgumentExtractor] = None, **kwargs) -> None
__repr__()
copy()
set_kwargs(**kwargs)
attribute(inputs: Union[Tensor, Tuple[Tensor]], targets: Tensor) -> Union[Tensor, Tuple[Tensor]] abstractmethod

Computes attributions for the given inputs and targets.

Parameters:

Name Type Description Default
inputs Union[Tensor, Tuple[Tensor]]

The inputs for the model.

required
targets Tensor

The target labels.

required

Returns:

Type Description
Union[Tensor, Tuple[Tensor]]

Union[Tensor, Tuple[Tensor]]: The computed attributions.

get_tunables() -> Dict[str, Tuple[type, dict]]

Returns a dictionary of tunable parameters for the explainer.

Returns:

Type Description
Dict[str, Tuple[type, dict]]

Dict[str, Tuple[type, dict]]: Dictionary of tunable parameters.