Source code for attribench.distributed.metrics.max_sensitivity._max_sensitivity

from attribench.data.attributions_dataset._attributions_dataset import (
    AttributionsDataset,
    GroupedAttributionsDataset
)
from .._metric_worker import MetricWorker
from ..._worker import WorkerConfig
from .._metric import Metric
from typing import Optional, Tuple
from torch.utils.data import Dataset
from attribench import MethodFactory
from attribench.result._max_sensitivity_result import MaxSensitivityResult
from ._max_sensitivity_worker import MaxSensitivityWorker
from attribench._model_factory import ModelFactory


[docs]class MaxSensitivity(Metric): """Compute the Max-Sensitivity metric for a given `Dataset` and attribution methods using multiple processes. Max-Sensitivity is computed by adding a small amount of uniform noise to the input samples and computing the norm of the difference in attributions between the original samples and the noisy samples. The maximum norm of difference is then taken as the Max-Sensitivity. The idea is that a small amount of noise should not change the attributions significantly, so the norm of the difference should be small. If the norm is large, then the attributions are not robust to small perturbations in the input. The number of processes is determined by the number of devices. If `devices` is None, then all available devices are used. Samples are distributed evenly across the processes. Each subprocess computes the Max-Sensitivity for a all attribution methods on a subset of the samples. """ def __init__( self, model_factory: ModelFactory, attributions_dataset: AttributionsDataset, batch_size: int, method_factory: MethodFactory, num_perturbations: int, radius: float, address="localhost", port="12355", devices: Optional[Tuple] = None, ): """ Parameters ---------- model_factory : ModelFactory ModelFactory instance or callable that returns a model. Used to create a model for each subprocess. attributions_dataset : Dataset Torch Dataset to use for computing the Max-Sensitivity. batch_size : int The batch size per subprocess to use for computing the Max-Sensitivity. method_factory : MethodFactory MethodFactory instance or callable that returns a dictionary of attribution methods, given a model. num_perturbations : int The number of perturbations to use for computing the Max-Sensitivity. radius : float The radius of the uniform noise to add to the input samples. address : str, optional Address to use for the multiprocessing connection. Default: "localhost" port : str, optional Port to use for the multiprocessing connection. Default: "12355" devices : Optional[Tuple], optional Tuple of devices to use for multiprocessing. If `None`, all available devices are used. """ super().__init__( model_factory, attributions_dataset, batch_size, address, port, devices, ) self.method_factory = method_factory self.num_perturbations = num_perturbations self.radius = radius self.dataset = GroupedAttributionsDataset(attributions_dataset) self._result = MaxSensitivityResult( method_factory.get_method_names(), num_samples=attributions_dataset.num_samples ) def _create_worker(self, worker_config: WorkerConfig) -> MetricWorker: return MaxSensitivityWorker( worker_config, self.model_factory, self.dataset, self.batch_size, self.method_factory, self.num_perturbations, self.radius, )