attribench.distributed.TrainAdversarialPatches

class attribench.distributed.TrainAdversarialPatches(model_factory, dataset, num_patches, batch_size, path, labels=None, address='localhost', port='12355', devices=None)[source]

Bases: DistributedComputation

Train adversarial patches for a given model and dataset and save them to disk. The patches are trained in parallel on multiple processes. Each process trains a subset of the patches.

Parameters:
model_factoryModelFactory

ModelFactory instance or callable that returns a model. Used to create a model for each subprocess.

datasetDataset

Torch Dataset to use for training the patches.

num_patchesint

Number of patches to train.

batch_sizeint

Batch size per subprocess to use for training.

pathstr

Path to which the patches should be saved.

labelsOptional[Tuple[int]], optional

Tuple of labels to use for the patches. If None, the labels are assumed to be range(num_patches). Default: None.

addressstr, optional

Address for communication between subprocesses, by default “localhost”

portstr, optional

Port for communication between subprocesses, by default “12355”

devicesOptional[Tuple[int]], optional

Devices to use. If None, then all available devices are used. By default None.

Methods

run