Skip to content

Experiment

Experiment

Bases: Observable

A class representing an experiment for model interpretability.

Parameters:

Name Type Description Default
model Model

The machine learning model to be interpreted.

required
data DataSource

The data used for the experiment.

required
modality Modality

The type of modality (image, tabular, text, time series) the model is designed for.

required
explainers Sequence[Explainer]

Explainer objects or their arguments for interpreting the model.

required
postprocessors Optional[Sequence[Callable]]

Postprocessing functions to apply to explanations.

required
metrics Optional[Sequence[Metric]]

Evaluation metrics used to assess model interpretability.

required
input_extractor Optional[Callable[[Any], Any]]

Function to extract inputs from data.

None
label_extractor Optional[Callable[[Any], Any]]

Function to extract labels from data.

None
target_extractor Optional[Callable[[Any], Any]]

Function to extract targets from data.

None
input_visualizer Optional[Callable[[Any], Any]]

Function to visualize input data.

None
target_visualizer Optional[Callable[[Any], Any]]

Function to visualize target data.

None
cache_device Optional[Union[device, str]]

Device to cache data and results.

None
target_labels bool

True if the target is a label, False otherwise.

False

Attributes:

Name Type Description
modality Modality

Object defining the modality-specific control flow of the experiment.

manager ExperimentManager

Manager object for the experiment.

all_explainers Sequence[Explainer]

All explainer objects used in the experiment.

all_metrics Sequence[Metric]

All evaluation metrics used in the experiment.

errors Sequence[Error]
is_image_task bool

True if the modality is an image-related modality, False otherwise.

has_explanations bool

True if the experiment has explanations, False otherwise.

