Source code for attribench.functional.metrics._max_sensitivity

from attribench.result import MaxSensitivityResult
from attribench.result._grouped_batch_result import GroupedBatchResult
from typing import Dict
from attribench.data import AttributionsDataset
from attribench.data.attributions_dataset._attributions_dataset import (
    GroupedAttributionsDataset,
)
import torch
from attribench._attribution_method import AttributionMethod
from torch.utils.data import DataLoader
import math


def _normalize_attrs(attrs):
    flattened = attrs.flatten(1)
    return flattened / torch.norm(flattened, dim=1, p=math.inf, keepdim=True)


def _max_sensitivity_batch(
    batch_x: torch.Tensor,
    batch_y: torch.Tensor,
    batch_attr: Dict[str, torch.Tensor],
    method_dict: Dict[str, AttributionMethod],
    num_perturbations: int,
    radius: float,
    device: torch.device,
) -> Dict[str, torch.Tensor]:
    if set(method_dict.keys()) != set(batch_attr.keys()):
        print(method_dict.keys())
        print(batch_attr.keys())
        raise ValueError(
            "Method dictionary and batch attributions dictionary"
            " must have the same keys."
        )
    result: Dict[str, torch.Tensor] = {
        method_name: torch.zeros(1) for method_name in method_dict.keys()
    }
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)

    # Compute Max-Sensitivity for each method
    for method_name, method in method_dict.items():
        #attrs = _normalize_attrs(batch_attr[method_name])
        # TODO we are not taking into account the aggregation here
        # TODO also the AttributionsDataset is not being used
        attrs = _normalize_attrs(method(batch_x, batch_y).detach()).cpu()
        diffs = []

        for _ in range(num_perturbations):
            # Add uniform noise with infinity norm <= radius
            # torch.rand generates noise between 0 and 1
            # => This generates noise between -radius and radius
            noise = (
                torch.rand(batch_x.shape, device=device) * 2 * radius - radius
            )
            noisy_samples = batch_x + noise
            # Get new attributions from noisy samples
            noisy_attrs = _normalize_attrs(
                method(noisy_samples, batch_y).detach()
            )
            # Get relative norm of attribution difference
            # [batch_size]
            diffs.append(torch.norm(noisy_attrs.cpu() - attrs, dim=1).detach())
        # [batch_size, num_perturbations]
        diffs = torch.stack(diffs, 1)
        # [batch_size]
        result[method_name] = diffs.max(dim=1)[0].cpu()
    return result


[docs]def max_sensitivity( attributions_dataset: AttributionsDataset, batch_size: int, method_dict: Dict[str, AttributionMethod], num_perturbations: int, radius: float, device: torch.device = torch.device("cpu"), ) -> MaxSensitivityResult: """Computes the Max-Sensitivity metric for a given `Dataset` and attribution methods. Max-Sensitivity is computed by adding a small amount of uniform noise to the input samples and computing the norm of the difference in attributions between the original samples and the noisy samples. The maximum norm of difference is then taken as the Max-Sensitivity. The idea is that a small amount of noise should not change the attributions significantly, so the norm of the difference should be small. If the norm is large, then the attributions are not robust to small perturbations in the input. Parameters ---------- attributions_dataset : Dataset The dataset for which the Max-Sensitivity should be computed. batch_size : int The batch size to use for computing the Max-Sensitivity. method_dict : Dict[str, AttributionMethod] Dictionary of attribution methods for which the Max-Sensitivity should be computed. num_perturbations : int The number of perturbations to use for computing the Max-Sensitivity. radius : float The radius of the uniform noise to add to the input samples. device : torch.device, optional Device to use, by default `torch.device("cpu")`. """ grouped_dataset = GroupedAttributionsDataset(attributions_dataset) dataloader = DataLoader( grouped_dataset, batch_size=batch_size, num_workers=4, pin_memory=True, ) method_names = list(method_dict.keys()) result = MaxSensitivityResult( method_names, num_samples=len(grouped_dataset), ) for batch_indices, batch_x, batch_y, batch_attr in dataloader: batch_x = batch_x.to(device) batch_y = batch_y.to(device) batch_result = _max_sensitivity_batch( batch_x, batch_y, batch_attr, method_dict, num_perturbations, radius, device, ) result.add(GroupedBatchResult(batch_indices, batch_result)) return result