Source code for attribench.functional.metrics.minimal_subset._minimal_subset

from torch.utils.data import DataLoader
from ._dataset import (
    MinimalSubsetDeletionDataset,
    MinimalSubsetInsertionDataset,
)
from attribench.masking import Masker
from typing import Callable, Dict
import torch
from torch import nn
from attribench.data import AttributionsDataset
from attribench.result import MinimalSubsetResult
from attribench.result._batch_result import BatchResult


def minimal_subset_batch(
    samples: torch.Tensor,
    model: Callable,
    attrs: torch.Tensor,
    num_steps: float,
    maskers: Dict[str, Masker],
    mode: str,
) -> Dict[str, torch.Tensor]:
    batch_result: Dict[str, torch.Tensor] = {}
    for masker_name, masker in maskers.items():
        if mode == "deletion":
            ds = MinimalSubsetDeletionDataset(
                num_steps, samples, attrs, masker
            )
            criterion_fn = lambda pred, orig: pred != orig
        elif mode == "insertion":
            ds = MinimalSubsetInsertionDataset(
                num_steps, samples, attrs, masker
            )
            criterion_fn = lambda pred, orig: pred == orig
        else:
            raise ValueError("Mode must be deletion or insertion. Got:", mode)

        # Initialize datastructures
        masker_result = torch.tensor(
            [-1 for _ in range(ds.samples.shape[0])]
        ).int()
        flipped = torch.tensor(
            [False for _ in range(ds.samples.shape[0])]
        ).bool()
        orig_predictions = torch.argmax(model(ds.samples), dim=1)

        # The MinimalSubsetDataset is an iterator that returns batches of
        # masked samples and the number of features that were masked.
        it = iter(ds)
        batch = next(it)
        while not torch.all(flipped) and batch is not None:
            masked_samples, mask_size = batch
            # Get output of model on masked samples
            with torch.no_grad():
                masked_output = model(masked_samples)
            predictions = torch.argmax(masked_output, dim=1)

            # Check which samples were flipped to either a different class
            # (deletion) or the original class (insertion)
            criterion = criterion_fn(predictions, orig_predictions)
            new_flipped = torch.logical_or(flipped, criterion.cpu())
            # Record which samples were flipped this iteration
            flipped_this_iteration = new_flipped != flipped
            masker_result[flipped_this_iteration] = mask_size
            flipped = new_flipped
            # Get next batch
            try:
                batch = next(it)
            except StopIteration:
                break

        # Set maximum value for samples that were never flipped
        num_inputs = attrs.reshape(attrs.shape[0], -1).shape[1]
        masker_result[masker_result == -1] = num_inputs
        batch_result[masker_name] = masker_result.reshape(-1, 1)
    return batch_result


[docs]def minimal_subset( model: nn.Module, attributions_dataset: AttributionsDataset, batch_size: int, maskers: Dict[str, Masker], mode: str = "deletion", num_steps: int = 100, device: torch.device = torch.device("cpu"), ) -> MinimalSubsetResult: """Computes Minimal Subset Deletion or Insertion for a given :class:`~attribench.data.AttributionsDataset` and model. Minimal Subset Deletion or Insertion is computed by iteratively masking (Deletion) or revealing (Insertion) the top features of the input samples and computing the prediction of the model on the masked samples. Minimal Subset Deletion is the minimal number of features that must be masked to change the model's prediction from its original prediction. Minimal Subset Insertion is the minimal number of features that must be revealed to get the model's original prediction. The Minimal Subset metric is computed for each masker in `maskers`. Parameters ---------- model : nn.Module Model to compute the Minimal Subset metric for. attributions_dataset : AttributionsDataset Dataset containing the samples and attributions to compute the Minimal Subset metric for. batch_size : int Batch size to use when computing the Minimal Subset metric. maskers : Dict[str, Masker] Dictionary mapping masker names to `Masker` objects. mode : str, optional "deletion" or "insertion", by default "deletion" num_steps : int, optional Number of steps to use when computing the Minimal Subset metric, by default 100. More steps will result in a more accurate metric, but will take longer to compute. device : torch.device, optional Device to use when computing the Minimal Subset metric, by default torch.device("cpu") Returns ------- MinimalSubsetResult """ model.to(device) model.eval() dataloader = DataLoader( attributions_dataset, batch_size=batch_size, num_workers=4, pin_memory=True ) result = MinimalSubsetResult( attributions_dataset.method_names, list(maskers.keys()), mode, num_samples=len(attributions_dataset), ) for ( batch_indices, batch_x, _, batch_attr, method_names, ) in dataloader: batch_x = batch_x.to(device) batch_result = minimal_subset_batch( batch_x, model, batch_attr, num_steps, maskers, mode ) result.add(BatchResult(batch_indices, batch_result, method_names)) return result