Note
Go to the end to download the full example code
4. Prediction Performance
This file analyzes the prediction performance of LGBMClassifier.
import numpy as np
import matplotlib.pyplot as plt
from utils import SAVE
from utils import version_info
from utils import set_rcParams
from utils import return_train_test
from utils import get_no_gene_model
from ai4water.postprocessing import ProcessPredictions
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay
version_info()
python 3.7.17 (default, Jun 25 2023, 21:20:48)
[GCC 11.3.0]
os posix
ai4water 1.07
lightgbm 4.0.0
easy_mpl 0.21.4
SeqMetrics 1.3.4
tensorflow 1.15.0
tensorflow.python.keras.api._v1.keras 2.2.4-tf
numpy 1.19.5
pandas 1.3.5
matplotlib 3.5.3
h5py 2.10.0
sklearn 1.0.2
optuna 3.2.0
skopt 0.9.0
seaborn 0.12.2
shap 0.41.0
set_rcParams()
roc_func = RocCurveDisplay.from_estimator
pr_func = PrecisionRecallDisplay.from_estimator
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2,
figsize=(9,9))
for target in ["MCR-1", "OXA48", "TEM", "CTX-M"]:
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')
model = get_no_gene_model(target, True, 'LGBMClassifier')
tr_kws = {'estimator': model,
'X': TrainX,
'y': TrainY.values,
'ax': ax1,
'name': target
}
roc_func(**tr_kws)
ax1.grid(ls='--', color='lightgrey')
ax1.set_title("Training")
test_kws = {'estimator': model,
'X': TestX,
'y': TestY.values,
'ax': ax2,
'name': target
}
roc_func(**test_kws)
ax2.grid(ls='--', color='lightgrey')
ax2.set_title("Test")
ax2.set_ylabel('')
tr_kws['ax'] = ax3
pr_func(**tr_kws)
ax3.grid(ls='--', color='lightgrey')
test_kws['ax'] = ax4
pr_func(**test_kws)
ax4.set_ylabel('')
ax4.grid(ls='--', color='lightgrey')
if SAVE:
plt.savefig(f"results/figures/roc_pr.png", dpi=600, bbox_inches="tight")
plt.show()

A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
fig, axes = plt.subplots(2, 2, sharex="all",
sharey="all")
targets = ["MCR-1", "OXA48", "TEM", "CTX-M"]
idx = 0
for target, ax in zip(targets, axes.flat):
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')
X = np.concatenate([TrainX.values, TestX.values])
y = np.concatenate([TrainY.values, TestY.values])
model = get_no_gene_model(target, True, 'LGBMClassifier')
pred = model.predict(X)
processor = ProcessPredictions('classification',
show=False,
save=False)
im = processor.confusion_matrix(
y, pred,
ax=ax,
cbar_params = {"border": False},
annotate_kws = {'fontsize': 20, "fmt": '%.f', 'ha':"center"})
ax_ = im.axes
ax_.set_title(target)
if idx in [1, 3]:
ax_.set_ylabel('')
if idx in [0, 1]:
ax_.set_xlabel('')
idx += 1
if SAVE:
plt.savefig(f"results/figures/confusion_metrics.png", dpi=600, bbox_inches="tight")
plt.show()

A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
[LightGBM] [Warning] Found boosting=goss. For backwards compatibility reasons, LightGBM interprets this as boosting=gbdt, data_sample_strategy=goss.To suppress this warning, set data_sample_strategy=goss instead.
Total running time of the script: ( 0 minutes 10.007 seconds)