Source code for attribench.masking.image._image_masker

from attribench.masking import Masker
from attribench._segmentation import segment_attributions
from typing import List, Union, Optional, Tuple
import numpy as np
import torch
from abc import abstractmethod


[docs]class ImageMasker(Masker): """Abstract base class for all image maskers. Image maskers are specific maskers for image data. They can be used to mask features on the feature level or the pixel level. If the masker is used on the feature level, the attributions must have the same shape as the samples. If the masker is used on the pixel level, the attributions must have the same shape as the samples, except the channel dimension must be 1. If the masker is used on the feature level, masking a feature means masking a specific input feature for all samples (i.e. one color value for one pixel). If the masker is used on the pixel level, masking a feature means masking a specific pixel for all samples (i.e. all color values for one pixel). If segmented samples are provided, the masker will use the segments to mask features. This means that masking a feature will mask all pixels belonging to the same segment. The attribution value of a segment is defined as the average attribution value of all features in that segment. """ def __init__(self, masking_level: str): """ Parameters ---------- masking_level : str Either ``"feature"`` or ``"pixel"``. If ``"feature"``, the masker will mask features. If ``"pixel"``, the masker will mask pixels. Raises ------ ValueError If ``masking_level`` is not ``"feature"`` or ``"pixel"``. """ if masking_level not in ("feature", "pixel"): raise ValueError( f"feature_level must be 'feature' or 'pixel'." f" Found {masking_level}." ) self.masking_level = masking_level super().__init__() # will be set after initialize_batch: self.segmented_samples: Optional[torch.Tensor] = None self.segmented_attributions: Optional[np.ndarray] = None self.segment_indices: Optional[List[np.ndarray]] = None self.use_segments: bool = False self.sorted_indices: torch.Tensor | List[torch.Tensor] | None = None
[docs] def set_batch( self, samples: torch.Tensor, attributions: torch.Tensor | None = None, segmented_samples: torch.Tensor | None = None, ): """Set the batch of samples and attributions to use for masking. Optionally also set the segmented samples. Parameters ---------- samples : torch.Tensor Samples to use for masking. attributions : torch.Tensor, optional Attributions to use for masking, by default None If None, the :meth:`mask_top` and :meth:`mask_bot` methods will not be available. segmented_samples : torch.Tensor, optional Segmented samples to use for masking, by default None """ # Check if attributions are compatible with samples if attributions is not None and not self._check_attribution_shape( samples, attributions ): raise ValueError( f"samples and attribution shape not compatible for masking " f"level {self.masking_level}." f"Found shapes {samples.shape} and {attributions.shape}" ) # Check if segmented samples are compatible with samples if segmented_samples is not None: if not ( samples.shape[0] == segmented_samples.shape[0] and samples.shape[-2:] == segmented_samples.shape[-2:] ): raise ValueError( f"Incompatible shapes: {samples.shape}, {segmented_samples.shape}" ) if samples.device != segmented_samples.device: raise ValueError( "Device for samples and segmented_samples must be equal." f"Got {samples.device} for samples," f" {segmented_samples.device} for segmented_samples." ) # Set attributes self.samples = samples self.attributions = attributions self.segmented_samples = segmented_samples self.use_segments = segmented_samples is not None if segmented_samples is not None and attributions is not None: # If segmented samples and attributions are given, # segment the attributions accordingly # and sort the segments self.segmented_attributions = segment_attributions( segmented_samples.cpu().numpy(), attributions.cpu().numpy(), ) sorted_indices = torch.tensor( self.segmented_attributions.argsort() ) # Filter out the -np.inf values from the sorted indices filtered_sorted_indices = [] for i in range(segmented_samples.shape[0]): num_infs = np.count_nonzero( self.segmented_attributions[i, ...] == -np.inf ) filtered_sorted_indices.append(sorted_indices[i, num_infs:]) self.sorted_indices = filtered_sorted_indices elif attributions is not None: # If only attributions are given, sort them self.sorted_indices = torch.tensor( attributions.cpu() .numpy() .reshape(attributions.shape[0], -1) .argsort() ) if segmented_samples is not None: # Get the indices of the segments for each image self.segment_indices = [ np.unique(segmented_samples.cpu().numpy()[i, ...]) for i in range(samples.shape[0]) ] self._initialize_baselines(self.samples)
[docs] def get_num_features(self): assert self.samples is not None if self.use_segments: raise ValueError( "When using segments, total number of features varies per image." ) if self.masking_level == "feature": return self.samples.flatten(1).shape[-1] if self.masking_level == "pixel": return self.samples.flatten(2).shape[-1]
[docs] def mask_top(self, k: int): assert self.samples is not None assert self.sorted_indices is not None if not self.use_segments: return super().mask_top(k) if k == 0: return self.samples # When using segments, k is relative (between 0 and 1) indices = [] for i in range(self.samples.shape[0]): num_segments = len(self.sorted_indices[i]) num_to_mask = int(num_segments * k) indices.append(self.sorted_indices[i][-num_to_mask:]) return self._mask_segments(indices)
[docs] def mask_bot(self, k: int): assert self.samples is not None assert self.sorted_indices is not None if not self.use_segments: return super().mask_bot(k) if k == 0: return self.samples # When using segments, k is relative (between 0 and 1) indices = [] for i in range(self.samples.shape[0]): num_segments = len(self.sorted_indices[i]) num_to_mask = int(num_segments * k) indices.append(self.sorted_indices[i][:num_to_mask]) return self._mask_segments(indices)
[docs] def mask_rand( self, k: int, return_indices=False ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: assert self.samples is not None if k == 0: return self.samples rng = np.random.default_rng() if not self.use_segments: return super().mask_rand(k, return_indices) else: # this is done a little different form super(), # no shuffle here: for each image, only select segments that exist in that image # k can be large -> rng.choice raises exception if k > number of segments assert self.segment_indices is not None assert self.segmented_samples is not None indices = torch.stack( [ torch.tensor( rng.choice( self.segment_indices[i], size=k, replace=False ) ) for i in range(self.segmented_samples.shape[0]) ] ) masked_samples = self._mask_segments(indices) if return_indices: return masked_samples, indices return masked_samples
def _check_attribution_shape( self, samples: torch.Tensor, attributions: torch.Tensor ): if self.masking_level == "feature": # Attributions should be same shape as samples return list(samples.shape) == list(attributions.shape) elif self.masking_level == "pixel": # attributions should have the same shape as samples, # except the channel dimension must be 1 aggregated_shape = list(samples.shape) aggregated_shape[1] = 1 return aggregated_shape == list(attributions.shape) def _mask(self, indices: torch.Tensor) -> torch.Tensor: if self.baseline is None: raise ValueError("Masker was not initialized.") if self.use_segments: return self._mask_segments(indices) else: assert self.samples is not None batch_size, _, _, _ = self.samples.shape num_indices = indices.shape[1] batch_dim = np.tile( range(batch_size), (num_indices, 1) ).transpose() to_mask = torch.zeros( self.samples.shape, device=self.samples.device ).flatten(1 if self.masking_level == "feature" else 2) # to_mask = np.zeros(self.samples.shape) if self.masking_level == "feature": to_mask[batch_dim, indices] = 1.0 else: try: to_mask[batch_dim, :, indices] = 1.0 except IndexError: raise ValueError( "Masking index was out of bounds. " "Make sure the masking policy is compatible with method output." ) return self._mask_boolean(to_mask.view(self.samples.shape).bool()) def _mask_segments( self, segments: Union[torch.Tensor, List[torch.Tensor]] ) -> torch.Tensor: assert self.segmented_samples is not None assert self.samples is not None if not self.segmented_samples.shape[0] == len(segments): raise ValueError( f"Number of segment lists doesn't match number of images:" f" {self.segmented_samples.shape[0]}" f"were expected, {len(segments)} were given." ) bool_masks = [] for i in range(self.samples.shape[0]): seg_img = self.segmented_samples[i, ...] segs = segments[i].to(seg_img.device) bool_masks.append(_isin(seg_img, segs)) bool_masks = torch.stack(bool_masks, dim=0) if self.samples.shape[1] == 3: bool_masks = bool_masks.repeat(1, 3, 1, 1) return self._mask_boolean(bool_masks) def _mask_boolean(self, bool_mask: torch.Tensor) -> torch.Tensor: return ( self.samples - (bool_mask * self.samples) + (bool_mask * self.baseline) ) @abstractmethod def _initialize_baselines(self, samples: torch.Tensor): raise NotImplementedError
def _isin(a: torch.Tensor, b: torch.Tensor): # https://stackoverflow.com/questions/60918304/get-indices-of-elements-in-tensor-a-that-are-present-in-tensor-b return (a[..., None] == b).any(-1)