attribench.distributed.SelectSamples
- class attribench.distributed.SelectSamples(model_factory, dataset, num_samples, batch_size, address='localhost', port='12355', devices=None)[source]
Bases:
DistributedComputationSelect correctly classified samples from a dataset and write them to a HDF5 file. This is done in a distributed fashion, i.e. each subprocess selects a part of the samples and writes them to the HDF5 file. The number of processes is determined by the number of devices.
If you want to select correctly classified samples and return them, rather than storing them to a HDF5 file, use
attribench.functional.select_samples()instead.- Parameters:
- model_factoryModelFactory
ModelFactory instance or callable that returns a model. Used to instantiate a model for each subprocess.
- datasetDataset
Torch Dataset containing the samples and labels.
- writerHDF5DatasetWriter
Writer to write the samples and labels to.
- num_samplesint
Number of correctly classified samples to select.
- batch_sizeint
Batch size per subprocess to use for the dataloader.
- addressstr, optional
Address to use for the multiprocessing connection, by default “localhost”
- portstr, optional
Port to use for the multiprocessing connection, by default “12355”
- devicesTuple, optional
Devices to use. If None, then all available devices are used. By default None.
Methods
Run the sample selection.