Source code for attribench.data._index_dataset

from collections.abc import Sized
from torch.utils.data import Dataset


[docs]class IndexDataset(Dataset): """Wraps a dataset to return the index of the sample as well. Used internally to keep track of the index of a sample in a dataset. """ def __init__(self, dataset: Dataset): """ Parameters ---------- dataset : Dataset The PyTorch Dataset to wrap. """ super().__init__() self.dataset: Dataset = dataset def __len__(self): if isinstance(self.dataset, Sized): return len(self.dataset) else: raise ValueError("Base dataset has no __len__") def __getitem__(self, item): data, target = self.dataset[item] return item, data, target