Source code in pnpxai/core/experiment/experiment.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
class Experiment(Observable):
    """
    A class representing an experiment for model interpretability.

    Args:
        model (Model): The machine learning model to be interpreted.
        data (DataSource): The data used for the experiment.
        modality (Modality): The type of modality (image, tabular, text, time series) the model is designed for.
        explainers (Sequence[Explainer]): Explainer objects or their arguments for interpreting the model.
        postprocessors (Optional[Sequence[Callable]]): Postprocessing functions to apply to explanations.
        metrics (Optional[Sequence[Metric]]): Evaluation metrics used to assess model interpretability.
        input_extractor (Optional[Callable[[Any], Any]]): Function to extract inputs from data.
        label_extractor (Optional[Callable[[Any], Any]]): Function to extract labels from data.
        target_extractor (Optional[Callable[[Any], Any]]): Function to extract targets from data.
        input_visualizer (Optional[Callable[[Any], Any]]): Function to visualize input data.
        target_visualizer (Optional[Callable[[Any], Any]]): Function to visualize target data.
        cache_device (Optional[Union[torch.device, str]]): Device to cache data and results.
        target_labels (bool): True if the target is a label, False otherwise.

    Attributes:
        modality (Modality): Object defining the modality-specific control flow of the experiment.
        manager (ExperimentManager): Manager object for the experiment.
        all_explainers (Sequence[Explainer]): All explainer objects used in the experiment.
        all_metrics (Sequence[Metric]): All evaluation metrics used in the experiment.
        errors (Sequence[Error]): 
        is_image_task (bool): True if the modality is an image-related modality, False otherwise.
        has_explanations (bool): True if the experiment has explanations, False otherwise.
    """

    def __init__(
        self,
        model: Model,
        data: DataSource,
        modality: Modality,
        explainers: Sequence[Explainer],
        postprocessors: Sequence[Callable],
        metrics: Sequence[Metric],
        input_extractor: Optional[Callable[[Any], Any]] = None,
        label_extractor: Optional[Callable[[Any], Any]] = None,
        target_extractor: Optional[Callable[[Any], Any]] = None,
        input_visualizer: Optional[Callable[[Any], Any]] = None,
        target_visualizer: Optional[Callable[[Any], Any]] = None,
        cache_device: Optional[Union[torch.device, str]] = None,
        target_labels: bool = False,
    ):
        super(Experiment, self).__init__()
        self.model = model
        self.model_device = next(self.model.parameters()).device

        self.manager = ExperimentManager(data=data, cache_device=cache_device)
        for explainer in explainers:
            self.manager.add_explainer(explainer)
        for postprocessor in postprocessors:
            self.manager.add_postprocessor(postprocessor)
        for metric in metrics:
            self.manager.add_metric(metric)

        self.input_extractor = input_extractor \
            if input_extractor is not None \
            else default_input_extractor
        self.label_extractor = label_extractor \
            if label_extractor is not None \
            else default_target_extractor
        self.target_extractor = target_extractor \
            if target_extractor is not None \
            else default_target_extractor
        self.input_visualizer = input_visualizer
        self.target_visualizer = target_visualizer
        self.target_labels = target_labels
        self.modality = modality
        self.reset_errors()

    def reset_errors(self):
        self._errors: List[BaseException] = []

    @property
    def errors(self):
        return self._errors

    def to_device(self, x):
        return to_device(x, self.model_device)

    def run_batch(
        self,
        data_ids: Sequence[int],
        explainer_id: int,
        postprocessor_id: int,
        metric_id: int,
    ) -> dict:
        """
        Runs the experiment for selected batch of data, explainer, postprocessor and metric.

        Args:
            data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to process.
            explainer_id (int): ID of explainer to use for the run.
            postprocessor_id (int): ID of postprocessor to use for the run.
            metrics_id (int): ID of metric to use for the run.

        Returns:
            The dictionary of inputs, labels, outputs, targets, explainer, explanation, postprocessor, postprocessed, metric, and evaluation.

        This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances,
        processing data, generating explanations, and evaluating metrics. It then caches the results in the manager, and returns back to the user.

        Note: The input parameters allow for flexibility in specifying subset of data, explainer, postprocessor and metric to process.
        """

        self.predict_batch(data_ids)
        self.explain_batch(data_ids, explainer_id)
        self.evaluate_batch(
            data_ids, explainer_id, postprocessor_id, metric_id)
        data = self.manager.batch_data_by_ids(data_ids)
        return {
            'inputs': self.input_extractor(data),
            'labels': self.label_extractor(data),
            'outputs': self.manager.batch_outputs_by_ids(data_ids),
            'targets': self._get_targets(data_ids),
            'explainer': self.manager.get_explainer_by_id(explainer_id),
            'explanation': self.manager.batch_explanations_by_ids(data_ids, explainer_id),
            'postprocessor': self.manager.get_postprocessor_by_id(postprocessor_id),
            'postprocessed': self.postprocess_batch(data_ids, explainer_id, postprocessor_id),
            'metric': self.manager.get_metric_by_id(metric_id),
            'evaluation': self.manager.batch_evaluations_by_ids(data_ids, explainer_id, postprocessor_id, metric_id),
        }

    def predict_batch(
        self,
        data_ids: Sequence[int],
    ):
        """
        Predicts results of the experiment for selected batch of data.

        Args:
            data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Batched model outputs corresponding to data ids.

        This method orchestrates the experiment by configuring the manager and processing data. It then caches the results in the manager, and returns back to the user.
        """
        data_ids_pred = [
            idx for idx in data_ids
            if self.manager.get_output_by_id(idx) is None
        ]
        if len(data_ids_pred) > 0:
            data = self.manager.batch_data_by_ids(data_ids_pred)
            outputs = self.model(
                *format_into_tuple(self.input_extractor(data)))
            self.manager.cache_outputs(data_ids_pred, outputs)
        return self.manager.batch_outputs_by_ids(data_ids)

    def explain_batch(
        self,
        data_ids: Sequence[int],
        explainer_id: int,
    ):
        """
        Explains selected batch of data within experiment.

        Args:
            data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Batched model explanations corresponding to data ids.

        This method orchestrates the experiment by configuring the manager, obtaining explainer instance,
        processing data, and generating explanations. It then caches the results in the manager, and returns back to the user.
        """
        data_ids_expl = [
            idx for idx in data_ids
            if self.manager.get_explanation_by_id(idx, explainer_id) is None
        ]
        if len(data_ids_expl):
            data = self.manager.batch_data_by_ids(data_ids_expl)
            inputs = self.input_extractor(data)
            targets = self._get_targets(data_ids_expl)
            explainer = self.manager.get_explainer_by_id(explainer_id)
            explanations = explainer.attribute(inputs, targets)
            self.manager.cache_explanations(
                explainer_id, data_ids_expl, explanations)
        return self.manager.batch_explanations_by_ids(data_ids, explainer_id)

    def postprocess_batch(
        self,
        data_ids: List[int],
        explainer_id: int,
        postprocessor_id: int,
    ):
        """
        Postprocesses selected batch of data within experiment.

        Args:
            data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to postprocess.
            explainer_id (int): An explainer ID to specify the explainer to use.
            postprocessor_id (int): A postprocessor ID to specify the postprocessor to use.

        Returns:
            Batched postprocessed model explanations corresponding to data ids.

        This method orchestrates the experiment by configuring the manager, obtaining explainer instance,
        processing data, and generating explanations. It then caches the results in the manager, and returns back to the user.
        """
        explanations = self.manager.batch_explanations_by_ids(
            data_ids, explainer_id)
        postprocessor = self.manager.get_postprocessor_by_id(postprocessor_id)

        modalities = format_into_tuple(self.modality)
        explanations = format_into_tuple(explanations)
        postprocessors = format_into_tuple(postprocessor)

        batch = []
        explainer = self.manager.get_explainer_by_id(explainer_id)
        for mod, attr, pp in zip(modalities, explanations, postprocessors):
            if (
                isinstance(explainer, (Lime, KernelShap))
                and isinstance(mod, TextModality)
                and not isinstance(pp.pooling_fn, Identity)
            ):
                raise ValueError(f'postprocessor {postprocessor_id} does not support explainer {explainer_id}.')
            batch.append(pp(attr))
        return format_out_tuple_if_single(batch)

    def evaluate_batch(
        self,
        data_ids: List[int],
        explainer_id: int,
        postprocessor_id: int,
        metric_id: int,
    ):
        """
        Evaluates selected batch of data within experiment.

        Args:
            data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to postprocess.
            explainer_id (int): An explainer ID to specify the explainer to use.
            postprocessor_id (int): A postprocessor ID to specify the postprocessor to use.
            metric_id (int): A metric ID to evaluate the model explanations.

        Returns:
            Batched model evaluations corresponding to data ids.

        This method orchestrates the experiment by configuring the manager, obtaining explainer instance,
        processing data, generating explanations, and evaluating results. It then caches the results in the manager, and returns back to the user.
        """
        data_ids_eval = [
            idx for idx in data_ids
            if self.manager.get_evaluation_by_id(
                idx, explainer_id, postprocessor_id, metric_id) is None
        ]
        if len(data_ids_eval):
            data = self.manager.batch_data_by_ids(data_ids_eval)
            inputs = self.input_extractor(data)
            targets = self._get_targets(data_ids_eval)
            postprocessed = self.postprocess_batch(
                data_ids_eval, explainer_id, postprocessor_id)
            explainer = self.manager.get_explainer_by_id(explainer_id)
            metric = self.manager.get_metric_by_id(metric_id)
            evaluations = metric.set_explainer(explainer).evaluate(
                inputs, targets, postprocessed)
            self.manager.cache_evaluations(
                explainer_id, postprocessor_id, metric_id,
                data_ids_eval, evaluations
            )
        return self.manager.batch_evaluations_by_ids(
            data_ids, explainer_id, postprocessor_id, metric_id
        )

    def _get_targets(self, data_ids):
        if self.target_labels:
            return self.label_extractor(self.manager.batch_data_by_ids(data_ids))
        outputs = self.manager.batch_outputs_by_ids(data_ids)
        return self.target_extractor(outputs)

    def optimize(
        self,
        data_ids: Union[int, Sequence[int]],
        explainer_id: int,
        metric_id: int,
        direction: Literal['minimize', 'maximize'] = 'maximize',
        sampler: Literal['grid', 'random', 'tpe'] = 'tpe',
        n_trials: Optional[int] = None,
        timeout: Optional[float] = None,
        **kwargs,  # sampler kwargs
    ):
        """
        Optimize experiment hyperparameters by processing data, generating explanations, evaluating with metrics, caching and retrieving the data.

        Args:
            data_ids (Union[int, Sequence[int]]): A single data ID or sequence of data IDs to specify the subset of data to process.
            explainer_id (int): An explainer ID to specify the explainer to use.
            metric_id (int): A metric ID to evaluate optimizer decisions.
            direction (Literal['minimize', 'maximize']): A string to specify the direction of optimization.
            sampler (Literal['grid', 'random', 'tpe']): A string to specify the sampler to use for optimization.
            n_trials (Optional[int]): An integer to specify the number of trials for optimization. If none passed, the number of trials is inferred from `timeout`.
            timeout (Optional[float]): A float to specify the timeout for optimization. Ignored, if `n_trials` is specified.

        Returns:
            The Experiment instance with updated results and state.

        This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances,
        processing data, generating explanations, and evaluating metrics. It then saves the results in the manager.

        Note: The input parameters allow for flexibility in specifying subsets of data, explainers, and metrics to process.
        If not provided, the method processes all available data, explainers, postprocessors, and metrics.
        """
        data_ids = [data_ids] if isinstance(data_ids, int) else data_ids
        data = self.manager.batch_data_by_ids(data_ids)
        explainer = self.manager.get_explainer_by_id(explainer_id)
        postprocessor = self.manager.get_postprocessor_by_id(
            0)  # sample postprocessor to ensure channel_dim
        metric = self.manager.get_metric_by_id(metric_id)

        objective = Objective(
            explainer=explainer,
            postprocessor=postprocessor,
            metric=metric,
            modality=self.modality,
            inputs=self.input_extractor(data),
            targets=self._get_targets(data_ids),
        )
        # TODO: grid search
        if timeout is None:
            n_trials = n_trials or get_default_n_trials(sampler)

        # optimize
        study = optuna.create_study(
            sampler=load_sampler(sampler, **kwargs),
            direction=direction,
        )
        study.optimize(
            objective,
            n_trials=n_trials,
            timeout=timeout,
            n_jobs=1,
        )
        opt_explainer = study.best_trial.user_attrs['explainer']
        opt_postprocessor = study.best_trial.user_attrs['postprocessor']
        return OptimizationOutput(
            explainer=opt_explainer,
            postprocessor=opt_postprocessor,
            study=study,
        )

    def run(
        self,
        data_ids: Optional[Sequence[int]] = None,
        explainer_ids: Optional[Sequence[int]] = None,
        postprocessor_ids: Optional[Sequence[int]] = None,
        metric_ids: Optional[Sequence[int]] = None,
    ) -> 'Experiment':
        """
        Run the experiment by processing data, generating explanations, evaluating with metrics, caching and retrieving the data.

        Args:
            data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.
            explainer_ids (Optional[Sequence[int]]): A sequence of explainer IDs to specify the subset of explainers to use.
            postprocessor_ids (Optional[Sequence[int]]): A sequence of postprocessor IDs to specify the subset of postprocessors to use.
            metric_ids (Optional[Sequence[int]]): A sequence of metric IDs to specify the subset of metrics to evaluate.

        Returns:
            The Experiment instance with updated results and state.

        This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances,
        processing data, generating explanations, and evaluating metrics. It then saves the results in the manager.

        Note: The input parameters allow for flexibility in specifying subsets of data, explainers, and metrics to process.
        If not provided, the method processes all available data, explainers, postprocessors, and metrics.
        """
        self.reset_errors()
        # self.manager.set_config(data_ids, explainer_ids, metrics_ids)

        # inference

        # data_ids is filtered out data indices whose output is saved in cache
        data, data_ids_pred = self.manager.get_data_to_predict(data_ids)
        outputs = self._predict(data)
        self.manager.save_outputs(outputs, data, data_ids_pred)

        # explain
        explainers, explainer_ids = self.manager.get_explainers(explainer_ids)
        postprocessors, postprocessor_ids = self.manager.get_postprocessors(postprocessor_ids)
        metrics, metric_ids = self.manager.get_metrics(metric_ids)

        for explainer, explainer_id in zip(explainers, explainer_ids):
            explainer_name = class_to_string(explainer)

            # data_ids is filtered out data indices whose explanation is saved in cache
            data, data_ids_expl = self.manager.get_data_to_process_for_explainer(
                explainer_id, data_ids)
            explanations = self._explain(data, data_ids_expl, explainer)
            self.manager.save_explanations(
                explanations, data, data_ids_expl, explainer_id
            )
            message = get_message(
                'experiment.event.explainer', explainer=explainer_name
            )
            print(f"[Experiment] {message}")
            self.fire(ExperimentObservableEvent(
                self.manager, message, explainer))

            for postprocessor, postprocessor_id in zip(postprocessors, postprocessor_ids):
                for metric, metric_id in zip(metrics, metric_ids):
                    metric_name = class_to_string(metric)
                    data, data_ids_eval = self.manager.get_data_to_process_for_metric(
                        explainer_id, postprocessor_id, metric_id, data_ids)
                    explanations, data_ids_eval = self.manager.get_valid_explanations(
                        explainer_id, data_ids_eval)
                    data, _ = self.manager.get_data(data_ids_eval)
                    evaluations = self._evaluate(
                        data, data_ids_eval, explanations, explainer, postprocessor, metric)
                    self.manager.save_evaluations(
                        evaluations, data, data_ids_eval, explainer_id, postprocessor_id, metric_id)

                    message = get_message(
                        'experiment.event.explainer.metric', explainer=explainer_name, metric=metric_name)
                    print(f"[Experiment] {message}")
                    self.fire(ExperimentObservableEvent(
                        self.manager, message, explainer, metric))
        return self

    def _predict(self, data: DataSource):
        """
        Predict input data with experiment model.

        Args:
            data (DataSource): A data to be explained.

        Returns:
            Predictions corresponding to input data.
        """
        outputs = [
            self.model(*format_into_tuple(
                self.to_device(self.input_extractor(datum)))
            ) for datum in data
        ]
        return outputs

    def _explain(self, data: DataSource, data_ids: Sequence[int], explainer: Explainer):
        """
        Explain input data with an explainer.

        Args:
            data (DataSource): A data to be explained.
            data_ids (Sequence[int]): A sequence of data IDs corresponding to the provided data source. Data IDs are used to cache results.
            explainer (Explainer): A explainer object to use for explanation.

        Returns:
            Explanations corresponding to input data, generated by the explainer.

        This method explains data with a specified explainer. Produced results are cached by the manager.
        """
        explanations = [None] * len(data)
        explainer_name = class_to_string(explainer)
        for i, (datum, data_id) in enumerate(zip(data, data_ids)):
            try:
                datum = self.to_device(datum)
                inputs = format_into_tuple(self.input_extractor(datum))
                targets = self.label_extractor(datum) if self.target_labels \
                    else self._get_targets(data_ids)
                explanations[i] = explainer.attribute(
                    inputs=inputs,
                    targets=targets,
                )
            except NotImplementedError as error:
                warnings.warn(
                    f"\n[Experiment] {get_message('experiment.errors.explainer_unsupported', explainer=explainer_name)}")
                raise error
            except Exception as e:
                warnings.warn(
                    f"\n[Experiment] {get_message('experiment.errors.explanation', explainer=explainer_name, error=e)}")
                self._errors.append(e)
        return explanations

    # def _evaluate(self, data: DataSource, data_ids: List[int], explanations: DataSource, explainer: Explainer, postprocessor: Callable, metric: Metric):
    #     if explanations is None:
    #         return None
    #     started_at = time.time()
    #     metric_name = class_to_string(metric)
    #     explainer_name = class_to_string(explainer)

    #     evaluations = [None] * len(data)
    #     for i, (datum, data_id, explanation) in enumerate(zip(data, data_ids, explanations)):
    #         if explanation is None:
    #             continue
    #         datum = self.to_device(datum)
    #         explanation = self.to_device(explanation)
    #         inputs = self.input_extractor(datum)
    #         targets = self.label_extractor(datum) if self.target_labels \
    #             else self._get_targets(data_ids)
    #         try:
    #             metric = metric.set_explainer(explainer).set_postprocessor(postprocessor)
    #             evaluations[i] = metric.evaluate(
    #                 inputs=inputs,
    #                 targets=targets,
    #                 attributions=explanation,
    #             )
    #         except Exception as e:
    #             import pdb; pdb.set_trace()
    #             warnings.warn(
    #                 f"\n[Experiment] {get_message('experiment.errors.evaluation', explainer=explainer_name, metric=metric_name, error=e)}")
    #             self._errors.append(e)
    #     elapsed_time = time.time() - started_at
    #     print(
    #         f"[Experiment] {get_message('elapsed', modality=metric_name, elapsed=elapsed_time)}")

    #     return evaluations

    def _get_targets(self, data_ids: Sequence[int]):
        """
        Retrieve and flatten last run target (output) data.

        Args:
            data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Flattened target (output) data.

        This method retrieves target (output) data using the target extractor and flattens it for further processing.

        Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
        """
        # predict if not cached
        not_predicted = [
            data_id for data_id in data_ids
            if self.manager._cache.get_output(data_id) is None
        ]
        if len(not_predicted) > 0:
            self.predict_batch(not_predicted)

        # load outputs from cache
        outputs = torch.stack([
            self.manager._cache.get_output(data_id)
            for data_id in data_ids
        ])
        return self.target_extractor(outputs)

    # def get_visualizations_flattened(self) -> Sequence[Sequence[Figure]]:
    #     """
    #     Generate flattened visualizations for each data point based on explanations.

    #     Returns:
    #         List of visualizations for each data point across explainers.

    #     This method retrieves valid explanations for each explainer, formats them for visualization,
    #     and generates visualizations using Plotly Express. The results are flattened based on the data points
    #     and returned as a list of lists of figures.
    #     """
    #     assert self.modality == 'image', f"Visualization for '{self.modality} is not supported yet"
    #     explainers, explainer_ids = self.manager.get_explainers()
    #     # Get all data ids
    #     experiment_data_ids = self.manager.get_data_ids()
    #     visualizations = []

    #     for explainer, explainer_id in zip(explainers, explainer_ids):
    #         # Get all valid explanations and data ids for this explainer
    #         explanations, data_ids = self.manager.get_valid_explanations(
    #             explainer_id)
    #         data = self.manager.get_data(data_ids)[0]
    #         explainer_visualizations = []
    #         for explanation in explanations:
    #             figs = []
    #             for attr in explanation:
    #                 postprocessed = postprocess_attr(
    #                     attr,
    #                     channel_dim=0,
    #                     pooling_method='l2normsq',
    #                     normalization_method='minmax'
    #                 ).detach().numpy()
    #                 fig = px.imshow(postprocessed, color_continuous_scale='RdBu_R', color_continuous_midpoint=.5)
    #                 figs.append(fig)
    #             explainer_visualizations.append(figs)

    #         # explainer_visualizations = [[
    #         #     px.imshow(
    #         #         postprocess_attr(
    #         #             attr, # C x H x W
    #         #             channel_dim=0,
    #         #             pooling_method='l2normsq',
    #         #             normalization_method='minmax',
    #         #         ).detach().numpy(),
    #         #         color_continuous_scale='RdBu_R',
    #         #         color_continuous_midpoint=.5,
    #         #     ) for attr in explanation]
    #         #     for explanation in explanations
    #         # ]
    #         # # Visualize each valid explanataion
    #         # for datum, explanation in zip(data, explanations):
    #         #     inputs = self.input_extractor(datum)
    #         #     targets = self.target_extractor(datum)
    #         #     formatted = explainer.format_outputs_for_visualization(
    #         #         inputs=inputs,
    #         #         targets=targets,
    #         #         explanations=explanation,
    #         #         modality=self.modality
    #         #     )

    #         #     if not self.manager.is_batched:
    #         #         formatted = [formatted]
    #         #     formatted_visualizations = [
    #         #         px.imshow(explanation, color_continuous_scale="RdBu_r", color_continuous_midpoint=0.0) for explanation in formatted
    #         #     ]
    #         #     if not self.manager.is_batched:
    #         #         formatted_visualizations = formatted_visualizations[0]
    #         #     explainer_visualizations.append(formatted_visualizations)

    #         flat_explainer_visualizations = self.manager.flatten_if_batched(
    #             explainer_visualizations, data)
    #         # Set visualizaions of all data ids as None
    #         explainer_visualizations = {
    #             idx: None for idx in experiment_data_ids}
    #         # Fill all valid visualizations
    #         for visualization, data_id in zip(flat_explainer_visualizations, data_ids):
    #             explainer_visualizations[data_id] = visualization
    #         visualizations.append(list(explainer_visualizations.values()))

    #     return visualizations

    def get_inputs_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]:
        """
        Retrieve and flatten last run input data.

        Args:
            data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Flattened input data.

        This method retrieves input data using the input extractor and flattens it for further processing.

        Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
        """
        data, _ = self.manager.get_data(data_ids)
        data = [self.input_extractor(datum) for datum in data]
        return self.manager.flatten_if_batched(data, data)

    def get_all_inputs_flattened(self) -> Sequence[Tensor]:
        """
        Retrieve and flatten all input data.

        Returns:
            Flattened input data from all available data.

        This method retrieves input data from all available data points using the input extractor and flattens it.
        """
        data = self.manager.get_all_data()
        data = [self.input_extractor(datum) for datum in data]
        return self.manager.flatten_if_batched(data, data)

    def get_labels_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]:
        """
        Retrieve and flatten labels data.

        Args:
            data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Flattened labels data.

        This method retrieves label data using the label extractor and flattens it for further processing.
        """
        data, _ = self.manager.get_data(data_ids)
        labels = [self.label_extractor(datum) for datum in data]
        return self.manager.flatten_if_batched(labels, data)

    def get_targets_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]:
        """
        Retrieve and flatten target data.

        Args:
            data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Flattened target data.

        This method retrieves target data using the target extractor and flattens it for further processing.

        Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
        """
        if self.target_labels:
            return self.get_labels_flattened(data_ids)
        data, _ = self.manager.get_data(data_ids)
        targets = [self._get_targets(data_ids)]
        return self.manager.flatten_if_batched(targets, data)
        # targets = [self.label_extractor(datum) for datum in data] \
        #     if self.target_labels else [self._get_targets(data_ids)]
        # return self.manager.flatten_if_batched(targets, data)

    def get_outputs_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]:
        """
        Retrieve and flatten model outputs.

        Args:
            data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Flattened model outputs.

        This method retrieves flattened model outputs using the manager's get_flat_outputs method.

        Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
        """
        return self.manager.get_flat_outputs(data_ids)

    def get_explanations_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Sequence[Tensor]]:
        """
        Retrieve and flatten explanations from all explainers.

        Args:
            data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Flattened explanations from all explainers.

        This method retrieves flattened explanations for each explainer using the manager's `get_flat_explanations` method.

        Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
        """
        _, explainer_ids = self.manager.get_explainers()
        return [
            self.manager.get_flat_explanations(explainer_id, data_ids)
            for explainer_id in explainer_ids
        ]

    def get_evaluations_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Sequence[Sequence[Tensor]]]:
        """
        Retrieve and flatten evaluations for all explainers and metrics.

        Args:
            data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

        Returns:
            Flattened evaluations for all explainers and metrics.

        This method retrieves flattened evaluations for each explainer and metric using the manager's
        get_flat_evaluations method.

        Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
        """
        _, explainer_ids = self.manager.get_explainers()
        _, postprocessor_ids = self.manager.get_postprocessors()
        _, metric_ids = self.manager.get_metrics()

        formatted = [[[
            self.manager.get_flat_evaluations(
                explainer_id, postprocessor_id, metric_id, data_ids)
            for metric_id in metric_ids
        ] for postprocessor_id in postprocessor_ids]
            for explainer_id in explainer_ids]

        return formatted

    def get_explainers_ranks(self) -> Optional[Sequence[Sequence[int]]]:
        """
        Calculate and return rankings for explainers based on evaluations.

        Returns:
            Rankings of explainers. Returns None if rankings cannot be calculated.

        This method calculates rankings for explainers based on evaluations and metric scores. It considers
        metric priorities and sorting preferences to produce rankings.
        """
        evaluations = [[[
            data_evaluation.detach().cpu() if data_evaluation is not None else None
            for data_evaluation in metric_data
        ]for metric_data in explainer_data
        ]for explainer_data in self.get_evaluations_flattened()]
        # (explainers, metrics, data)
        evaluations = np.array(evaluations, dtype=float)
        evaluations = np.nan_to_num(evaluations, nan=-np.inf)
        if evaluations.ndim != 3:
            return None

        evaluations = evaluations.argsort(axis=-3).argsort(axis=-3) + 1
        n_explainers = evaluations.shape[0]
        metric_name_to_idx = {}

        for idx, metric in enumerate(self.get_current_metrics()):
            metric_name = class_to_string(metric)
            evaluations[:, idx, :] = evaluations[:, idx, :]
            if EVALUATION_METRIC_REVERSE_SORT.get(metric_name, False):
                evaluations[:, idx, :] = \
                    n_explainers - evaluations[:, idx, :] + 1

            metric_name_to_idx[metric_name] = idx
        # (explainers, data)
        scores: np.ndarray = evaluations.sum(axis=-2)

        for metric_name in EVALUATION_METRIC_SORT_PRIORITY:
            if metric_name not in metric_name_to_idx:
                continue

            idx = metric_name_to_idx[metric_name]
            scores = scores * n_explainers + evaluations[:, idx, :]

        return scores.argsort(axis=-2).argsort(axis=-2).tolist()

    @property
    def has_explanations(self):
        return self.manager.has_explanations

