Source code for attribench.distributed._select_samples

from ._distributed_sampler import DistributedSampler
from ._distributed_computation import DistributedComputation
from ._message import PartialResultMessage
from ._worker import Worker, WorkerConfig
from attribench import ModelFactory
from tqdm import tqdm
from attribench.data import HDF5DatasetWriter
from attribench.functional._select_samples import _select_samples_batch
from typing import Callable, Tuple, Optional
from torch.utils.data import Dataset, DataLoader
import torch
from numpy import typing as npt
from torch import nn
from multiprocessing.synchronize import Event


class SamplesResult:
    def __init__(self, samples: npt.NDArray, labels: npt.NDArray):
        self.samples = samples
        self.labels = labels


class SampleSelectionWorker(Worker):
    def __init__(
        self,
        worker_config: WorkerConfig,
        sufficient_samples: Event,
        batch_size: int,
        dataset: Dataset,
        model_factory: Callable[[], nn.Module],
    ):
        super().__init__(worker_config)
        self.sufficient_samples = sufficient_samples
        self.model_factory = model_factory
        self.dataset = dataset
        self.batch_size = batch_size

    def work(self):
        sampler = DistributedSampler(
            self.dataset,
            self.worker_config.world_size,
            self.worker_config.rank,
        )
        dataloader = DataLoader(
            self.dataset,
            sampler=sampler,
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True,
        )
        device = torch.device(self.worker_config.rank)
        model = self.model_factory()
        model.to(device)
        it = iter(dataloader)

        for batch_x, batch_y in it:
            correct_samples, correct_labels = _select_samples_batch(
                batch_x, batch_y, model, device
            )
            result = SamplesResult(
                correct_samples.cpu().numpy(), correct_labels.cpu().numpy()
            )
            self.worker_config.send_result(
                PartialResultMessage(self.worker_config.rank, result)
            )
            if self.sufficient_samples.is_set():
                break


[docs]class SelectSamples(DistributedComputation): """Select correctly classified samples from a dataset and write them to a HDF5 file. This is done in a distributed fashion, i.e. each subprocess selects a part of the samples and writes them to the HDF5 file. The number of processes is determined by the number of devices. If you want to select correctly classified samples and return them, rather than storing them to a HDF5 file, use :func:`attribench.functional.select_samples` instead. """ def __init__( self, model_factory: ModelFactory, dataset: Dataset, num_samples: int, batch_size: int, address: str = "localhost", port: str = "12355", devices: Optional[Tuple] = None, ): """ Parameters ---------- model_factory : ModelFactory ModelFactory instance or callable that returns a model. Used to instantiate a model for each subprocess. dataset : Dataset Torch Dataset containing the samples and labels. writer : HDF5DatasetWriter Writer to write the samples and labels to. num_samples : int Number of correctly classified samples to select. batch_size : int Batch size per subprocess to use for the dataloader. address : str, optional Address to use for the multiprocessing connection, by default "localhost" port : str, optional Port to use for the multiprocessing connection, by default "12355" devices : Tuple, optional Devices to use. If None, then all available devices are used. By default None. """ super().__init__(address, port, devices) self.model_factory = model_factory self.dataset = dataset self.num_samples = num_samples self.batch_size = batch_size self.sufficient_samples = self.ctx.Event() self.count = 0 self.prog: tqdm | None = None self.writer: HDF5DatasetWriter | None = None def _create_worker( self, worker_config: WorkerConfig ): return SampleSelectionWorker( worker_config, self.sufficient_samples, self.batch_size, self.dataset, self.model_factory, )
[docs] def run(self, path: str): """Run the sample selection. Parameters ---------- path : str Path to the HDF5 file to write the samples to. """ self.writer = HDF5DatasetWriter(path, self.num_samples) self.prog = tqdm(total=self.num_samples) super().run()
def _handle_result(self, result: PartialResultMessage[SamplesResult]): assert self.writer is not None # If too many samples, truncate samples = result.data.samples labels = result.data.labels if self.count + samples.shape[0] > self.num_samples: samples = samples[: self.num_samples - self.count] labels = labels[: self.num_samples - self.count] self.sufficient_samples.set() # Write to disk self.writer.write(samples, labels) # Update progress bar self.count += samples.shape[0] if self.prog is not None: self.prog.update(samples.shape[0])