Source code for attribench.functional.metrics.sensitivity_n._sensitivity_n

import numpy as np
from torch import nn
import torch
import numpy.typing as npt
from typing import Callable, List, Mapping, Dict, Tuple
from attribench.masking import Masker
from attribench.masking.image import ImageMasker
from torch.utils.data import DataLoader
from attribench.data import AttributionsDataset
from attribench._activation_fns import ACTIVATION_FNS
from ._dataset import SensitivityNDataset, SegSensNDataset
from attribench._segmentation import segment_attributions
from attribench._stat import rowwise_pearsonr
from attribench.result import SensitivityNResult
from attribench.result._grouped_batch_result import GroupedBatchResult
from attribench.data.attributions_dataset._attributions_dataset import (
    GroupedAttributionsDataset,
)


def _get_orig_output(
    samples: torch.Tensor, model: Callable, activation_fns: List[str]
):
    activated_orig_output = {}
    with torch.no_grad():
        orig_output = model(samples)
        for activation_fn in activation_fns:
            activated_orig_output[activation_fn] = ACTIVATION_FNS[
                activation_fn
            ](orig_output)
    return activated_orig_output


def _compute_out_diffs(
    model: Callable,
    ds: SensitivityNDataset | SegSensNDataset,
    activation_fns: List[str],
    orig_output: Dict[str, torch.Tensor],
    labels: torch.Tensor,
) -> Tuple[Dict[str, Dict[int, npt.NDArray]], Dict[int, npt.NDArray]]:
    n_range = ds.n_range
    output_diff_shape = (ds.samples.shape[0], ds.num_subsets)
    # Calculate differences in output and removed indices
    # (will be re-used for all methods)
    # activation_fn -> n -> [batch_size, num_subsets]
    output_diffs: Dict[str, Dict[int, npt.NDArray]] = {
        activation_fn: {n: np.zeros(output_diff_shape) for n in n_range}
        for activation_fn in activation_fns
    }
    removed_indices: Dict[int, npt.NDArray] = {
        n: np.zeros((ds.samples.shape[0], ds.num_subsets, n), dtype=int)
        for n in n_range
    }
    # TODO why do we not use a dataloader here?
    for i in range(len(ds)):
        batch, indices, n, subset_idx = ds[i]
        n = n.item()
        with torch.no_grad():
            output = model(batch)
        for activation_fn in activation_fns:
            activated_output = ACTIVATION_FNS[activation_fn](output)
            # [batch_size, 1]
            output_diffs[activation_fn][n][:, subset_idx] = (
                (orig_output[activation_fn] - activated_output)
                .gather(dim=1, index=labels.unsqueeze(-1))
                .flatten()
                .detach()
                .cpu()
                .numpy()
            )
        removed_indices[n][:, subset_idx, :] = indices  # [batch_size, n]
    return output_diffs, removed_indices


def _compute_correlations(
    method_names: List[str],
    batch_attr: Dict[str, torch.Tensor],
    ds: SensitivityNDataset | SegSensNDataset,
    segmented: bool,
    removed_indices: Dict[int, npt.NDArray],
    output_diffs: Dict[str, Dict[int, npt.NDArray]],
    activation_fns: List[str],
) -> Dict[str, Dict[str, torch.Tensor]]:
    # activation_fn -> method_name -> [batch_size, len(n_range)]
    result = {
        activation_fn: {
            method_name: torch.zeros((ds.samples.shape[0], len(ds.n_range)))
            for method_name in method_names
        }
        for activation_fn in activation_fns
    }
    # Compute correlations for all methods
    # TODO can we use joblib here to compute this in parallel?
    for method_name in method_names:
        attrs = batch_attr[method_name].cpu().numpy()
        if segmented:
            assert isinstance(ds, SegSensNDataset)
            attrs = segment_attributions(
                ds.segmented_images.cpu().numpy(), attrs
            )
        # [batch_size, 1, -1]
        attrs = attrs.reshape((attrs.shape[0], 1, -1))
        for n_idx, n in enumerate(ds.n_range):
            # [batch_size, num_subsets, n]
            n_mask_attrs = np.take_along_axis(
                attrs, axis=-1, indices=removed_indices[n]
            )
            for activation_fn in activation_fns:
                # Compute sum of attributions
                # [batch_size, num_subsets]
                n_sum_of_attrs = n_mask_attrs.sum(axis=-1)
                n_output_diffs = output_diffs[activation_fn][n]
                # Compute correlation between output difference and
                # sum of attribution values
                result[activation_fn][method_name][:, n_idx] = torch.tensor(
                    rowwise_pearsonr(n_sum_of_attrs, n_output_diffs)
                )
    return result