evaluate_batch(data_ids, explainer_id, postprocessor_id, metric_id)

Evaluates selected batch of data within experiment.

Parameters:

Name Type Description Default
data_ids Sequence[int]

A sequence of data IDs to specify the subset of data to postprocess.

required
explainer_id int

An explainer ID to specify the explainer to use.

required
postprocessor_id int

A postprocessor ID to specify the postprocessor to use.

required
metric_id int

A metric ID to evaluate the model explanations.

required

Returns:

Type Description

Batched model evaluations corresponding to data ids.

This method orchestrates the experiment by configuring the manager, obtaining explainer instance, processing data, generating explanations, and evaluating results. It then caches the results in the manager, and returns back to the user.

Source code in pnpxai/core/experiment/experiment.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def evaluate_batch(
    self,
    data_ids: List[int],
    explainer_id: int,
    postprocessor_id: int,
    metric_id: int,
):
    """
    Evaluates selected batch of data within experiment.

    Args:
        data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to postprocess.
        explainer_id (int): An explainer ID to specify the explainer to use.
        postprocessor_id (int): A postprocessor ID to specify the postprocessor to use.
        metric_id (int): A metric ID to evaluate the model explanations.

    Returns:
        Batched model evaluations corresponding to data ids.

    This method orchestrates the experiment by configuring the manager, obtaining explainer instance,
    processing data, generating explanations, and evaluating results. It then caches the results in the manager, and returns back to the user.
    """
    data_ids_eval = [
        idx for idx in data_ids
        if self.manager.get_evaluation_by_id(
            idx, explainer_id, postprocessor_id, metric_id) is None
    ]
    if len(data_ids_eval):
        data = self.manager.batch_data_by_ids(data_ids_eval)
        inputs = self.input_extractor(data)
        targets = self._get_targets(data_ids_eval)
        postprocessed = self.postprocess_batch(
            data_ids_eval, explainer_id, postprocessor_id)
        explainer = self.manager.get_explainer_by_id(explainer_id)
        metric = self.manager.get_metric_by_id(metric_id)
        evaluations = metric.set_explainer(explainer).evaluate(
            inputs, targets, postprocessed)
        self.manager.cache_evaluations(
            explainer_id, postprocessor_id, metric_id,
            data_ids_eval, evaluations
        )
    return self.manager.batch_evaluations_by_ids(
        data_ids, explainer_id, postprocessor_id, metric_id
    )

