"""
===================
2. ML Experiments
===================

"""

import matplotlib.pyplot as plt
from SeqMetrics import ClassificationMetrics
from ai4water.experiments import MLClassificationExperiments

from utils import SAVE
from utils import version_info, return_train_test, set_rcParams

# %%

version_info()

# %%

set_rcParams()

# %%

def f1_weighted(t,p)->float:
    return ClassificationMetrics(t, p).f1_score(average="weighted")

def f1_micro(t,p)->float:
    return ClassificationMetrics(t, p).f1_score(average="micro")

def precision_weighted(t,p)->float:
    return ClassificationMetrics(t, p).precision(average="weighted")

def precision_micro(t,p)->float:
    return ClassificationMetrics(t, p).precision(average="micro")

def recall_weighted(t,p)->float:
    return ClassificationMetrics(t, p).recall(average="weighted")

def recall_micro(t,p)->float:
    return ClassificationMetrics(t, p).recall(average="micro")

# %%
# WQ and Antibiotics
# ==================
scenario = 'no_genes'

# %%
# CTX-M
# ------

target = 'CTX-M'

TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%

_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()
# %%
# OXA48
# -------

target = 'OXA48'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%

comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()
# %%
# TEM
# ------

target = 'TEM'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()
# %%
# MCR-1
# ------

target = 'MCR-1'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8),)
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()

# %%
# Only WQ
# =========
scenario = 'only_wq'


# %%
# CTX-M
# ------

target = 'CTX-M'

TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%

_ = comparisons.compare_errors('f1_score',
                               x=TestX,
                               y=TestY.values,
                               colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
                                figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()

# %%
# OXA48
# -------

target = 'OXA48'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%

comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()

# %%
# TEM
# ------

target = 'TEM'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()

# %%
# MCR-1
# ------

target = 'MCR-1'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()

# %%
# Only Antibiotics
# ================
scenario = 'only_antibiotics'


# %%
# CTX-M
# ------

target = 'CTX-M'

TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%

_ = comparisons.compare_errors('f1_score',
                               x=TestX,
                               y=TestY.values,
                               colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
                                figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()

# %%
# OXA48
# -------

target = 'OXA48'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%

comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()

# %%
# TEM
# ------

target = 'TEM'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()

# %%
# MCR-1
# ------

target = 'MCR-1'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, scenario)

# %%

comparisons = MLClassificationExperiments(
    input_features=inputs,
    output_features=target,
    monitor=['f1_score', 'accuracy',
             'precision', 'recall',
             f1_micro, f1_weighted,
             precision_micro, precision_weighted,
             recall_micro, recall_weighted,
             ],
    verbosity=0,
    show=False
)

# %%
comparisons.fit(x=TrainX.values,
                y=TrainY.values,
                exclude=['LinearDiscriminantAnalysis',
                         'BaggingClassifier',
                         'HistGradientBoostingClassifier',
                         'DecisionTreeClassifier',
                         'GaussianProcessClassifier']
                )

# %%
_ = comparisons.compare_errors(
    'f1_score',
    x=TestX,
    y=TestY.values,
    colors=['salmon', 'cadetblue'],
    label_bars=True,
    bar_label_kws={"color": 'black', 'label_type':'center'},
    figsize=(10, 8))
plt.tight_layout()
if SAVE:
    plt.savefig(f"results/figures/exp_{target}_{scenario}", bbox_inches="tight", dpi=600)
plt.show()