Source code for attribench.functional._select_samples

from torch import nn
from torch.utils.data import Dataset, DataLoader
from attribench.data import HDF5DatasetWriter
import torch
from typing import Optional, Tuple


def _select_samples_batch(
    batch_x: torch.Tensor,
    batch_y: torch.Tensor,
    model: nn.Module,
    device: torch.device,
):
    """Returns the correctly classified samples and their labels.

    Parameters
    ----------
    batch_x : torch.Tensor
    batch_y : torch.Tensor
    model : nn.Module
    device : torch.device

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor]
        The correctly classified samples and their labels.
    """
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)
    with torch.no_grad():
        output = torch.argmax(model(batch_x), dim=1)
    correct_samples = batch_x[output == batch_y, ...]
    correct_labels = batch_y[output == batch_y]
    return correct_samples, correct_labels


[docs]def select_samples( model: nn.Module, dataset: Dataset, num_samples: int, batch_size: int, writer: Optional[HDF5DatasetWriter] = None, device: Optional[torch.device] = None, ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """Select correctly classified samples from a dataset and optionally write them to a HDF5 file. If the `writer` is `None`, the samples and labels are simply returned. Otherwise, the samples and labels are written to the HDF5 file and `None` is returned. TODO this function should just return the samples and labels. Use the distributed class to write the samples and labels to a file. Parameters ---------- model : nn.Module Model to use for classification. dataset : Dataset Torch Dataset containing the samples and labels. writer : HDF5DatasetWriter Writer to write the samples and labels to, by default None. num_samples : int Number of correctly classified samples to select. batch_size : int Batch size to use for the dataloader. device : Optional[torch.device], optional Device to use, by default None. Returns ------- Tuple[torch.Tensor, torch.Tensor] | None If `writer` is `None`, a tuple containing the correctly classified samples and their labels. Otherwise, `None`. """ if device is None: device = torch.device("cpu") model.to(device) model.eval() dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=4, pin_memory=True ) samples_count = 0 all_correct_samples, all_correct_labels = [], [] for batch_x, batch_y in dataloader: correct_samples, correct_labels = _select_samples_batch( batch_x, batch_y, model, device ) if len(correct_samples) > 0: if writer is None: all_correct_samples.append(correct_samples) all_correct_labels.append(correct_labels) else: writer.write( correct_samples.cpu().numpy(), correct_labels.cpu().numpy() ) samples_count += len(correct_samples) if samples_count >= num_samples: break if writer is None: return torch.cat(all_correct_samples), torch.cat(all_correct_labels)