Skip to content

Detector [source]

detect_model_architecture traces a model: torch.nn.Module's graph by using torch.fx.symbolic_trace and model.named_modules() to extract information to be applied to recommend available explainers. It returns a Set[torch.nn.Module] which describes layer informations and representative type of architecture.

Basic Usage

import torchvision
from pnpxai.detector import detect_model_architecture

model = torchvision.models.get_model("resnet18").eval()
detected_modules = detect_model_architecture(model)
print(detected_modules)

Notes

Since detect_model_architecture is fully depending on torch.fx.symbolic_trace and (torch.nn.Module()).named_modules(), the detected model architecture omits dynamic control flows and non-torch functions in model.forward, as the official document of torch.fx describing. On the contrary, we can say that all operations in static control flow, torch function or python-builtin function are fully detected. If your model.forward has dynamic control flows or non-torch functions, please change them to static flows or torch functions in order to let the following procedures of pnpxai work in the proper way with your model.