Source code for attribench.data.attributions_dataset._attributions_dataset

import torch
from torch.utils.data import TensorDataset
from attribench.data._index_dataset import IndexDataset
from .._typing import _check_is_dataset
from torch.utils.data import Dataset
import numpy as np
import h5py
from typing import List, Dict, Tuple


def _max_abs(arr: torch.Tensor, dim: int) -> torch.Tensor:
    return torch.max(torch.abs(arr), dim=dim, keepdim=True)


def _mean(arr: torch.Tensor, dim: int) -> torch.Tensor:
    return torch.mean(arr, dim=dim, keepdim=True)


def _check_is_dataset_or_tensor(obj) -> h5py.Dataset | torch.Tensor:
    if isinstance(obj, h5py.Dataset) or isinstance(obj, torch.Tensor):
        return obj
    else:
        raise ValueError(
            f"Expected obj to be a Dataset or Tensor, but got {type(obj)}"
        )


def _get_attributions_shape(
    attributions: Dict[str, torch.Tensor] | h5py.File, method_names: List[str]
) -> Tuple[int, ...]:
    shape = None
    for m_name in method_names:
        if shape is None:
            # If shape is None, set it to the shape of the first method
            if isinstance(attributions, h5py.File):
                dataset = _check_is_dataset(attributions[m_name])
                shape = dataset.shape
            else:
                shape = attributions[m_name].shape
        else:
            # Otherwise, check if the shape for the current method
            # is the same as the first
            if isinstance(attributions, h5py.File):
                dataset = _check_is_dataset(attributions[m_name])
                cur_shape = dataset.shape
            else:
                cur_shape = attributions[m_name].shape
            if shape != cur_shape:
                raise ValueError(
                    "Attributions must have the same shape for each method"
                )
    if shape is None:
        raise ValueError("Attributions must not be empty")
    return shape


def _parse_attributions_dict(
    attributions: Dict[str, torch.Tensor],
    methods: List[str] | None,
) -> Tuple[List[str], int, Tuple[int, ...]]:
    if methods is None:
        method_names = list(attributions.keys())
    elif all(m in attributions.keys() for m in methods):
        method_names = methods
    else:
        raise ValueError(f"Invalid methods: {methods}")

    shape = _get_attributions_shape(attributions, method_names)

    num_samples = shape[0]
    attributions_shape = shape[1:]
    return method_names, num_samples, attributions_shape


def _parse_attributions_file(
    path: str, methods: List[str] | None
) -> Tuple[List[str], int, Tuple[int, ...]]:
    with h5py.File(path, "r") as fp:
        # Check if methods argument is valid
        if methods is None:
            method_names = list(fp.keys())
        elif all(m in fp for m in methods):
            method_names = methods
        else:
            raise ValueError(f"Invalid methods: {methods}")

        # Check if num_samples metadata is valid
        num_samples = fp.attrs["num_samples"]
        if isinstance(num_samples, np.integer):
            num_samples = int(num_samples)
        else:
            raise ValueError(
                f"Expected num_samples to be an integer,"
                f" but got {type(num_samples)}"
            )

        # Check if attributions have the same shape for each method
        shape = _get_attributions_shape(fp, method_names)
        attributions_shape = shape[1:]
    return method_names, num_samples, attributions_shape


