attribench.functional.select_samples
- attribench.functional.select_samples(model, dataset, num_samples, batch_size, writer=None, device=None)[source]
Select correctly classified samples from a dataset and optionally write them to a HDF5 file. If the writer is None, the samples and labels are simply returned. Otherwise, the samples and labels are written to the HDF5 file and None is returned.
TODO this function should just return the samples and labels. Use the distributed class to write the samples and labels to a file.
- Parameters:
- modelnn.Module
Model to use for classification.
- datasetDataset
Torch Dataset containing the samples and labels.
- writerHDF5DatasetWriter
Writer to write the samples and labels to, by default None.
- num_samplesint
Number of correctly classified samples to select.
- batch_sizeint
Batch size to use for the dataloader.
- deviceOptional[torch.device], optional
Device to use, by default None.
- Returns:
- Tuple[torch.Tensor, torch.Tensor] | None
If writer is None, a tuple containing the correctly classified samples and their labels. Otherwise, None.
- Return type:
Optional[Tuple[Tensor,Tensor]]