explain_batch(data_ids, explainer_id)

Explains selected batch of data within experiment.

Parameters:

Name Type Description Default
data_ids Sequence[int]

A sequence of data IDs to specify the subset of data to process.

required

Returns:

Type Description

Batched model explanations corresponding to data ids.

This method orchestrates the experiment by configuring the manager, obtaining explainer instance, processing data, and generating explanations. It then caches the results in the manager, and returns back to the user.

Source code in pnpxai/core/experiment/experiment.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def explain_batch(
    self,
    data_ids: Sequence[int],
    explainer_id: int,
):
    """
    Explains selected batch of data within experiment.

    Args:
        data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to process.

    Returns:
        Batched model explanations corresponding to data ids.

    This method orchestrates the experiment by configuring the manager, obtaining explainer instance,
    processing data, and generating explanations. It then caches the results in the manager, and returns back to the user.
    """
    data_ids_expl = [
        idx for idx in data_ids
        if self.manager.get_explanation_by_id(idx, explainer_id) is None
    ]
    if len(data_ids_expl):
        data = self.manager.batch_data_by_ids(data_ids_expl)
        inputs = self.input_extractor(data)
        targets = self._get_targets(data_ids_expl)
        explainer = self.manager.get_explainer_by_id(explainer_id)
        explanations = explainer.attribute(inputs, targets)
        self.manager.cache_explanations(
            explainer_id, data_ids_expl, explanations)
    return self.manager.batch_explanations_by_ids(data_ids, explainer_id)

