igc.base.AbstractAttributionMethod#
- class igc.base.AbstractAttributionMethod(module, dataset, dtld_kwargs=None, forward_method_name=None, forward_method_kwargs=None, dtype=torch.float32, dtype_cat=torch.int32)[source]#
Bases:
object
Define the base class for an abstract attribution method.
The sub-classes are expected to implement a
compute()
method, specific to each attribution method.- Parameters:
module (torch.nn.Module) – PyTorch module defining the model under scrutiny.
dataset (torch.utils.data.Dataset) – PyTorch dataset providing inputs/outputs for any given index. See PyTorch documentation for more information. In addition, inputs must be organized in a specific manner, see warning below.
dtld_kwargs (dict) – Additional keyword arguments to the dataloader (
torch.utils.data.DataLoader
) constructed around thedataset
, except:dataset
,batch_size
,shuffle
,sampler
,batch_sampler
, andgenerator
.forward_method_name (str) – Name of the forward method of the
module
. IfNone
, the defaultforward
is used.forward_method_kwargs (dict) – Additional keyword arguments to the forward method of the
module
.dtype (torch.dtype) – Default data type of all intermediary tensors. It also defines the numpy data type of the attribution results.
dtype_cat (torch.dtype) – Default data type of the categorical input tensors.
Notes
Warning
When computing attributions on models using multiple inputs, e.g., x_1, x_2, and x_cat, with x_cat a categorical input, the dataset must return all inputs packed in a tuple, such as: (x_1, x_2, x_cat), y. Note that categorical inputs must be placed at the end of the tuple.
- add_embedding_method(embedding_method_name, embedding_method_kwargs=None, embedding_n_cat=None)[source]#
Add an embedding method to preprocess categorical inputs.
Warning
This effect of this method must not excluded from the forward method defined by
forward_method_name
at initialization.- Parameters:
embedding_method_name (str) – Name of the embedding method of the
module
.embedding_method_kwargs (dict) – Additional keyword arguments to the embedding method of the
module
.embedding_n_cat (int) – Number of categorical inputs. If
None
, this value is inferred from the input data types (torch.int16
,torch.int32
,torch.int64
).
- Return type:
self