def _sens_n_batch(
    samples: torch.Tensor,
    labels: torch.Tensor,
    model: Callable,
    attrs: Dict[str, torch.Tensor],
    maskers: Mapping[str, Masker],
    activation_fns: List[str],
    n_range: npt.NDArray,
    num_subsets: int,
    segmented: bool,
) -> Dict[str, Dict[str, Dict[str, torch.Tensor]]]:
    method_names = list(attrs.keys())
    orig_output = _get_orig_output(samples, model, activation_fns)
    # masker_name -> activation_fn -> method_name -> [batch_size, num_steps]
    batch_result: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = {}

    for masker_name, masker in maskers.items():
        # Create pseudo-dataset to generate perturbed samples
        if segmented:
            ds = SegSensNDataset(n_range, num_subsets, samples)
            assert isinstance(masker, ImageMasker)
            ds.set_masker(masker)
        else:
            ds = SensitivityNDataset(n_range, num_subsets, samples, masker)

        output_diffs, removed_indices = _compute_out_diffs(
            model, ds, activation_fns, orig_output, labels
        )

        batch_result[masker_name] = _compute_correlations(
            method_names,
            attrs,
            ds,
            segmented,
            removed_indices,
            output_diffs,
            activation_fns,
        )
    return batch_result


[docs]def sensitivity_n( model: nn.Module, attributions_dataset: AttributionsDataset, batch_size: int, maskers: Mapping[str, Masker], activation_fns: str | List[str], min_subset_size: float, max_subset_size: float, num_steps: int, num_subsets: int, segmented: bool, device: torch.device = torch.device("cpu"), ) -> SensitivityNResult: """Computes the Sensitivity-n metric for a given :class:`~attribench.data.AttributionsDataset` and model. Sensitivity-n is computed by iteratively masking a random subset of `n` features of the input samples and computing the output of the model on the masked samples. For each random subset of masked features, the sum of the attributions is also computed. This results in two series of values: the model output and the sum of the attributions. The Sensitivity-n metric is the correlation between these two series. This is repeated for different values of `n` between `min_subset_size` and `max_subset_size` in `num_steps` steps. `min_subset_size` and `max_subset_size` are percentages of the total number of features. For each value of `n`, `num_subsets` random subsets are generated. If segmented is True, then the Seg-Sensitivity-n metric is computed. This metric is analogous to Sensitivity-n, but instead of using random subsets of features, the images are first segmented into superpixels and then random subsets of superpixels are masked. This improves the signal-to-noise ratio of the metric for high-resolution images. The Sensitivity-n metric is computed for each masker in `maskers` and for each activation function in `activation_fns`. Parameters ---------- model : nn.Module Model to compute Sensitivity-n for. attributions_dataset : AttributionsDataset Dataset containing the attributions to compute Sensitivity-n on. batch_size : int Batch size to use when computing model output on masked samples. maskers : Dict[str, Masker] Dictionary of maskers to use. Keys are the names of the maskers. activation_fns : Union[Tuple[str], str] Activation functions to use. If a single string is passed, then the it is converted to a single-element list. min_subset_size : float Minimum percentage of features to mask. max_subset_size : float Maximum percentage of features to mask. num_steps : int Number of steps between `min_subset_size` and `max_subset_size`. num_subsets : int Number of random subsets to generate for each value of `n`. segmented : bool If True, then the Seg-Sensitivity-n metric is computed. device : torch.device, optional Device to use, by default torch.device("cpu") Returns ------- SensitivityNResult """ if isinstance(activation_fns, str): activation_fns = [activation_fns] model.to(device) model.eval() grouped_dataset = GroupedAttributionsDataset(attributions_dataset) dataloader = DataLoader( grouped_dataset, batch_size=batch_size, num_workers=4, pin_memory=True ) result = SensitivityNResult( attributions_dataset.method_names, list(maskers.keys()), list(activation_fns), num_samples=attributions_dataset.num_samples, num_steps=num_steps, ) # Compute range of subset sizes n_range = np.linspace(min_subset_size, max_subset_size, num_steps) if segmented: n_range = n_range * 100 else: total_num_features = np.prod(attributions_dataset.attributions_shape) n_range = n_range * total_num_features n_range = n_range.astype(int) for ( batch_indices, batch_x, batch_y, batch_attr, ) in dataloader: batch_x = batch_x.to(device) batch_y = batch_y.to(device) batch_result = _sens_n_batch( batch_x, batch_y, model, batch_attr, maskers, activation_fns, n_range, num_subsets, segmented, ) result.add(GroupedBatchResult(batch_indices, batch_result)) return result