import torch
from torch import nn
from typing import List, Union, Mapping, Dict
from attribench.masking.image import ImageMasker
from attribench.functional.metrics.deletion._dataset import IrofDataset
from attribench.functional.metrics.deletion._get_predictions import (
get_predictions,
)
from attribench.data import AttributionsDataset
from torch.utils.data import DataLoader
from attribench.result._deletion_result import DeletionResult
from attribench.result._batch_result import BatchResult
def _irof_batch(
samples: torch.Tensor,
labels: torch.Tensor,
model: nn.Module,
attrs: torch.Tensor,
maskers: Mapping[str, ImageMasker],
activation_fns: List[str],
mode: str,
start: float,
stop: float,
num_steps: int,
) -> Dict:
result_dict = {}
for masker_name, masker in maskers.items():
masking_dataset = IrofDataset(
mode, start, stop, num_steps, samples, masker
)
masking_dataset.set_attrs(attrs)
result_dict[masker_name] = get_predictions(
masking_dataset, labels, model, activation_fns
)
return result_dict
[docs]def irof(
model: nn.Module,
attributions_dataset: AttributionsDataset,
batch_size: int,
maskers: Mapping[str, ImageMasker],
activation_fns: Union[List[str], str] = "linear",
mode: str = "morf",
start: float = 0.0,
stop: float = 1.0,
num_steps: int = 100,
device: torch.device = torch.device("cpu"),
):
"""Computes the IROF metric for a given :class:`~attribench.data.AttributionsDataset` and model.
IROF starts segmenting the input image using SLIC. Then, it iteratively
masks out the top (Most Relevant First, or MoRF) or bottom (Least Relevant
First, or LeRF) segments and computes the confidence of the model on the
masked samples. The relevance of a segment is computed as the average
relevance of the features in the segment.
This results in a curve of confidence vs. number of segments masked. The
area under (or equivalently over) this curve is the IROF metric.
`start`, `stop`, and `num_steps` are used to determine the range of segments
to mask. The range is determined by `start` and `stop` as a percentage of
the total number of segments. `num_steps` is the number of steps to take
between `start` and `stop`.
The IROF metric is computed for each masker in `maskers` and for each
activation function in `activation_fns`.
Parameters
----------
model : nn.Module
Model to compute IROF on.
attributions_dataset : AttributionsDataset
Dataset of attributions to compute IROF on.
batch_size : int
Batch size to use when computing model predictions on masked samples.
maskers : Mapping[str, ImageMasker]
Dictionary of maskers to use for masking samples.
activation_fns : Union[List[str], str], optional
List of activation functions to use when computing model predictions on
masked samples. If a single string is passed, it is converted to a
single-element list.
Default: "linear"
mode : str, optional
Mode to use when masking samples. Must be "morf" or "lerf".
Default: "morf"
start : float, optional
Relative start of the range of segments to mask. Must be between 0 and 1.
Default: 0.0
stop : float, optional
Relative stop of the range of segments to mask. Must be between 0 and 1.
Default: 1.0
num_steps : int, optional
Number of steps to take between `start` and `stop`.
Default: 100
"""
if isinstance(activation_fns, str):
activation_fns = [activation_fns]
model.to(device)
model.eval()
dataloader = DataLoader(
attributions_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
)
result = DeletionResult(
attributions_dataset.method_names,
list(maskers.keys()),
activation_fns,
mode,
num_samples=attributions_dataset.num_samples,
num_steps=num_steps,
)
for (
batch_indices,
batch_x,
batch_y,
batch_attr,
method_names,
) in dataloader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_result = _irof_batch(
batch_x,
batch_y,
model,
batch_attr,
maskers,
activation_fns,
mode,
start,
stop,
num_steps,
)
result.add(BatchResult(batch_indices, batch_result, method_names))