get_all_inputs_flattened()

Retrieve and flatten all input data.

Returns:

Type Description
Sequence[Tensor]

Flattened input data from all available data.

This method retrieves input data from all available data points using the input extractor and flattens it.

Source code in pnpxai/core/experiment/experiment.py
682
683
684
685
686
687
688
689
690
691
692
693
def get_all_inputs_flattened(self) -> Sequence[Tensor]:
    """
    Retrieve and flatten all input data.

    Returns:
        Flattened input data from all available data.

    This method retrieves input data from all available data points using the input extractor and flattens it.
    """
    data = self.manager.get_all_data()
    data = [self.input_extractor(datum) for datum in data]
    return self.manager.flatten_if_batched(data, data)

get_evaluations_flattened(data_ids=None)

Retrieve and flatten evaluations for all explainers and metrics.

Parameters:

Name Type Description Default
data_ids Optional[Sequence[int]]

A sequence of data IDs to specify the subset of data to process.

None

Returns:

Type Description
Sequence[Sequence[Sequence[Tensor]]]

Flattened evaluations for all explainers and metrics.

This method retrieves flattened evaluations for each explainer and metric using the manager's get_flat_evaluations method.

Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.

Source code in pnpxai/core/experiment/experiment.py
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
def get_evaluations_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Sequence[Sequence[Tensor]]]:
    """
    Retrieve and flatten evaluations for all explainers and metrics.

    Args:
        data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

    Returns:
        Flattened evaluations for all explainers and metrics.

    This method retrieves flattened evaluations for each explainer and metric using the manager's
    get_flat_evaluations method.

    Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
    """
    _, explainer_ids = self.manager.get_explainers()
    _, postprocessor_ids = self.manager.get_postprocessors()
    _, metric_ids = self.manager.get_metrics()

    formatted = [[[
        self.manager.get_flat_evaluations(
            explainer_id, postprocessor_id, metric_id, data_ids)
        for metric_id in metric_ids
    ] for postprocessor_id in postprocessor_ids]
        for explainer_id in explainer_ids]

    return formatted

