attribench.distributed.metrics.MinimalSubset
- class attribench.distributed.metrics.MinimalSubset(model_factory, attributions_dataset, batch_size, maskers, mode='deletion', num_steps=100, address='localhost', port='12355', devices=None)[source]
Bases:
MetricComputes Minimal Subset Deletion or Insertion for a given
AttributionsDatasetand model using multiple processes.Minimal Subset Deletion or Insertion is computed by iteratively masking (Deletion) or revealing (Insertion) the top features of the input samples and computing the prediction of the model on the masked samples.
Minimal Subset Deletion is the minimal number of features that must be masked to change the model’s prediction from its original prediction. Minimal Subset Insertion is the minimal number of features that must be revealed to get the model’s original prediction.
The Minimal Subset metric is computed for each masker in maskers. The number of processes is determined by the number of devices. If devices is None, then all available devices are used. Samples are distributed evenly across the processes.
- Parameters:
- model_factoryModelFactory
ModelFactory instance or callable that returns a model. Used to create a model for each subprocess.
- attributions_datasetAttributionsDataset
Dataset containing the samples and attributions to compute the Minimal Subset metric for.
- batch_sizeint
Batch size per subprocess to use when computing the metric.
- maskersDict[str, Masker]
Dictionary mapping masker names to Masker objects.
- modestr, optional
“deletion” or “insertion”, by default “deletion”
- num_stepsint, optional
Number of steps to use when computing the Minimal Subset metric, by default 100. More steps will result in a more accurate metric, but will take longer to compute.
- addressstr, optional
Address to use for multiprocessing, by default “localhost”
- portstr, optional
Port to use for multiprocessing, by default “12355”
- devicesOptional[Tuple], optional
Tuple of devices to use for multiprocessing, by default None. If None, all available devices are used.
- Raises:
- ValueError
If mode is not “deletion” or “insertion”.
Methods
Runs the metric computation and optionally saves the result.
Save the result to disk.
Attributes
result- run(result_path=None, progress=True)
Runs the metric computation and optionally saves the result. If no result path is given, the result will not be saved to disk. It can still be accessed via the
resultproperty.- Parameters:
- result_pathstr, optional
Path to save the result to. If None, the result is not saved to disk.
- progressbool, optional
Whether to show a progress bar. Defaults to True.
- save_result(path, format='hdf5')
Save the result to disk.
- Parameters:
- pathstr
Path to save the result to.
- formatstr, optional
Format to save the result in. If
"hdf5", the result is saved as an HDF5 file. If"csv", the result is saved as a directory structure containing CSV files. Default:"hdf5".
- Raises:
- ValueError
If the result is None.