Source code for attribench.distributed.metrics.minimal_subset._minimal_subset

from typing import Tuple, Dict, Optional
from ..._worker import WorkerConfig
from attribench.masking import Masker
from ._minimal_subset_worker import MinimalSubsetWorker
from attribench.result import MinimalSubsetResult
from .._metric import Metric
from attribench.data import AttributionsDataset
from attribench._model_factory import ModelFactory


[docs]class MinimalSubset(Metric): """Computes Minimal Subset Deletion or Insertion for a given :class:`~attribench.data.AttributionsDataset` and model using multiple processes. Minimal Subset Deletion or Insertion is computed by iteratively masking (Deletion) or revealing (Insertion) the top features of the input samples and computing the prediction of the model on the masked samples. Minimal Subset Deletion is the minimal number of features that must be masked to change the model's prediction from its original prediction. Minimal Subset Insertion is the minimal number of features that must be revealed to get the model's original prediction. The Minimal Subset metric is computed for each masker in `maskers`. The number of processes is determined by the number of devices. If `devices` is None, then all available devices are used. Samples are distributed evenly across the processes. """ def __init__( self, model_factory: ModelFactory, attributions_dataset: AttributionsDataset, batch_size: int, maskers: Dict[str, Masker], mode: str = "deletion", num_steps: int = 100, address="localhost", port="12355", devices: Optional[Tuple] = None, ): """ Parameters ---------- model_factory : ModelFactory ModelFactory instance or callable that returns a model. Used to create a model for each subprocess. attributions_dataset : AttributionsDataset Dataset containing the samples and attributions to compute the Minimal Subset metric for. batch_size : int Batch size per subprocess to use when computing the metric. maskers : Dict[str, Masker] Dictionary mapping masker names to `Masker` objects. mode : str, optional "deletion" or "insertion", by default "deletion" num_steps : int, optional Number of steps to use when computing the Minimal Subset metric, by default 100. More steps will result in a more accurate metric, but will take longer to compute. address : str, optional Address to use for multiprocessing, by default "localhost" port : str, optional Port to use for multiprocessing, by default "12355" devices : Optional[Tuple], optional Tuple of devices to use for multiprocessing, by default None. If None, all available devices are used. Raises ------ ValueError If `mode` is not "deletion" or "insertion". """ super().__init__( model_factory, attributions_dataset, batch_size, address, port, devices ) self.dataset = attributions_dataset self.num_steps = num_steps if mode not in ["deletion", "insertion"]: raise ValueError("Mode must be deletion or insertion. Got:", mode) self.mode = mode self.maskers = maskers self._result = MinimalSubsetResult( attributions_dataset.method_names, list(maskers.keys()), mode, num_samples=attributions_dataset.num_samples, ) def _create_worker( self, worker_config: WorkerConfig ) -> MinimalSubsetWorker: return MinimalSubsetWorker( worker_config, self.model_factory, self.dataset, self.batch_size, self.maskers, self.mode, self.num_steps, )