get_explainers_ranks()

Calculate and return rankings for explainers based on evaluations.

Returns:

Type Description
Optional[Sequence[Sequence[int]]]

Rankings of explainers. Returns None if rankings cannot be calculated.

This method calculates rankings for explainers based on evaluations and metric scores. It considers metric priorities and sorting preferences to produce rankings.

Source code in pnpxai/core/experiment/experiment.py
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
def get_explainers_ranks(self) -> Optional[Sequence[Sequence[int]]]:
    """
    Calculate and return rankings for explainers based on evaluations.

    Returns:
        Rankings of explainers. Returns None if rankings cannot be calculated.

    This method calculates rankings for explainers based on evaluations and metric scores. It considers
    metric priorities and sorting preferences to produce rankings.
    """
    evaluations = [[[
        data_evaluation.detach().cpu() if data_evaluation is not None else None
        for data_evaluation in metric_data
    ]for metric_data in explainer_data
    ]for explainer_data in self.get_evaluations_flattened()]
    # (explainers, metrics, data)
    evaluations = np.array(evaluations, dtype=float)
    evaluations = np.nan_to_num(evaluations, nan=-np.inf)
    if evaluations.ndim != 3:
        return None

    evaluations = evaluations.argsort(axis=-3).argsort(axis=-3) + 1
    n_explainers = evaluations.shape[0]
    metric_name_to_idx = {}

    for idx, metric in enumerate(self.get_current_metrics()):
        metric_name = class_to_string(metric)
        evaluations[:, idx, :] = evaluations[:, idx, :]
        if EVALUATION_METRIC_REVERSE_SORT.get(metric_name, False):
            evaluations[:, idx, :] = \
                n_explainers - evaluations[:, idx, :] + 1

        metric_name_to_idx[metric_name] = idx
    # (explainers, data)
    scores: np.ndarray = evaluations.sum(axis=-2)

    for metric_name in EVALUATION_METRIC_SORT_PRIORITY:
        if metric_name not in metric_name_to_idx:
            continue

        idx = metric_name_to_idx[metric_name]
        scores = scores * n_explainers + evaluations[:, idx, :]

    return scores.argsort(axis=-2).argsort(axis=-2).tolist()

get_explanations_flattened(data_ids=None)

Retrieve and flatten explanations from all explainers.

Parameters:

Name Type Description Default
data_ids Optional[Sequence[int]]

A sequence of data IDs to specify the subset of data to process.

None

Returns:

Type Description
Sequence[Sequence[Tensor]]

Flattened explanations from all explainers.

This method retrieves flattened explanations for each explainer using the manager's get_flat_explanations method.

Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.

Source code in pnpxai/core/experiment/experiment.py
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
def get_explanations_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Sequence[Tensor]]:
    """
    Retrieve and flatten explanations from all explainers.

    Args:
        data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

    Returns:
        Flattened explanations from all explainers.

    This method retrieves flattened explanations for each explainer using the manager's `get_flat_explanations` method.

    Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
    """
    _, explainer_ids = self.manager.get_explainers()
    return [
        self.manager.get_flat_explanations(explainer_id, data_ids)
        for explainer_id in explainer_ids
    ]

get_inputs_flattened(data_ids=None)

Retrieve and flatten last run input data.

Parameters:

Name Type Description Default
data_ids Optional[Sequence[int]]

A sequence of data IDs to specify the subset of data to process.

None

Returns:

Type Description
Sequence[Tensor]

Flattened input data.

This method retrieves input data using the input extractor and flattens it for further processing.

Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.

Source code in pnpxai/core/experiment/experiment.py
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
def get_inputs_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]:
    """
    Retrieve and flatten last run input data.

    Args:
        data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

    Returns:
        Flattened input data.

    This method retrieves input data using the input extractor and flattens it for further processing.

    Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
    """
    data, _ = self.manager.get_data(data_ids)
    data = [self.input_extractor(datum) for datum in data]
    return self.manager.flatten_if_batched(data, data)

get_labels_flattened(data_ids=None)

Retrieve and flatten labels data.

Parameters:

Name Type Description Default
data_ids Optional[Sequence[int]]

A sequence of data IDs to specify the subset of data to process.

None

Returns:

Type Description
Sequence[Tensor]

Flattened labels data.

This method retrieves label data using the label extractor and flattens it for further processing.

Source code in pnpxai/core/experiment/experiment.py
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
def get_labels_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]:
    """
    Retrieve and flatten labels data.

    Args:
        data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

    Returns:
        Flattened labels data.

    This method retrieves label data using the label extractor and flattens it for further processing.
    """
    data, _ = self.manager.get_data(data_ids)
    labels = [self.label_extractor(datum) for datum in data]
    return self.manager.flatten_if_batched(labels, data)

get_outputs_flattened(data_ids=None)

Retrieve and flatten model outputs.

Parameters:

Name Type Description Default
data_ids Optional[Sequence[int]]

A sequence of data IDs to specify the subset of data to process.

None

Returns:

Type Description
Sequence[Tensor]

Flattened model outputs.

This method retrieves flattened model outputs using the manager's get_flat_outputs method.

Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.

Source code in pnpxai/core/experiment/experiment.py
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
def get_outputs_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]:
    """
    Retrieve and flatten model outputs.

    Args:
        data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

    Returns:
        Flattened model outputs.

    This method retrieves flattened model outputs using the manager's get_flat_outputs method.

    Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
    """
    return self.manager.get_flat_outputs(data_ids)

get_targets_flattened(data_ids=None)

Retrieve and flatten target data.

Parameters:

Name Type Description Default
data_ids Optional[Sequence[int]]

A sequence of data IDs to specify the subset of data to process.

None

Returns:

Type Description
Sequence[Tensor]

Flattened target data.

This method retrieves target data using the target extractor and flattens it for further processing.

Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.

Source code in pnpxai/core/experiment/experiment.py
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
def get_targets_flattened(self, data_ids: Optional[Sequence[int]] = None) -> Sequence[Tensor]:
    """
    Retrieve and flatten target data.

    Args:
        data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.

    Returns:
        Flattened target data.

    This method retrieves target data using the target extractor and flattens it for further processing.

    Note: The input parameters allow for flexibility in specifying subsets of data to process. If not provided, the method processes all available data.
    """
    if self.target_labels:
        return self.get_labels_flattened(data_ids)
    data, _ = self.manager.get_data(data_ids)
    targets = [self._get_targets(data_ids)]
    return self.manager.flatten_if_batched(targets, data)

optimize(data_ids, explainer_id, metric_id, direction='maximize', sampler='tpe', n_trials=None, timeout=None, **kwargs)

Optimize experiment hyperparameters by processing data, generating explanations, evaluating with metrics, caching and retrieving the data.

Parameters:

Name Type Description Default
data_ids Union[int, Sequence[int]]

