attribench.masking.Masker

class attribench.masking.Masker[source]

Bases: object

Base class for all maskers. Maskers are used to “remove” features from a sample by masking them with some value. This can be a fixed baseline value, a random value, or some other value.

Note that a Masker object is not yet usable after creation. You need to call set_batch() first, to set the samples and attributions. This allows the same Masker object to be used for multiple batches.

Methods

get_num_features

Return the number of features in the samples.

mask_bot

Mask the k least important features, according to the attributions.

mask_rand

Mask k random features.

mask_top

Mask the k most important features, according to the attributions.

set_batch

Set the samples and attributions for the next batch.

get_num_features()[source]

Return the number of features in the samples.

Returns:
int

Number of features in the samples.

Return type:

int

mask_bot(k)[source]

Mask the k least important features, according to the attributions.

Parameters:
kint

Number of features to mask.

Returns:
torch.Tensor

Samples with the bottom k features masked.

Return type:

Tensor

mask_rand(k, return_indices=False)[source]

Mask k random features.

Parameters:
kint

Number of features to mask.

return_indicesbool, optional

Whether to return the indices of the masked features, by default False

Returns:
torch.Tensor

Samples with k random features masked.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

mask_top(k)[source]

Mask the k most important features, according to the attributions.

Parameters:
kint

Number of features to mask.

Returns:
torch.Tensor

Samples with the top k features masked.

Return type:

Tensor

abstract set_batch(samples, attributions=None)[source]

Set the samples and attributions for the next batch.

Parameters:
samplestorch.Tensor

Samples of shape [num_samples, *sample_shape].

attributionstorch.Tensor, optional

Attributions of shape [num_samples, *sample_shape], by default None If None, the mask_top() and mask_bot() methods will not be available.