attribench.distributed.SelectSamples

class attribench.distributed.SelectSamples(model_factory, dataset, num_samples, batch_size, address='localhost', port='12355', devices=None)[source]

Bases: DistributedComputation

Select correctly classified samples from a dataset and write them to a HDF5 file. This is done in a distributed fashion, i.e. each subprocess selects a part of the samples and writes them to the HDF5 file. The number of processes is determined by the number of devices.

If you want to select correctly classified samples and return them, rather than storing them to a HDF5 file, use attribench.functional.select_samples() instead.

Parameters:
model_factoryModelFactory

ModelFactory instance or callable that returns a model. Used to instantiate a model for each subprocess.

datasetDataset

Torch Dataset containing the samples and labels.

writerHDF5DatasetWriter

Writer to write the samples and labels to.

num_samplesint

Number of correctly classified samples to select.

batch_sizeint

Batch size per subprocess to use for the dataloader.

addressstr, optional

Address to use for the multiprocessing connection, by default “localhost”

portstr, optional

Port to use for the multiprocessing connection, by default “12355”

devicesTuple, optional

Devices to use. If None, then all available devices are used. By default None.

Methods

run

Run the sample selection.

run(path)[source]

Run the sample selection.

Parameters:
pathstr

Path to the HDF5 file to write the samples to.