A single data ID or sequence of data IDs to specify the subset of data to process.

required
explainer_id int

An explainer ID to specify the explainer to use.

required
metric_id int

A metric ID to evaluate optimizer decisions.

required
direction Literal['minimize', 'maximize']

A string to specify the direction of optimization.

'maximize'
sampler Literal['grid', 'random', 'tpe']

A string to specify the sampler to use for optimization.

'tpe'
n_trials Optional[int]

An integer to specify the number of trials for optimization. If none passed, the number of trials is inferred from timeout.

None
timeout Optional[float]

A float to specify the timeout for optimization. Ignored, if n_trials is specified.

None

Returns:

Type Description

The Experiment instance with updated results and state.

This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances, processing data, generating explanations, and evaluating metrics. It then saves the results in the manager.

Note: The input parameters allow for flexibility in specifying subsets of data, explainers, and metrics to process. If not provided, the method processes all available data, explainers, postprocessors, and metrics.

Source code in pnpxai/core/experiment/experiment.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def optimize(
    self,
    data_ids: Union[int, Sequence[int]],
    explainer_id: int,
    metric_id: int,
    direction: Literal['minimize', 'maximize'] = 'maximize',
    sampler: Literal['grid', 'random', 'tpe'] = 'tpe',
    n_trials: Optional[int] = None,
    timeout: Optional[float] = None,
    **kwargs,  # sampler kwargs
):
    """
    Optimize experiment hyperparameters by processing data, generating explanations, evaluating with metrics, caching and retrieving the data.

    Args:
        data_ids (Union[int, Sequence[int]]): A single data ID or sequence of data IDs to specify the subset of data to process.
        explainer_id (int): An explainer ID to specify the explainer to use.
        metric_id (int): A metric ID to evaluate optimizer decisions.
        direction (Literal['minimize', 'maximize']): A string to specify the direction of optimization.
        sampler (Literal['grid', 'random', 'tpe']): A string to specify the sampler to use for optimization.
        n_trials (Optional[int]): An integer to specify the number of trials for optimization. If none passed, the number of trials is inferred from `timeout`.
        timeout (Optional[float]): A float to specify the timeout for optimization. Ignored, if `n_trials` is specified.

    Returns:
        The Experiment instance with updated results and state.

    This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances,
    processing data, generating explanations, and evaluating metrics. It then saves the results in the manager.

    Note: The input parameters allow for flexibility in specifying subsets of data, explainers, and metrics to process.
    If not provided, the method processes all available data, explainers, postprocessors, and metrics.
    """
    data_ids = [data_ids] if isinstance(data_ids, int) else data_ids
    data = self.manager.batch_data_by_ids(data_ids)
    explainer = self.manager.get_explainer_by_id(explainer_id)
    postprocessor = self.manager.get_postprocessor_by_id(
        0)  # sample postprocessor to ensure channel_dim
    metric = self.manager.get_metric_by_id(metric_id)

    objective = Objective(
        explainer=explainer,
        postprocessor=postprocessor,
        metric=metric,
        modality=self.modality,
        inputs=self.input_extractor(data),
        targets=self._get_targets(data_ids),
    )
    # TODO: grid search
    if timeout is None:
        n_trials = n_trials or get_default_n_trials(sampler)

    # optimize
    study = optuna.create_study(
        sampler=load_sampler(sampler, **kwargs),
        direction=direction,
    )
    study.optimize(
        objective,
        n_trials=n_trials,
        timeout=timeout,
        n_jobs=1,
    )
    opt_explainer = study.best_trial.user_attrs['explainer']
    opt_postprocessor = study.best_trial.user_attrs['postprocessor']
    return OptimizationOutput(
        explainer=opt_explainer,
        postprocessor=opt_postprocessor,
        study=study,
    )

postprocess_batch(data_ids, explainer_id, postprocessor_id)

Postprocesses selected batch of data within experiment.

Parameters:

Name Type Description Default
data_ids Sequence[int]

A sequence of data IDs to specify the subset of data to postprocess.

required
explainer_id int

An explainer ID to specify the explainer to use.

required
postprocessor_id int

A postprocessor ID to specify the postprocessor to use.

required

Returns:

Type Description

Batched postprocessed model explanations corresponding to data ids.

This method orchestrates the experiment by configuring the manager, obtaining explainer instance, processing data, and generating explanations. It then caches the results in the manager, and returns back to the user.

Source code in pnpxai/core/experiment/experiment.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def postprocess_batch(
    self,
    data_ids: List[int],
    explainer_id: int,
    postprocessor_id: int,
):
    """
    Postprocesses selected batch of data within experiment.

    Args:
        data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to postprocess.
        explainer_id (int): An explainer ID to specify the explainer to use.
        postprocessor_id (int): A postprocessor ID to specify the postprocessor to use.

    Returns:
        Batched postprocessed model explanations corresponding to data ids.

    This method orchestrates the experiment by configuring the manager, obtaining explainer instance,
    processing data, and generating explanations. It then caches the results in the manager, and returns back to the user.
    """
    explanations = self.manager.batch_explanations_by_ids(
        data_ids, explainer_id)
    postprocessor = self.manager.get_postprocessor_by_id(postprocessor_id)

    modalities = format_into_tuple(self.modality)
    explanations = format_into_tuple(explanations)
    postprocessors = format_into_tuple(postprocessor)

    batch = []
    explainer = self.manager.get_explainer_by_id(explainer_id)
    for mod, attr, pp in zip(modalities, explanations, postprocessors):
        if (
            isinstance(explainer, (Lime, KernelShap))
            and isinstance(mod, TextModality)
            and not isinstance(pp.pooling_fn, Identity)
        ):
            raise ValueError(f'postprocessor {postprocessor_id} does not support explainer {explainer_id}.')
        batch.append(pp(attr))
    return format_out_tuple_if_single(batch)

predict_batch(data_ids)

Predicts results of the experiment for selected batch of data.

Parameters:

Name Type Description Default
data_ids Sequence[int]

A sequence of data IDs to specify the subset of data to process.

required

Returns:

Type Description

Batched model outputs corresponding to data ids.

This method orchestrates the experiment by configuring the manager and processing data. It then caches the results in the manager, and returns back to the user.

Source code in pnpxai/core/experiment/experiment.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def predict_batch(
    self,
    data_ids: Sequence[int],
):
    """
    Predicts results of the experiment for selected batch of data.

    Args:
        data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to process.

    Returns:
        Batched model outputs corresponding to data ids.

    This method orchestrates the experiment by configuring the manager and processing data. It then caches the results in the manager, and returns back to the user.
    """
    data_ids_pred = [
        idx for idx in data_ids
        if self.manager.get_output_by_id(idx) is None
    ]
    if len(data_ids_pred) > 0:
        data = self.manager.batch_data_by_ids(data_ids_pred)
        outputs = self.model(
            *format_into_tuple(self.input_extractor(data)))
        self.manager.cache_outputs(data_ids_pred, outputs)
    return self.manager.batch_outputs_by_ids(data_ids)

run(data_ids=None, explainer_ids=None, postprocessor_ids=None, metric_ids=None)

