"""
===================
3. Models
===================

"""

import matplotlib.pyplot as plt

from SeqMetrics import ClassificationMetrics
from easy_mpl import bar_chart

from utils import return_train_test, \
                  no_gene_CTX_M, no_gene_OXA48, no_gene_TEM, \
                  no_gene_MCR_1, only_wq_CTX_M, \
                  only_wq_OXA48, only_wq_TEM, only_wq_MCR_1, \
                    only_antibiotics_CTX_M, \
                  only_antibiotics_OXA48, only_antibiotics_TEM, \
                    only_antibiotics_MCR_1

from utils import SAVE
from utils import version_info, set_rcParams

# %%

version_info()

# %%

set_rcParams()

# %%

model_name = 'LGBMClassifier'

# %%

def f1_score_macro(t,p)->float:
    if (p == p[0]).all():
        return 0.1
    return ClassificationMetrics(t, p).f1_score(average="macro")

# %%

ctxm_f1_test = []

oxa48_f1_test = []

tem_f1_test = []

mcr1_f1_test = []

# %%
# Scenario 1: WQ plus Antibiotics
# ===============================

scenario = 'no_genes'

# %%
# CTX-M
# -----

target = 'CTX-M'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = no_gene_CTX_M(model_name).predict(x=TestX)

# %%

ctxm_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# OXA48
# -----

target = 'OXA48'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = no_gene_OXA48(model_name).predict(x=TestX)

# %%

oxa48_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# TEM
# ---

target = 'TEM'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = no_gene_TEM(model_name).predict(x=TestX)

# %%

tem_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# MCR-1
# -----

target = 'MCR-1'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = no_gene_MCR_1(model_name).predict(x=TestX)

# %%

mcr1_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# Scenario 2: Only Water Quality
# ===============================

scenario = 'only_wq'

# %%
# CTX-M
# -----

target = 'CTX-M'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = only_wq_CTX_M(model_name).predict(x=TestX)

# %%

ctxm_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# OXA48
# -----

target = 'OXA48'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = only_wq_OXA48(model_name).predict(x=TestX)
# %%

oxa48_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# TEM
# -----

target = 'TEM'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = only_wq_TEM(model_name).predict(x=TestX)

# %%
tem_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# MCR-1
# -----

target = 'MCR-1'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = only_wq_MCR_1(model_name).predict(x=TestX)
mcr1_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')


# %%
# Scenario 3: Only Antibiotics
# ============================

scenario = 'only_antibiotics'

# %%
# CTX-M
# -----

target = 'CTX-M'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = only_antibiotics_CTX_M(model_name).predict(x=TestX)

# %%

ctxm_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# OXA48
# -----

target = 'OXA48'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = only_antibiotics_OXA48(model_name).predict(x=TestX)
# %%

oxa48_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# TEM
# -----

target = 'TEM'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = only_antibiotics_TEM(model_name).predict(x=TestX)

# %%
tem_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%
# MCR-1
# -----

target = 'MCR-1'

_, _, TestX, TestY, _ = return_train_test(target, scenario)

test_p = only_antibiotics_MCR_1(model_name).predict(x=TestX)
mcr1_f1_test.append(f1_score_macro(TestY.values, test_p))

print(f'(target = {target}, scenario = {scenario}, test F1 score = '
      f'{f1_score_macro(TestY.values, test_p)})')

# %%

data = [ctxm_f1_test, oxa48_f1_test, tem_f1_test, mcr1_f1_test]

colors = ['c', 'm', 'y']
#colors = ['c', 'm', 'y', 'salmon']
scenario_labels = ["WQ plus Antibiotics", "Only WQ", "Only Antibiotics"]

legend_colors = {}
for i in range(len(scenario_labels)):
    legend_colors[scenario_labels[i]] = colors[i]

ax = bar_chart(data, color=colors, orient='v', width=0.15,
               labels=['CTX-M', 'OXA48', 'TEM', 'MCR-1'], show=False)

handles = [plt.Rectangle((0,0),1,1, color=legend_colors[l])
           for l in scenario_labels]

plt.legend(handles, scenario_labels, loc='upper left')
ax.set_ylim((0.5, ax.get_ylim()[1]))
if SAVE:
    plt.savefig("results/figures/scenarios_performance.png", dpi=600, bbox_inches="tight")
plt.show()
