Source code for attribench._model_factory

from torch import nn
from abc import abstractmethod
import copy


[docs]class ModelFactory: """ Basic interface for a callable that returns a model. This is necessary for multi-GPU mode, as each subprocess needs its own copy of the model. """
[docs] @abstractmethod def __call__(self) -> nn.Module: """Return a model. Returns ------- nn.Module The model. """ raise NotImplementedError
[docs]class BasicModelFactory(ModelFactory): """ Basic implementation of a :class:`ModelFactory` that just returns a deep copy of a given model. This can be used for simple models, but requires the model to be picklable (e.g. lambda layers won't work and require a specific implementation of :class:`ModelFactory`) """ def __init__(self, base_model: nn.Module) -> None: """ Parameters ---------- base_model : nn.Module Model of which to return a deep copy. """ self.base_model = base_model
[docs] def __call__(self) -> nn.Module: """Return a deep copy of the base model. Returns ------- nn.Module Deep copy of the base model. """ return copy.deepcopy(self.base_model)