[docs]class AttributionsDataset(IndexDataset): """ Represents a dataset containing attributions for a set of samples and attribution methods. The samples and labels can be given in two ways. Either a PyTorch ``Dataset`` is passed to the ``samples`` argument containing both the samples and the labels, or a Tensor is passed to the ``samples`` argument and a Tensor is passed to the ``labels`` argument. An AttributionsDataset can be constructed from a dictionary of attributions or from an HDF5 file containing the attributions. If attributions are given using a dictionary, the dictionary must map method names to Tensors containing the attributions for each sample. The attributions must have the same shape for each method. The shape of the attributions must be ``[num_samples, *sample_shape]``. If attributions are given using an HDF5 file, the file must contain a dataset for each attribution method. The dataset must have the same shape for each method. The shape of the dataset must be ``[num_samples, *sample_shape]``. The file must also contain an attribute ``num_samples`` specifying the number of samples in the dataset. A list of method names can be given using the ``methods`` argument. If ``methods`` is None, all methods in the attributions dictionary or file are used. Otherwise, only the methods in the ``methods`` list are used. Attributions can be aggregated over some dimension by specifying the aggregate_dim and aggregate_method arguments. The aggregate_dim argument specifies the dimension over which to aggregate. The aggregate_method argument specifies the method to use for aggregation. The aggregate_method argument must be one of ``"mean"`` or ``"max_abs"``. Note that the aggregate_dim argument is specified in terms of the shape of the attributions, i.e. excluding the ``num_samples`` dimension. For example, if the attributions have shape ``[num_samples, 3, 32, 32]``, then the attributions can be aggregated over the channel dimension by setting ``aggregate_dim=0``. The resulting attributions will have shape ``[num_samples, 32, 32]``. """ def __init__( self, samples: Dataset | torch.Tensor, labels: torch.Tensor | None = None, path: str | None = None, attributions: Dict[str, torch.Tensor] | None = None, methods: List[str] | None = None, aggregate_dim: int = 0, aggregate_method: str | None = None, ): """ Parameters ---------- samples: Dataset | torch.Tensor A Dataset containing samples and labels, or a Tensor containing the samples for which attributions are given. labels: torch.Tensor | None A Tensor containing the labels for the samples. Only used if samples is a Tensor. path: str | None Path to an HDF5 file containing the attributions. If None, attributions must be given as a dictionary. attributions: Dict[str, torch.Tensor] | None A dictionary mapping attribution method names to Tensors containing the attributions for each sample. If None, a path to an HDF5 file must be given. methods: List[str] | None A list of method names to use. If None, all methods in the attributions dictionary are used. aggregate_dim: int If not None, aggregate the attributions over the given dimension. aggregate_method: str | None If not None, aggregate the attributions using the given method. Must be one of "mean" or "max_abs" or None. Raises ------ ValueError If attributions is None and path is None, or if labels is None and samples is a Tensor. """ self.path = path # If samples and labels are given as Tensors, wrap them in a Dataset self.samples_dataset: Dataset if isinstance(samples, torch.Tensor): if labels is None: raise ValueError( "Labels must be given if samples are given as a Tensor" ) self.samples_dataset = TensorDataset(samples, labels) else: self.samples_dataset = samples # Handle attributions dict or file self.attributions: Dict[str, torch.Tensor] | h5py.File | None = None self.method_names: List[str] orig_attributions_shape: Tuple[int, ...] if attributions is not None: # If attributions are given as a dict, parse the dict for metadata. ( self.method_names, self.num_samples, orig_attributions_shape, ) = _parse_attributions_dict(attributions, methods) self.attributions = attributions else: # Otherwise, a path must be given. Load metadata from HDF5 file. if path is None: raise ValueError("Either attributions or path must be given") ( self.method_names, self.num_samples, orig_attributions_shape, ) = _parse_attributions_file(path, methods) # Handle aggregation if necessary self.aggregate_fn = None self.aggregate_dim = aggregate_dim self.attributions_shape: Tuple[int, ...] if aggregate_method is not None: agg_fns = {"mean": _mean, "max_abs": _max_abs} self.aggregate_fn = agg_fns[aggregate_method] if self.aggregate_fn is not None: # If we aggregate over some axis, drop the corresponding axis self.attributions_shape = ( orig_attributions_shape[: self.aggregate_dim] + orig_attributions_shape[self.aggregate_dim + 1 :] ) else: self.attributions_shape = orig_attributions_shape def _open_attributions_file(self): self.attributions = h5py.File(self.path, "r") def __getitem__( self, index: int ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor, str]: if self.attributions is None: self._open_attributions_file() assert self.attributions is not None method_idx = index // self.num_samples method_name = self.method_names[method_idx] sample_idx = index % self.num_samples sample, label = self.samples_dataset[sample_idx] dataset = _check_is_dataset_or_tensor(self.attributions[method_name]) attrs = dataset[sample_idx] if not isinstance(attrs, torch.Tensor): attrs = torch.tensor(attrs) if self.aggregate_fn is not None: attrs = self.aggregate_fn(attrs, dim=self.aggregate_dim) return sample_idx, sample, label, attrs, method_name def __len__(self): return self.num_samples * len(self.method_names)
class GroupedAttributionsDataset(IndexDataset): def __init__(self, dataset: AttributionsDataset): super().__init__(dataset) self.dataset: AttributionsDataset = dataset self.method_names = dataset.method_names self.num_samples = dataset.num_samples self.attributions_shape = dataset.attributions_shape def __getitem__(self, index): if self.dataset.attributions is None: self.dataset._open_attributions_file() assert self.dataset.attributions is not None sample, label = self.dataset.samples_dataset[index] attrs: Dict[str, torch.Tensor] = {} for method_name in self.dataset.method_names: dataset = _check_is_dataset_or_tensor( self.dataset.attributions[method_name] ) attrs[method_name] = dataset[index] if not isinstance(attrs[method_name], torch.Tensor): attrs[method_name] = torch.tensor(attrs[method_name]) if self.dataset.aggregate_fn is not None: for method_name in self.dataset.method_names: attrs[method_name] = self.dataset.aggregate_fn( attrs[method_name], dim=self.dataset.aggregate_dim ) return index, sample, label, attrs def __len__(self): return self.dataset.num_samples