omnixai.explainers.prediction package
- class omnixai.explainers.prediction.PredictionAnalyzer(mode, test_data, test_targets, model=None, preprocess=None, postprocess=None, predict_function=None, **kwargs)
Bases:
ExplainerBase
The analysis for the prediction results of a classification or regression model:
analyzer = PredictionAnalyzer( mode="classification", # The task type, e.g., "classification" or "regression" test_data=test_data, # The test dataset (a `Tabular` instance) test_targets=test_labels, # The test labels (a numpy array) model=model, # The ML model preprocess=preprocess_function # Converts raw features into the model inputs ) prediction_explanations = analyzer.explain()
- Parameters
mode (
str
) – The task type, e.g., classification and regression.test_data – The test data.
test_data
contains the raw features of the test instances. Iftest_data
is aTabular
with a target/label column, this column is ignored (because the labels in this column are raw labels which are not processed by a LabelEncoder).test_targets – The test labels or targets. The specified targets by
test_targets
will be used to compute metrics and curves. For classification,test_targets
should be integers (processed by a LabelEncoder) and match the class probabilities returned by the ML model.model (
Optional
[Any
]) – The machine learning model to analyze, which can be a scikit-learn model, a tensorflow model, a torch model, or a black-box prediction function.preprocess (
Optional
[Callable
]) – The preprocessing function that converts the raw input features into the inputs ofmodel
.postprocess (
Optional
[Callable
]) – The postprocessing function that transforms the outputs ofmodel
to a user-specific form, e.g., the predicted class probabilities.predict_function (
Optional
[Callable
]) – The prediction function corresponding to the ML model. The outputs of thepredict_function
are the class probabilities. Ifpredict_function
is not None,PredictionAnalyzer
will ignoremodel
and usepredict_function
only to generate prediction results.
- explanation_type = 'prediction'
- alias = ['prediction']
- explain(**kwargs)
- Return type
Dict