omnixai.explainers.prediction package

class omnixai.explainers.prediction.PredictionAnalyzer(mode, test_data, test_targets, model=None, preprocess=None, postprocess=None, predict_function=None, **kwargs)

Bases: ExplainerBase

The analysis for the prediction results of a classification or regression model:

analyzer = PredictionAnalyzer(
    mode="classification",            # The task type, e.g., "classification" or "regression"
    test_data=test_data,              # The test dataset (a `Tabular` instance)
    test_targets=test_labels,         # The test labels (a numpy array)
    model=model,                      # The ML model
    preprocess=preprocess_function    # Converts raw features into the model inputs
)
prediction_explanations = analyzer.explain()
Parameters
  • mode (str) – The task type, e.g., classification and regression.

  • test_data – The test data. test_data contains the raw features of the test instances. If test_data is a Tabular with a target/label column, this column is ignored (because the labels in this column are raw labels which are not processed by a LabelEncoder).

  • test_targets – The test labels or targets. The specified targets by test_targets will be used to compute metrics and curves. For classification, test_targets should be integers (processed by a LabelEncoder) and match the class probabilities returned by the ML model.

  • model (Optional[Any]) – The machine learning model to analyze, which can be a scikit-learn model, a tensorflow model, a torch model, or a black-box prediction function.

  • preprocess (Optional[Callable]) – The preprocessing function that converts the raw input features into the inputs of model.

  • postprocess (Optional[Callable]) – The postprocessing function that transforms the outputs of model to a user-specific form, e.g., the predicted class probabilities.

  • predict_function (Optional[Callable]) – The prediction function corresponding to the ML model. The outputs of the predict_function are the class probabilities. If predict_function is not None, PredictionAnalyzer will ignore model and use predict_function only to generate prediction results.

explanation_type = 'prediction'
alias = ['prediction']
explain(**kwargs)
Return type

Dict