attribench.functional.select_samples

attribench.functional.select_samples(model, dataset, num_samples, batch_size, writer=None, device=None)[source]

Select correctly classified samples from a dataset and optionally write them to a HDF5 file. If the writer is None, the samples and labels are simply returned. Otherwise, the samples and labels are written to the HDF5 file and None is returned.

TODO this function should just return the samples and labels. Use the distributed class to write the samples and labels to a file.

Parameters:
modelnn.Module

Model to use for classification.

datasetDataset

Torch Dataset containing the samples and labels.

writerHDF5DatasetWriter

Writer to write the samples and labels to, by default None.

num_samplesint

Number of correctly classified samples to select.

batch_sizeint

Batch size to use for the dataloader.

deviceOptional[torch.device], optional

Device to use, by default None.

Returns:
Tuple[torch.Tensor, torch.Tensor] | None

If writer is None, a tuple containing the correctly classified samples and their labels. Otherwise, None.

Return type:

Optional[Tuple[Tensor, Tensor]]