attribench.distributed.ComputeAttributions
- class attribench.distributed.ComputeAttributions(model_factory, method_factory, dataset, batch_size, address='localhost', port='12355', devices=None)[source]
Bases:
DistributedComputationCompute attributions for a dataset using multiple processes. The attributions are written to a HDF5 file. The number of processes is determined by the number of devices. If no devices are specified, then all available devices are used. Samples are distributed evenly across the processes.
If you want to compute attributions and simply return them, rather than storing them in a file, then use the
compute_attributions()function instead.- Parameters:
- model_factoryModelFactory
ModelFactory instance or callable that returns a model. Used to create a model for each subprocess.
- method_factoryMethodFactory
MethodFactory instance or callable that returns a dictionary of attribution methods, given a model.
- datasetDataset
Torch Dataset to use for computing the attributions.
- batch_sizeint
The batch size to use for computing the attributions.
- writerAttributionsDatasetWriter
AttributionsDatasetWriter to write the attributions to.
- addressstr, optional
Address to use for the multiprocessing connection. By default “localhost”.
- portstr, optional
Port to use for the multiprocessing connection. By default “12355”.
- devicesOptional[Tuple], optional
Devices to use. If None, then all available devices are used. By default None.
Methods
Run the computation.