Source code for attribench.functional._compute_attributions

from tqdm import tqdm
from typing import Dict, Optional, List
from torch import nn
from torch.utils.data import Dataset, DataLoader
from attribench.data import AttributionsDatasetWriter, IndexDataset
from attribench import AttributionMethod
import torch


[docs]def compute_attributions( model: nn.Module, method_dict: Dict[str, AttributionMethod], dataset: Dataset, batch_size: int, writer: Optional[AttributionsDatasetWriter] = None, device: Optional[torch.device] = None, ) -> Optional[Dict[str, torch.Tensor]]: """Compute attributions for a given model and dataset using a dictionary of attribution methods, and optionally write them to a HDF5 file. If the `writer` is `None`, the attributions are simply returned in a dictionary. Otherwise, the attributions are written to the HDF5 file and `None` is returned. TODO don't write to file, just return the dict Parameters ---------- model : nn.Module The model for which the attributions should be computed. method_dict : Dict[str, AttributionMethod] Dictionary of attribution methods. dataset : Dataset Torch Dataset to use for computing the attributions. batch_size : int The batch size to use for computing the attributions. writer : Optional[AttributionsDatasetWriter], optional AttributionsDatasetWriter to write the attributions to, by default `None`. If `None`, the attributions are returned in a dictionary. device : Optional[torch.device], optional Device to use, by default `None`. If `None`, the CPU is used. Returns ------- Optional[Dict[str, torch.Tensor]] If `writer` is `None`, a dictionary of attributions. """ if device is None: device = torch.device("cpu") model.to(device) model.eval() index_dataset = IndexDataset(dataset) dataloader = DataLoader( index_dataset, batch_size=batch_size, num_workers=4, pin_memory=True, ) result_dict: Dict[str, List[torch.Tensor]] = {method_name: [ torch.zeros(1) for _ in range(len(index_dataset)) ] for method_name in method_dict.keys()} for batch_indices, batch_x, batch_y in tqdm(dataloader): batch_x = batch_x.to(device) batch_y = batch_y.to(device) for method_name, method in method_dict.items(): with torch.no_grad(): attrs = method(batch_x, batch_y) if writer is None: for idx in batch_indices: result_dict[method_name][idx] = attrs[idx, ...].cpu() else: writer.write( batch_indices.cpu().numpy(), attrs.cpu().numpy(), method_name, ) if writer is None: result_dict_cat = { method_name: torch.cat(attrs_list) for method_name, attrs_list in result_dict.items() } return result_dict_cat