Source code for attribench.distributed._compute_attributions

from ._message import PartialResultMessage
from ._distributed_computation import DistributedComputation
from ._distributed_sampler import DistributedSampler
from ._worker import Worker, WorkerConfig
from attribench._model_factory import ModelFactory

from attribench._method_factory import MethodFactory
from attribench.data import AttributionsDatasetWriter, IndexDataset
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, Optional
import torch
from numpy import typing as npt
from tqdm import tqdm


class AttributionResult:
    def __init__(
        self, indices: npt.NDArray, attributions: npt.NDArray, method_name: str
    ):
        self.indices = indices
        self.attributions = attributions
        self.method_name = method_name


class AttributionsWorker(Worker):
    def __init__(
        self,
        worker_config: WorkerConfig,
        model_factory: ModelFactory,
        method_factory: MethodFactory,
        dataset: IndexDataset,
        batch_size: int,
    ):
        super().__init__(worker_config)
        self.batch_size = batch_size
        self.dataset = dataset
        self.method_factory = method_factory
        self.model_factory = model_factory

    def work(self):
        sampler = DistributedSampler(
            self.dataset,
            self.worker_config.world_size,
            self.worker_config.rank,
            shuffle=False,
        )
        dataloader = DataLoader(
            self.dataset,
            sampler=sampler,
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True,
        )
        device = torch.device(self.worker_config.rank)
        model = self.model_factory()
        model.to(device)
        method_dict = self.method_factory(model)

        for batch_indices, batch_x, batch_y in dataloader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            for method_name, method in method_dict.items():
                with torch.no_grad():
                    attrs = method(batch_x, batch_y)
                    result = AttributionResult(
                        batch_indices.cpu().numpy(),
                        attrs.cpu().numpy(),
                        method_name,
                    )
                    self.worker_config.send_result(
                        PartialResultMessage(self.worker_config.rank, result)
                    )


[docs]class ComputeAttributions(DistributedComputation): """Compute attributions for a dataset using multiple processes. The attributions are written to a HDF5 file. The number of processes is determined by the number of devices. If no devices are specified, then all available devices are used. Samples are distributed evenly across the processes. If you want to compute attributions and simply return them, rather than storing them in a file, then use the :func:`~attribench.functional.compute_attributions` function instead. """ def __init__( self, model_factory: ModelFactory, method_factory: MethodFactory, dataset: Dataset, batch_size: int, address="localhost", port="12355", devices: Optional[Tuple] = None, ): """ Parameters ---------- model_factory : ModelFactory ModelFactory instance or callable that returns a model. Used to create a model for each subprocess. method_factory : MethodFactory MethodFactory instance or callable that returns a dictionary of attribution methods, given a model. dataset : Dataset Torch Dataset to use for computing the attributions. batch_size : int The batch size to use for computing the attributions. writer : AttributionsDatasetWriter AttributionsDatasetWriter to write the attributions to. address : str, optional Address to use for the multiprocessing connection. By default "localhost". port : str, optional Port to use for the multiprocessing connection. By default "12355". devices : Optional[Tuple], optional Devices to use. If None, then all available devices are used. By default None. """ super().__init__(address, port, devices) self.model_factory = model_factory self.method_factory = method_factory self.dataset = IndexDataset(dataset) self.batch_size = batch_size self.prog: tqdm | None = None self.writer: AttributionsDatasetWriter | None = None
[docs] def run(self, path: str): """Run the computation. Parameters ---------- path : str Path to the HDF5 file to write the attributions to. """ self.writer = AttributionsDatasetWriter( path, num_samples=len(self.dataset), ) self.prog = tqdm(total=len(self.dataset) * len(self.method_factory)) super().run()
def _cleanup(self): if self.prog is not None: self.prog.close() def _create_worker( self, worker_config: WorkerConfig ) -> Worker: return AttributionsWorker( worker_config, self.model_factory, self.method_factory, self.dataset, self.batch_size, ) def _handle_result( self, result_message: PartialResultMessage[AttributionResult] ): assert self.writer is not None indices = result_message.data.indices attributions = result_message.data.attributions method_name = result_message.data.method_name self.writer.write(indices, attributions, method_name) if self.prog is not None: self.prog.update(len(indices))