omnixai.explainers.vision.counterfactual package
omnixai.explainers.vision.counterfactual.ce module
The counterfactual explainer for image classification.
- class omnixai.explainers.vision.counterfactual.ce.CounterfactualExplainer(model, preprocess_function, mode='classification', c=10.0, kappa=10.0, binary_search_steps=5, learning_rate=0.01, num_iterations=100, grad_clip=1000.0, **kwargs)
Bases:
ExplainerBase
The counterfactual explainer for image classification. If using this explainer, please cite the paper Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR, Sandra Wachter, Brent Mittelstadt, Chris Russell, https://arxiv.org/abs/1711.00399.
- Parameters
model – The classification model which can be torch.nn.Module or tf.keras.Model.
preprocess_function (
Callable
) – The preprocessing function that converts the raw data into the inputs ofmodel
.mode (
str
) – It can be classification only.c – The weight of the hinge loss term.
kappa – The parameter in the hinge loss function.
binary_search_steps – The number of iterations to adjust the weight of the loss term.
learning_rate – The learning rate.
num_iterations – The maximum number of iterations during optimization.
grad_clip – The value for clipping gradients.
kwargs – Not used.
- explanation_type = 'local'
- alias = ['ce', 'counterfactual']
- explain(X, **kwargs)
Generates the counterfactual explanations for the input images. Note that the returned results including the original input images and the counterfactual images have been processed by the
preprocess_function
, e.g., if thepreprocess_function
rescales [0, 255] to [0, 1], the return results will have range [0, 1].- Parameters
X (
Image
) – A batch of the input images.- Return type
- Returns
The counterfactual explanations for all the images, e.g., counterfactual images.