attribench.distributed.TrainAdversarialPatches
- class attribench.distributed.TrainAdversarialPatches(model_factory, dataset, num_patches, batch_size, path, labels=None, address='localhost', port='12355', devices=None)[source]
Bases:
DistributedComputationTrain adversarial patches for a given model and dataset and save them to disk. The patches are trained in parallel on multiple processes. Each process trains a subset of the patches.
- Parameters:
- model_factoryModelFactory
ModelFactory instance or callable that returns a model. Used to create a model for each subprocess.
- datasetDataset
Torch Dataset to use for training the patches.
- num_patchesint
Number of patches to train.
- batch_sizeint
Batch size per subprocess to use for training.
- pathstr
Path to which the patches should be saved.
- labelsOptional[Tuple[int]], optional
Tuple of labels to use for the patches. If None, the labels are assumed to be range(num_patches). Default: None.
- addressstr, optional
Address for communication between subprocesses, by default “localhost”
- portstr, optional
Port for communication between subprocesses, by default “12355”
- devicesOptional[Tuple[int]], optional
Devices to use. If None, then all available devices are used. By default None.
Methods
run