omnixai.explainers.vision.counterfactual package

omnixai.explainers.vision.counterfactual.ce module

The counterfactual explainer for image classification.

class omnixai.explainers.vision.counterfactual.ce.CounterfactualExplainer(model, preprocess_function, mode='classification', c=10.0, kappa=10.0, binary_search_steps=5, learning_rate=0.01, num_iterations=100, grad_clip=1000.0, **kwargs)

Bases: ExplainerBase

The counterfactual explainer for image classification. If using this explainer, please cite the paper Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR, Sandra Wachter, Brent Mittelstadt, Chris Russell, https://arxiv.org/abs/1711.00399.

Parameters
  • model – The classification model which can be torch.nn.Module or tf.keras.Model.

  • preprocess_function (Callable) – The preprocessing function that converts the raw data into the inputs of model.

  • mode (str) – It can be classification only.

  • c – The weight of the hinge loss term.

  • kappa – The parameter in the hinge loss function.

  • binary_search_steps – The number of iterations to adjust the weight of the loss term.

  • learning_rate – The learning rate.

  • num_iterations – The maximum number of iterations during optimization.

  • grad_clip – The value for clipping gradients.

  • kwargs – Not used.

explanation_type = 'local'
alias = ['ce', 'counterfactual']
explain(X, **kwargs)

Generates the counterfactual explanations for the input images. Note that the returned results including the original input images and the counterfactual images have been processed by the preprocess_function, e.g., if the preprocess_function rescales [0, 255] to [0, 1], the return results will have range [0, 1].

Parameters

X (Image) – A batch of the input images.

Return type

CFExplanation

Returns

The counterfactual explanations for all the images, e.g., counterfactual images.