Run the experiment by processing data, generating explanations, evaluating with metrics, caching and retrieving the data.

Parameters:

Name Type Description Default
data_ids Optional[Sequence[int]]

A sequence of data IDs to specify the subset of data to process.

None
explainer_ids Optional[Sequence[int]]

A sequence of explainer IDs to specify the subset of explainers to use.

None
postprocessor_ids Optional[Sequence[int]]

A sequence of postprocessor IDs to specify the subset of postprocessors to use.

None
metric_ids Optional[Sequence[int]]

A sequence of metric IDs to specify the subset of metrics to evaluate.

None

Returns:

Type Description
Experiment

The Experiment instance with updated results and state.

This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances, processing data, generating explanations, and evaluating metrics. It then saves the results in the manager.

Note: The input parameters allow for flexibility in specifying subsets of data, explainers, and metrics to process. If not provided, the method processes all available data, explainers, postprocessors, and metrics.

Source code in pnpxai/core/experiment/experiment.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
def run(
    self,
    data_ids: Optional[Sequence[int]] = None,
    explainer_ids: Optional[Sequence[int]] = None,
    postprocessor_ids: Optional[Sequence[int]] = None,
    metric_ids: Optional[Sequence[int]] = None,
) -> 'Experiment':
    """
    Run the experiment by processing data, generating explanations, evaluating with metrics, caching and retrieving the data.

    Args:
        data_ids (Optional[Sequence[int]]): A sequence of data IDs to specify the subset of data to process.
        explainer_ids (Optional[Sequence[int]]): A sequence of explainer IDs to specify the subset of explainers to use.
        postprocessor_ids (Optional[Sequence[int]]): A sequence of postprocessor IDs to specify the subset of postprocessors to use.
        metric_ids (Optional[Sequence[int]]): A sequence of metric IDs to specify the subset of metrics to evaluate.

    Returns:
        The Experiment instance with updated results and state.

    This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances,
    processing data, generating explanations, and evaluating metrics. It then saves the results in the manager.

    Note: The input parameters allow for flexibility in specifying subsets of data, explainers, and metrics to process.
    If not provided, the method processes all available data, explainers, postprocessors, and metrics.
    """
    self.reset_errors()
    # self.manager.set_config(data_ids, explainer_ids, metrics_ids)

    # inference

    # data_ids is filtered out data indices whose output is saved in cache
    data, data_ids_pred = self.manager.get_data_to_predict(data_ids)
    outputs = self._predict(data)
    self.manager.save_outputs(outputs, data, data_ids_pred)

    # explain
    explainers, explainer_ids = self.manager.get_explainers(explainer_ids)
    postprocessors, postprocessor_ids = self.manager.get_postprocessors(postprocessor_ids)
    metrics, metric_ids = self.manager.get_metrics(metric_ids)

    for explainer, explainer_id in zip(explainers, explainer_ids):
        explainer_name = class_to_string(explainer)

        # data_ids is filtered out data indices whose explanation is saved in cache
        data, data_ids_expl = self.manager.get_data_to_process_for_explainer(
            explainer_id, data_ids)
        explanations = self._explain(data, data_ids_expl, explainer)
        self.manager.save_explanations(
            explanations, data, data_ids_expl, explainer_id
        )
        message = get_message(
            'experiment.event.explainer', explainer=explainer_name
        )
        print(f"[Experiment] {message}")
        self.fire(ExperimentObservableEvent(
            self.manager, message, explainer))

        for postprocessor, postprocessor_id in zip(postprocessors, postprocessor_ids):
            for metric, metric_id in zip(metrics, metric_ids):
                metric_name = class_to_string(metric)
                data, data_ids_eval = self.manager.get_data_to_process_for_metric(
                    explainer_id, postprocessor_id, metric_id, data_ids)
                explanations, data_ids_eval = self.manager.get_valid_explanations(
                    explainer_id, data_ids_eval)
                data, _ = self.manager.get_data(data_ids_eval)
                evaluations = self._evaluate(
                    data, data_ids_eval, explanations, explainer, postprocessor, metric)
                self.manager.save_evaluations(
                    evaluations, data, data_ids_eval, explainer_id, postprocessor_id, metric_id)

                message = get_message(
                    'experiment.event.explainer.metric', explainer=explainer_name, metric=metric_name)
                print(f"[Experiment] {message}")
                self.fire(ExperimentObservableEvent(
                    self.manager, message, explainer, metric))
    return self

run_batch(data_ids, explainer_id, postprocessor_id, metric_id)

Runs the experiment for selected batch of data, explainer, postprocessor and metric.

Parameters:

Name Type Description Default
data_ids Sequence[int]

A sequence of data IDs to specify the subset of data to process.

required
explainer_id int

ID of explainer to use for the run.

required
postprocessor_id int

ID of postprocessor to use for the run.

required
metrics_id int

ID of metric to use for the run.

required

Returns:

Type Description
dict

The dictionary of inputs, labels, outputs, targets, explainer, explanation, postprocessor, postprocessed, metric, and evaluation.

This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances, processing data, generating explanations, and evaluating metrics. It then caches the results in the manager, and returns back to the user.

Note: The input parameters allow for flexibility in specifying subset of data, explainer, postprocessor and metric to process.

Source code in pnpxai/core/experiment/experiment.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def run_batch(
    self,
    data_ids: Sequence[int],
    explainer_id: int,
    postprocessor_id: int,
    metric_id: int,
) -> dict:
    """
    Runs the experiment for selected batch of data, explainer, postprocessor and metric.

    Args:
        data_ids (Sequence[int]): A sequence of data IDs to specify the subset of data to process.
        explainer_id (int): ID of explainer to use for the run.
        postprocessor_id (int): ID of postprocessor to use for the run.
        metrics_id (int): ID of metric to use for the run.

    Returns:
        The dictionary of inputs, labels, outputs, targets, explainer, explanation, postprocessor, postprocessed, metric, and evaluation.

    This method orchestrates the experiment by configuring the manager, obtaining explainer and metric instances,
    processing data, generating explanations, and evaluating metrics. It then caches the results in the manager, and returns back to the user.

    Note: The input parameters allow for flexibility in specifying subset of data, explainer, postprocessor and metric to process.
    """

    self.predict_batch(data_ids)
    self.explain_batch(data_ids, explainer_id)
    self.evaluate_batch(
        data_ids, explainer_id, postprocessor_id, metric_id)
    data = self.manager.batch_data_by_ids(data_ids)
    return {
        'inputs': self.input_extractor(data),
        'labels': self.label_extractor(data),
        'outputs': self.manager.batch_outputs_by_ids(data_ids),
        'targets': self._get_targets(data_ids),
        'explainer': self.manager.get_explainer_by_id(explainer_id),
        'explanation': self.manager.batch_explanations_by_ids(data_ids, explainer_id),
        'postprocessor': self.manager.get_postprocessor_by_id(postprocessor_id),
        'postprocessed': self.postprocess_batch(data_ids, explainer_id, postprocessor_id),
        'metric': self.manager.get_metric_by_id(metric_id),
        'evaluation': self.manager.batch_evaluations_by_ids(data_ids, explainer_id, postprocessor_id, metric_id),
    }