"""
=============
with autotab
=============
"""

import pandas as pd

from autotab import OptimizePipeline
from autotab._main import METRIC_TYPES
from ai4water.utils.utils import TrainTestSplit
from autotab.utils import Callbacks

from sklearn.metrics import cohen_kappa_score
from SeqMetrics import ClassificationMetrics

from utils import make_data, return_train_test


def f1_score_macro(t,p)->float:
    if (p == p[0]).all():
        return 0.1
    return ClassificationMetrics(t, p).f1_score(average="macro")


def sensitivity(t,p)->float:
    if (p == p[0]).all():
        return 0.0
    return ClassificationMetrics(t, p).recall(average="macro")


def specificity(t,p)->float:
    if (p == p[0]).all():
        return 0.0
    return ClassificationMetrics(t, p).specificity(average="macro")

def kappa(t,p)->float:
    if (p == p[0]).all():
        return 0.0
    return cohen_kappa_score(t, p)


METRIC_TYPES['f1_score_macro'] = "max"
METRIC_TYPES['sensitivity'] = "max"
METRIC_TYPES['specificity'] = "max"
METRIC_TYPES['kappa'] = "max"

target = 'MCR-1'

tr_features = [ 'T  (℃)', 'pH', 'TDS (mg/L)', 'DO (mg/L)', 'TN (mg/L)']

TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes'
                                                         )

train_x, val_x, train_y, val_y = TrainTestSplit(seed=313).split_by_random(
    TrainX,
    TrainY
)

import wandb

test_metrics = []
class MyCallbacks(Callbacks):
    def on_eval_begin(self, model,
                      iter_num=None,
                      x=None, y=None,
                      validation_data=None) -> None:
        model.fit(x=TrainX, y=TrainY.values)
        test_p = model.predict(TestX)
        test_f1 = f1_score_macro(TestY.values, test_p)
        test_metrics.append(test_f1)
        wandb.log({"test_f1": test_f1})
        return

kws = {
        "mode": "classification",
        "eval_metric": f1_score_macro,
        "inputs_to_transform": tr_features,
        "outputs_to_transform": None,
        "models": [
                   #  "ExtraTreeClassifier",
                   # "RandomForestClassifier",
                   # "XGBClassifier",
                   # "CatBoostClassifier",
                   "LGBMClassifier",
                   # "GradientBoostingClassifier",
                   # "HistGradientBoostingClassifier",
                   # "ExtraTreesClassifier",
                   # "RidgeClassifier",
                   # "SVC",
                   # "KNeighborsClassifier",
                   # "GaussianProcessClassifier"
                        ],
        "parent_iterations": 1000,
        "child_iterations": 40,
        "parent_algorithm": 'tpe',
        "child_algorithm": 'random',
        "monitor": [kappa, sensitivity, specificity],
        "input_features": inputs,
        "output_features": target,
    "wandb_config": dict(project=f"ab_isl_{target}_no_genes", entity="aa_si",
                         config={})
    }


class MyPipeline(OptimizePipeline):

    def wb_finish(self):
        """prepares the logs and puts them on wandb"""
        if self.use_wb and self.parent_iter_ > 0:

            # 🐝 Create a wandb Table to log parent suppestions and metrics
            df = pd.DataFrame(
                [list(val.values()) for val in self._parent_suggestions_.values()],
            columns=list(self._parent_suggestions_[0].keys())
            )

            df['iterations'] = self.parent_suggestions_.keys()
            df['test_f1'] = test_metrics

            df = pd.concat([df, self.metrics_], axis=1)

            df['hyperparas'] = [list(val['model'].values())[0] for val in self.parent_suggestions_.values()]

            table = wandb.Table(data=df, allow_mixed_types=True,
                                columns=df.columns.tolist())

            self.wb_run_.log({"result": table})

            if self.child_iter_>0:
                table = wandb.Table(
                    data=pd.DataFrame(self.child_val_scores_),
                    allow_mixed_types=True)

                self.wb_run_.log({"child_hpo_results": table})

            self.wb_run_.notes = self.report(False)

            cols = self.metrics_best_.columns
            # find last nan value in each column
            indices = self.metrics_best_.apply(pd.Series.last_valid_index)
            vals = [self.metrics_best_[col].iloc[index] for col, index in zip(cols, indices)]
            summary_metrics = {metric:val for metric, val in zip(cols, vals)}

            self.wb_run_.summary = summary_metrics

            self.wb_run_.finish()
        return


# with MyPipeline(**kws) as pl:
#     pl.fit(x=train_x, y=train_y,
#            validation_data=(val_x, val_y),
#            callbacks= [MyCallbacks()],
#            process_results=False,
#            finish_wb=False
#            )
#
# pl.bfe_all_best_models(x=TrainX, y=TrainY, test_data=(TestX, TestY),
#                        metric_name="f1_score_macro"
#                       )
#
# pl.proces_hpo_results(pl.optimizer_, importance=False)
#
# pl.baseline_results(x=TrainX, y=TrainY, test_data=(TestX, TestY))
#
# pl.wb_finish()
