attribench.distributed.metrics.MinimalSubset

class attribench.distributed.metrics.MinimalSubset(model_factory, attributions_dataset, batch_size, maskers, mode='deletion', num_steps=100, address='localhost', port='12355', devices=None)[source]

Bases: Metric

Computes Minimal Subset Deletion or Insertion for a given AttributionsDataset and model using multiple processes.

Minimal Subset Deletion or Insertion is computed by iteratively masking (Deletion) or revealing (Insertion) the top features of the input samples and computing the prediction of the model on the masked samples.

Minimal Subset Deletion is the minimal number of features that must be masked to change the model’s prediction from its original prediction. Minimal Subset Insertion is the minimal number of features that must be revealed to get the model’s original prediction.

The Minimal Subset metric is computed for each masker in maskers. The number of processes is determined by the number of devices. If devices is None, then all available devices are used. Samples are distributed evenly across the processes.

Parameters:
model_factoryModelFactory

ModelFactory instance or callable that returns a model. Used to create a model for each subprocess.

attributions_datasetAttributionsDataset

Dataset containing the samples and attributions to compute the Minimal Subset metric for.

batch_sizeint

Batch size per subprocess to use when computing the metric.

maskersDict[str, Masker]

Dictionary mapping masker names to Masker objects.

modestr, optional

“deletion” or “insertion”, by default “deletion”

num_stepsint, optional

Number of steps to use when computing the Minimal Subset metric, by default 100. More steps will result in a more accurate metric, but will take longer to compute.

addressstr, optional

Address to use for multiprocessing, by default “localhost”

portstr, optional

Port to use for multiprocessing, by default “12355”

devicesOptional[Tuple], optional

Tuple of devices to use for multiprocessing, by default None. If None, all available devices are used.

Raises:
ValueError

If mode is not “deletion” or “insertion”.

Methods

run

Runs the metric computation and optionally saves the result.

save_result

Save the result to disk.

Attributes

result

run(result_path=None, progress=True)

Runs the metric computation and optionally saves the result. If no result path is given, the result will not be saved to disk. It can still be accessed via the result property.

Parameters:
result_pathstr, optional

Path to save the result to. If None, the result is not saved to disk.

progressbool, optional

Whether to show a progress bar. Defaults to True.

save_result(path, format='hdf5')

Save the result to disk.

Parameters:
pathstr

Path to save the result to.

formatstr, optional

Format to save the result in. If "hdf5", the result is saved as an HDF5 file. If "csv", the result is saved as a directory structure containing CSV files. Default: "hdf5".

Raises:
ValueError

If the result is None.