attribench.distributed.metrics.ImpactCoverage
- class attribench.distributed.metrics.ImpactCoverage(model_factory, samples_dataset, batch_size, method_factory, patch_folder, address='localhost', port='12355', devices=None)[source]
Bases:
MetricComputes the Impact Coverage metric for a given dataset, model, and set of attribution methods, using multiple processes.
Impact Coverage is computed by applying an adversarial patch to the input. This patch causes the model to change its prediction. The Impact Coverage metric is the intersection over union (IoU) of the patch with the top n attributions of the input, where n is the number of features masked by the patch. The idea is that, as the patch causes the model to change its prediction, the corresponding region in the image should be highly relevant to the model’s prediction.
Impact Coverage requires a folder containing adversarial patches. The patches should be named as follows: patch_<target>.pt, where <target> is the target class of the patch. The target class is the class that the model will predict when the patch is applied to the input.
The number of processes is determined by the number of devices. If devices is None, then all available devices are used. Samples are distributed evenly across the processes.
To generate adversarial patches, the
train_adversarial_patches()function orTrainAdversarialPatchesclass can be used.- Parameters:
- model_factoryModelFactory
ModelFactory instance or callable that returns a model. Used to create a model for each subprocess.
- datasetDataset
Dataset to compute Impact Coverage for.
- batch_sizeint
Batch size to use when computing Impact Coverage.
- method_factoryMethodFactory
MethodFactory instance or callable that returns a dictionary mapping method names to attribution methods, given a model.
- patch_folderstr
Path to folder containing adversarial patches.
- addressstr, optional
Address to use for the multiprocessing connection, by default “localhost”
- portstr, optional
Port to use for the multiprocessing connection, by default “12355”
- devicesOptional[Tuple], optional
Devices to use. If None, then all available devices are used. By default None.
Methods
Runs the metric computation and optionally saves the result.
Save the result to disk.
Attributes
result- run(result_path=None, progress=True)
Runs the metric computation and optionally saves the result. If no result path is given, the result will not be saved to disk. It can still be accessed via the
resultproperty.- Parameters:
- result_pathstr, optional
Path to save the result to. If None, the result is not saved to disk.
- progressbool, optional
Whether to show a progress bar. Defaults to True.
- save_result(path, format='hdf5')
Save the result to disk.
- Parameters:
- pathstr
Path to save the result to.
- formatstr, optional
Format to save the result in. If
"hdf5", the result is saved as an HDF5 file. If"csv", the result is saved as a directory structure containing CSV files. Default:"hdf5".
- Raises:
- ValueError
If the result is None.