attribench.distributed.ComputeAttributions

class attribench.distributed.ComputeAttributions(model_factory, method_factory, dataset, batch_size, address='localhost', port='12355', devices=None)[source]

Bases: DistributedComputation

Compute attributions for a dataset using multiple processes. The attributions are written to a HDF5 file. The number of processes is determined by the number of devices. If no devices are specified, then all available devices are used. Samples are distributed evenly across the processes.

If you want to compute attributions and simply return them, rather than storing them in a file, then use the compute_attributions() function instead.

Parameters:
model_factoryModelFactory

ModelFactory instance or callable that returns a model. Used to create a model for each subprocess.

method_factoryMethodFactory

MethodFactory instance or callable that returns a dictionary of attribution methods, given a model.

datasetDataset

Torch Dataset to use for computing the attributions.

batch_sizeint

The batch size to use for computing the attributions.

writerAttributionsDatasetWriter

AttributionsDatasetWriter to write the attributions to.

addressstr, optional

Address to use for the multiprocessing connection. By default “localhost”.

portstr, optional

Port to use for the multiprocessing connection. By default “12355”.

devicesOptional[Tuple], optional

Devices to use. If None, then all available devices are used. By default None.

Methods

run

Run the computation.

run(path)[source]

Run the computation.

Parameters:
pathstr

Path to the HDF5 file to write the attributions to.