igc.base.AbstractAttributionMethod#

class igc.base.AbstractAttributionMethod(module, dataset, dtld_kwargs=None, forward_method_name=None, forward_method_kwargs=None, dtype=torch.float32, dtype_cat=torch.int32)[source]#

Bases: object

Define the base class for an abstract attribution method.

The sub-classes are expected to implement a compute() method, specific to each attribution method.

Parameters:
  • module (torch.nn.Module) – PyTorch module defining the model under scrutiny.

  • dataset (torch.utils.data.Dataset) – PyTorch dataset providing inputs/outputs for any given index. See PyTorch documentation for more information. In addition, inputs must be organized in a specific manner, see warning below.

  • dtld_kwargs (dict) – Additional keyword arguments to the dataloader (torch.utils.data.DataLoader) constructed around the dataset, except: dataset, batch_size, shuffle, sampler, batch_sampler, and generator.

  • forward_method_name (str) – Name of the forward method of the module. If None, the default forward is used.

  • forward_method_kwargs (dict) – Additional keyword arguments to the forward method of the module.

  • dtype (torch.dtype) – Default data type of all intermediary tensors. It also defines the numpy data type of the attribution results.

  • dtype_cat (torch.dtype) – Default data type of the categorical input tensors.

Notes

Warning

When computing attributions on models using multiple inputs, e.g., x_1, x_2, and x_cat, with x_cat a categorical input, the dataset must return all inputs packed in a tuple, such as: (x_1, x_2, x_cat), y. Note that categorical inputs must be placed at the end of the tuple.

add_embedding_method(embedding_method_name, embedding_method_kwargs=None, embedding_n_cat=None)[source]#

Add an embedding method to preprocess categorical inputs.

Warning

This effect of this method must not excluded from the forward method defined by forward_method_name at initialization.

Parameters:
  • embedding_method_name (str) – Name of the embedding method of the module.

  • embedding_method_kwargs (dict) – Additional keyword arguments to the embedding method of the module.

  • embedding_n_cat (int) – Number of categorical inputs. If None, this value is inferred from the input data types (torch.int16, torch.int32, torch.int64).

Return type:

self

abstractmethod compute()[source]#

Abstract method computing attributions.