"""
=========================
4. Prediction Performance
=========================
This file analyzes the prediction performance of LGBMClassifier.

"""
import numpy as np
import matplotlib.pyplot as plt

from utils import SAVE
from utils import version_info
from utils import set_rcParams
from utils import return_train_test
from utils import get_no_gene_model
from ai4water.postprocessing import ProcessPredictions
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay

# %%
version_info()

# %%

set_rcParams()

roc_func = RocCurveDisplay.from_estimator
pr_func = PrecisionRecallDisplay.from_estimator

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2,
                               figsize=(9,9))

for target in ["MCR-1", "OXA48", "TEM", "CTX-M"]:

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')

    model = get_no_gene_model(target, True, 'LGBMClassifier')

    tr_kws = {'estimator': model,
           'X': TrainX,
           'y': TrainY.values,
           'ax': ax1,
           'name': target
           }

    roc_func(**tr_kws)
    ax1.grid(ls='--', color='lightgrey')
    ax1.set_title("Training")

    test_kws = {'estimator': model,
           'X': TestX,
           'y': TestY.values,
           'ax': ax2,
           'name': target
           }

    roc_func(**test_kws)

    ax2.grid(ls='--', color='lightgrey')
    ax2.set_title("Test")
    ax2.set_ylabel('')

    tr_kws['ax'] = ax3
    pr_func(**tr_kws)
    ax3.grid(ls='--', color='lightgrey')

    test_kws['ax'] = ax4
    pr_func(**test_kws)
    ax4.set_ylabel('')
    ax4.grid(ls='--', color='lightgrey')

if SAVE:
    plt.savefig(f"results/figures/roc_pr.png", dpi=600, bbox_inches="tight")

plt.show()

# %%

fig, axes = plt.subplots(2, 2, sharex="all",
                         sharey="all")

targets = ["MCR-1", "OXA48", "TEM", "CTX-M"]
idx = 0
for target, ax in zip(targets, axes.flat):

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')

    X = np.concatenate([TrainX.values, TestX.values])
    y = np.concatenate([TrainY.values, TestY.values])

    model = get_no_gene_model(target, True, 'LGBMClassifier')

    pred = model.predict(X)

    processor = ProcessPredictions('classification',
                                   show=False,
                                   save=False)

    im = processor.confusion_matrix(
        y, pred,
        ax=ax,
        cbar_params = {"border": False},
        annotate_kws = {'fontsize': 20, "fmt": '%.f', 'ha':"center"})
    ax_ = im.axes

    ax_.set_title(target)

    if idx in [1, 3]:
        ax_.set_ylabel('')

    if idx in [0, 1]:
        ax_.set_xlabel('')

    idx += 1

if SAVE:
    plt.savefig(f"results/figures/confusion_metrics.png", dpi=600, bbox_inches="tight")
plt.show()