attribench.distributed.metrics.Infidelity
- class attribench.distributed.metrics.Infidelity(model_factory, attributions_dataset, batch_size, activation_fns, perturbation_generators, num_perturbations, address='localhost', port='12355', devices=None)[source]
Bases:
MetricComputes the Infidelity metric for a given
AttributionsDatasetand model using multiple processes.Infidelity is computed by generating perturbations for each sample in the dataset and computing the difference in the model’s output on the original sample and the perturbed sample. This difference is then compared to the dot product of the perturbation vector and the attribution map for each attribution method. The Infidelity metric is the mean squared error between these two values.
The idea is that if the dot product is large, then the perturbation vector is aligned with the attribution map, and the model’s output should change significantly when the perturbation is applied. If the dot product is small, then the perturbation vector is not aligned with the attribution map, and the model’s output should not change significantly when the perturbation is applied.
The mean squared error is computed for num_perturbations perturbations for each sample. The perturbation_generators argument is a dictionary mapping perturbation generator names to PerturbationGenerator objects. These objects can be used to implement different versions of Infidelity.
The Infidelity metric is computed for each perturbation generator in perturbation_generators and each activation function in activation_fns. The number of processes is determined by the number of devices. If devices is None, then all available devices are used. Samples are distributed evenly across the processes.
- Parameters:
- model_factoryModelFactory
ModelFactory instance or callable that returns a model. Used to create a model for each subprocess.
- attributions_datasetAttributionsDataset
Dataset containing the samples and attributions to compute Infidelity on.
- batch_sizeint
Batch size to use when computing Infidelity.
- activation_fnsTuple[str]
Tuple of activation functions to use when computing Infidelity.
- perturbation_generatorsDict[str, PerturbationGenerator]
Dictionary of perturbation generators to use for generating perturbations.
- num_perturbationsint
Number of perturbations to generate for each sample.
- addressstr, optional
Address to use for the multiprocessing connection, by default “localhost”
- portstr, optional
Port to use for the multiprocessing connection, by default “12355”
- devicesOptional[Tuple], optional
Devices to use. If None, then all available devices are used. By default None.
Methods
Runs the metric computation and optionally saves the result.
Save the result to disk.
Attributes
result- run(result_path=None, progress=True)
Runs the metric computation and optionally saves the result. If no result path is given, the result will not be saved to disk. It can still be accessed via the
resultproperty.- Parameters:
- result_pathstr, optional
Path to save the result to. If None, the result is not saved to disk.
- progressbool, optional
Whether to show a progress bar. Defaults to True.
- save_result(path, format='hdf5')
Save the result to disk.
- Parameters:
- pathstr
Path to save the result to.
- formatstr, optional
Format to save the result in. If
"hdf5", the result is saved as an HDF5 file. If"csv", the result is saved as a directory structure containing CSV files. Default:"hdf5".
- Raises:
- ValueError
If the result is None.