igc.bsc.BaselineShapley#
- class igc.bsc.BaselineShapley(module, dataset, dtld_kwargs=None, forward_method_name=None, forward_method_kwargs=None, n_embedding_categories=None, dtype=torch.float32, dtype_cat=torch.int32)[source]#
Bases:
AbstractAttributionMethodBaseline Shapley (BS).
See the original paper [SN20] for more information.
- Parameters:
module (torch.nn.Module) – PyTorch module defining the model under scrutiny.
dataset (torch.utils.data.Dataset) – PyTorch dataset providing inputs/outputs for any given index. See PyTorch documentation for more information. In addition, inputs must be organized in a specific manner, see warning below.
dtld_kwargs (dict) – Additional keyword arguments to the dataloaders (
torch.utils.data.DataLoader) constructed around thedataset, except:dataset,batch_size,shuffle,sampler,batch_sampler, andgenerator.forward_method_name (str) – Name of the forward method of the
module. IfNone, the defaultforwardis used.forward_method_kwargs (dict) – Additional keyword arguments to the forward method of the
module.n_embedding_categories (None | int) – Enable the computation of attributions for categorical inputs associated with
torch.nn.Embeddinglayers, by providing the number of embedding categories.dtype (torch.dtype) – Default data type of all intermediary tensors. It also defines the NumPy data type of the attribution results.
dtype_cat (torch.dtype) – Default data type of the categorical input tensors.
Notes
Note
Using categorical inputs with
torch.nn.Embeddinglayers modifies the output shape of attributions associated with this categorical input. The number of embedding categories is added at the end of the original shape.- compute(x, x_0=None, y_idx=None, n_iter=8, x_0_batch_size=1, x_seed=None, x_0_seed=100, check_error=True)[source]#
Compute Baseline Shapley (BS).
Warning
Baseline Shapley (BS) does not support multiple inputs.
- Parameters:
x (None | int | ArrayLike) –
None :
x_dtlditerates over the whole dataset.int : Number of
xinputs sampled from the dataset.ArrayLike : Set new
xused byx_dtld.
x_0 (None | int | float | ArrayLike) –
None : Zero baseline
x_0.int : Number of
x_0baselines sampled from the dataset.float : Constant baseline
x_0.ArrayLike : Set
x_0baselines used byx_0_dtld.
y_idx (None | int | ArrayLike) –
None :
y_idx_dtlditerates over all output component indicesy_idx.int : Select a specific output component index
y_idx.ArrayLike : Select multiple output component indices
y_idx.
n_iter (int) – Number of iterations, i.e. the number of random sequences of input component indices enabled one after the other.
x_0_batch_size (None | int) –
None : Set
x_0_bsz=n_x_0.int : Set
x_0_bsz.
x_seed (None | int) – Seed associated with
x_dtld.x_0_seed (None | int) – Seed associated with
x_0_dtld.check_error (bool) – If
True, the mean absolute error of BS approximations is reported. For each input, baseline, and output component, the completeness property of BS states that the sum of input component attributions must be equal to the difference between the model predictions associated with the input and baseline under scrutiny.
- Returns:
ArrayLike : sampled inputs
ArrayLike : corresponding true outputs
ArrayLike : model predictions of sampled inputs
ArrayLike : model predictions of baselines
ArrayLike : BS attributions of shape (
n_x,n_y_idx, * unbatchedxshape).
- Return type:
tuple