Skip to content

PnPXAI: Plug-and-Play Explainable AI


PnPXAI is a Python package that provides a modular and easy-to-use framework for explainable artificial intelligence (XAI). It allows users to apply various XAI methods to their own models and datasets, and visualize the results in an interactive and intuitive way.

Features

  • Detector: The detector module provides automatic detection of AI models implemented in PyTorch.
  • Evaluator: The evaluator module provides various ways to evaluate and compare the performance and explainability of AI models with the categorized evaluation properties of correctness (fidelity, area between perturbation curves), continuity (sensitivity), and compactness (complexity).
  • Explainers: The explainers module contains a collection of state-of-the-art XAI methods that can generate global or local explanations for any AI model, such as:
  • Recommender: The recommender module offers a recommender system that can suggest the most suitable XAI methods for a given model and dataset, based on the user’s preferences and goals.
  • Optimizer: The optimizer module is finds the best hyperparameter options, given a user-specified metric.

Project Core API

  • Experiment: module, responsible for data manipulation, model explanation and explanations' evaluation
  • Auto Explanation: module, responsible for data manipulation, model explanation and explanations' evaluation

Installation

To install pnpxai, run the following command:

# Command lines for installation
pip install -e .

Getting Started

Auto Explanation

To use pnpxai, you need to import the package and its modules in your Python script. Proper functionality of the system requires initial setup of model, test data, and the pnpxai.AutoExplanationForImageClassification.

The explanation module can be specified according to a modality, which fits best for the user's task. Specifically, pnpxai offers AutoExplanationForImageClassification, AutoExplanationForTextClassification, AutoExplanationForTSClassification, AutoExplanationForVisualQuestionAnswering, which can be utilized for image, test, time series, and a combination of image and text modalities respectively.

import torch
from torcu.utils.data import DataLoader

from pnpxai import AutoExplanationForImageClassification

# Bring your model
model = ...

# Prepare your data
dataset = ...
loader = DataLoader(dataset, batch_size=...)

In addition to regular experiment setup, the library requires input_extractor, target_extractor, and label_extractor, which are used for passing the test data into the model. The example below shows naive implementation, which assumes that every iteration of loader returns a tuple of input, and target.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def input_extractor(x): return x[0].to(device)
def label_extractor(x): return x[1].to(device)
def target_extractor(outputs): outputs.argmax(-1).to(device)

Final setup step is initialization of AutoExplanationForImageClassification with aforementioned values, and the start of optimization process. To start optimization process, it is required to choose desired data ids, as well as, explainer, and metric from the list of suggested ones.

experiment = AutoExplanationForImageClassification(
    model,
    loader,
    input_extractor=input_extractor,
    label_extractor=label_extractor,
    target_extractor=target_extractor,
    target_labels=False,
)
optimized = experiment.optimize(
    data_ids=range(16),
    explainer_id=2,
    metric_id=1,
    direction='maximize', # less is better
    sampler='tpe', # Literal['tpe','random']
    n_trials=50, # by default, 50 for sampler in ['random', 'tpe'], None for ['grid']
    seed=42, # seed for sampler: by default, None
)

Complete backbone of initialization code can be represented as follows:

import torch
from torcu.utils.data import DataLoader

from pnpxai import AutoExplanationForImageClassification

# Bring your model
model = ...

# Prepare your data
dataset = ...
loader = DataLoader(dataset, batch_size=...)
def input_extractor(x):
    ...
def target_extractor(x):
    ...

# Auto-explanation
experiment = AutoExplanationForImageClassification(
    model,
    loader,
    input_extractor=input_extractor,
    label_extractor=label_extractor,
    target_extractor=target_extractor,
    target_labels=False,
)
optimized = experiment.optimize(
    data_ids=range(16),
    explainer_id=2,
    metric_id=1,
    direction='maximize', # less is better
    sampler='tpe', # Literal['tpe','random']
    n_trials=50, # by default, 50 for sampler in ['random', 'tpe'], None for ['grid']
    seed=42, # seed for sampler: by default, None
)

Manual Setup

AutoExplanationForImageClassification class is guided by pnpxai.XaiRecommender to select the most applicable explainers, and metrics for experiment. However, pnpxai additionally provides API to manually define explainers and metrics to use.

Here, users are asked to manually define modalities, in order to enable modality-dependent control flow. The pnpxai package comes with a set of predefined modalities, namely ImageModality, TextModality, TimeSeriesModality. However, API also enables extension possiblity with the help of a Modality base class.

import torch
from torcu.utils.data import DataLoader

from pnpxai import Experiment
from pnpxai.core.modality import ImageModality
from pnpxai.explainers import LRPEpsilonPlus
from pnpxai.evaluator.metrics import MuFidelity
from pnpxai.explainers.utils.postprocess import Identity

# Bring your model
model = ...

# Prepare your data
dataset = ...
loader = DataLoader(dataset, batch_size=...)
def input_extractor(x):
    ...
def label_extractor(x):
    ...
def target_extractor(x):
    ...

# Auto-explanation
explainer = LRPEpsilonPlus(model)
metric = MuFidelity(model, explainer)
postprocessor = Identity()
modality = ImageModality()

experiment = Experiment(
    model=model,
    data=loader,
    modality=ImageModality(),
    explainers=[explainer],
    postprocessors=[Identity()],
    metrics=[metric],
    input_extractor=lambda x: x[0].to(device),
    label_extractor=lambda x: x[-1].to(device),
    target_extractor=lambda outputs: outputs.argmax(-1).to(device)
)