4. Prediction Performance

This file analyzes the prediction performance of LGBMClassifier.

import numpy as np
import matplotlib.pyplot as plt

from utils import SAVE
from utils import version_info
from utils import set_rcParams
from utils import return_train_test
from utils import get_no_gene_model
from ai4water.postprocessing import ProcessPredictions
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay
version_info()
python 3.7.17 (default, Jun 25 2023, 21:20:48)
[GCC 11.3.0]
os posix
ai4water 1.07
lightgbm 4.0.0
easy_mpl 0.21.4
SeqMetrics 1.3.4
tensorflow 1.15.0
tensorflow.python.keras.api._v1.keras 2.2.4-tf
numpy 1.19.5
pandas 1.3.5
matplotlib 3.5.3
h5py 2.10.0
sklearn 1.0.2
optuna 3.2.0
skopt 0.9.0
seaborn 0.12.2
shap 0.41.0
set_rcParams()

roc_func = RocCurveDisplay.from_estimator
pr_func = PrecisionRecallDisplay.from_estimator

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2,
                               figsize=(9,9))

for target in ["MCR-1", "OXA48", "TEM", "CTX-M"]:

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')

    model = get_no_gene_model(target, True, 'LGBMClassifier')

    tr_kws = {'estimator': model,
           'X': TrainX,
           'y': TrainY.values,
           'ax': ax1,
           'name': target
           }

    roc_func(**tr_kws)
    ax1.grid(ls='--', color='lightgrey')
    ax1.set_title("Training")

    test_kws = {'estimator': model,
           'X': TestX,
           'y': TestY.values,
           'ax': ax2,
           'name': target
           }

    roc_func(**test_kws)

    ax2.grid(ls='--', color='lightgrey')
    ax2.set_title("Test")
    ax2.set_ylabel('')

    tr_kws['ax'] = ax3
    pr_func(**tr_kws)
    ax3.grid(ls='--', color='lightgrey')

    test_kws['ax'] = ax4
    pr_func(**test_kws)
    ax4.set_ylabel('')
    ax4.grid(ls='--', color='lightgrey')

if SAVE:
    plt.savefig(f"results/figures/roc_pr.png", dpi=600, bbox_inches="tight")

plt.show()
Training, Test
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
fig, axes = plt.subplots(2, 2, sharex="all",
                         sharey="all")

targets = ["MCR-1", "OXA48", "TEM", "CTX-M"]
idx = 0
for target, ax in zip(targets, axes.flat):

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')

    X = np.concatenate([TrainX.values, TestX.values])
    y = np.concatenate([TrainY.values, TestY.values])

    model = get_no_gene_model(target, True, 'LGBMClassifier')

    pred = model.predict(X)

    processor = ProcessPredictions('classification',
                                   show=False,
                                   save=False)

    im = processor.confusion_matrix(
        y, pred,
        ax=ax,
        cbar_params = {"border": False},
        annotate_kws = {'fontsize': 20, "fmt": '%.f', 'ha':"center"})
    ax_ = im.axes

    ax_.set_title(target)

    if idx in [1, 3]:
        ax_.set_ylabel('')

    if idx in [0, 1]:
        ax_.set_xlabel('')

    idx += 1

if SAVE:
    plt.savefig(f"results/figures/confusion_metrics.png", dpi=600, bbox_inches="tight")
plt.show()
MCR-1, OXA48, TEM, CTX-M
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.

Total running time of the script: ( 0 minutes 10.007 seconds)

Gallery generated by Sphinx-Gallery