utils

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from typing import Union

import seaborn as sns
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

import shap

from ai4water import Model
from ai4water.utils.utils import TrainTestSplit
from ai4water.utils.utils import get_version_info

from easy_mpl import plot, boxplot, pie, bar_chart, scatter
from easy_mpl.utils import create_subplots
LABEL_MAP = {
    'season': 'Season'
}
SAVE = False
CATEGORIES = {
    "Physicochemical": [
        'T  (℃)', 'pH', 'TDS (mg/L)', 'DO (mg/L)', 'TN (mg/L)',
        'season'
                        ],
    "ARGs": ['MCR-1',    #gene
              'CTX-M',  # gene
              'SHV',  # gene
              'TEM',  # gene
              'OXA48',  # gene
              'KPC',  # gene
              'NDM',  # gene
              'IMP',  # gene
              'QnrS',  # gene
              'Qnr A'],
    "Antibiotics": ['F', 'DO',  'CN',  'ATM', 'FOT',],
}
def read_data():

    fpath = os.path.join(os.getcwd(), "data", "meropenem_isolates.xlsx")
    fall20_df = pd.read_excel(fpath,
                              sheet_name=0, header=2)
    fall20_df['season'] = 0

    win2021_df = pd.read_excel(fpath,
                               sheet_name=1, header=2)
    win2021_df['season'] = 1

    spr21_df = pd.read_excel(fpath,
                             sheet_name=2, header=2)
    spr21_df['season'] = 2

    sum21_df = pd.read_excel(fpath,
                             sheet_name=3, header=2)
    sum21_df['season'] = 3

    frames = [fall20_df, win2021_df, spr21_df, sum21_df]

    result = pd.concat(frames)

    result = result.reset_index()
    result.pop('index')

    # put the season column at the start
    result.insert(0, "season", result.pop("season"))

    columns = ["Qnr A", 'Qnr B', 'QnrS', 'IMP', 'NDM', 'BIC', 'KPC',
               'OXA48', 'TEM', 'SHV', 'CTX-M', 'MCR-1', 'AMP', 'CAZ',
               'AMC', 'CTX', 'FEP', 'ATM', 'MEM', 'CIP', 'C', 'NA',
               'CN', 'DO', 'F', 'FOT', 'SXT']

    cat_inp_features = ['DO', 'F']

    for col in columns:
        result[col] = result[col].str.strip()

    return result
# function for OHE
def _ohe_encoder(df:pd.DataFrame, col_name:str)->tuple:
    assert isinstance(col_name, str)

    encoder = OneHotEncoder(sparse=False)
    ohe_cat = encoder.fit_transform(df[col_name].values.reshape(-1, 1))
    cols_added = [f"{col_name}_{i}" for i in range(ohe_cat.shape[-1])]

    df[cols_added] = ohe_cat

    df = df.drop(columns=col_name)

    return df, cols_added, encoder
def _label_encoder(df:pd.DataFrame, col_name:str)->tuple:
    assert isinstance(col_name, str)

    encoder = LabelEncoder()
    le_cat = encoder.fit_transform(df[col_name].values.reshape(-1, 1))

    df[col_name] = le_cat

    return df, col_name, encoder


def make_data(
        inputs,
        target,
        encode=True,
        encode_type_ab:str = "le"
):
    """

    :param inputs:
        names of input features
    :param target:
        the target to use
    :param encode:
        whether to encode antibiotic and ARGs or not
    :param encode_type_ab:
        the type of encoding to be applied on Antibiotic data
    :return:
    """
    _def_inputs = ['season', 'T  (℃)', 'pH', 'TDS (mg/L)', 'DO (mg/L)', 'TN (mg/L)']
    _args = ['Qnr A', 'Qnr B', 'QnrS', 'IMP', 'NDM', 'BIC',  'KPC', 'OXA48',
             'TEM', 'SHV', 'CTX-M', 'MCR-1']

    whole_data = read_data()

    assert isinstance(inputs, list)
    assert isinstance(target, str)
    assert encode_type_ab in ["le", "ohe"]

    parameters = inputs + [target]

    data = pd.DataFrame()
    encoders = {}
    for para in parameters:
        if para in _def_inputs:
            data[para] = whole_data[para].copy()
        elif para != target:
            if para in _args:
                if encode:
                    encoded_data, cols, encoders[f"{para}_enc"] = _label_encoder(whole_data, para)
                    data[cols] = encoded_data[cols].copy()
                else:
                    data[para] = whole_data[para].copy()
            else:
                if encode:
                    if encode_type_ab == "ohe":
                        encoded_data, cols, encoders[f"{para}_enc"] = _ohe_encoder(whole_data, para)
                        data[cols] = encoded_data[cols].copy()
                    else:
                        encoded_data, cols, encoders[f"{para}_enc"] = _label_encoder(whole_data, para)
                        data[cols] = encoded_data[cols].copy()
                else:
                    data[para] = whole_data[para].copy()

    data[target] = whole_data[target]

    # replacing strings in target with binary
    data[target] = data[target].replace(['-'], 0)
    data[target] = data[target].replace(['+'], 1)

    return data, encoders
def return_train_test(target, scenario=None, return_encoders=False)->tuple:

    inputs = ['season', 'T  (℃)', 'pH', 'TDS (mg/L)',
              'DO (mg/L)', 'TN (mg/L)',
              'F',        #antibiotic
              'DO',       #antibiotic
              'CN',       #antibiotic
              'ATM',      #antibiotic
              'FOT',      #antibiotic
              'MCR-1',    #gene
              'CTX-M',  # gene
              'SHV',  # gene
              'TEM',  # gene
              'OXA48',  # gene
              'KPC',  # gene
              'NDM',  # gene
              'IMP',  # gene
              'QnrS',  # gene
              'Qnr A'  # gene
              ]

    antibiotics = ['F', 'DO', 'CN', 'ATM', 'FOT']

    genes = ['MCR-1', 'CTX-M', 'SHV', 'TEM', 'OXA48',
             'KPC', 'NDM', 'IMP', 'QnrS', 'Qnr A']

    wq = ['season', 'T  (℃)', 'pH', 'TDS (mg/L)',
              'DO (mg/L)', 'TN (mg/L)']

    if scenario=='no_genes':
        for i in genes:
            inputs.remove(i)

    elif scenario=='no_antibiotics':
        inputs.remove(target)
        for i in antibiotics:
            inputs.remove(i)

    elif scenario=='only_wq':
        features = antibiotics + genes
        for i in features:
            inputs.remove(i)

    elif scenario=='only_antibiotics':
        features = genes + wq
        for i in features:
            inputs.remove(i)

    elif scenario in [None, 'all_inputs']:
        inputs.remove(target)
    else:
        raise ValueError(f"Invalid scenario {scenario}")

    data, encoders = make_data(
        inputs=inputs, target=target, encode=True)

    data = data.dropna()

    input_features = data.columns.tolist()[0:-1]
    output_features = data.columns.tolist()[-1:]

    TrainX, TestX, TrainY, TestY = TrainTestSplit(seed=313).split_by_random(
        data[input_features],
        data[output_features]
    )

    if return_encoders:
        return TrainX, TrainY, TestX, TestY, inputs, encoders

    return TrainX, TrainY, TestX, TestY, inputs
def all_inputs_CTX_M(model_name):

    target = 'CTX-M'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target)


    if model_name=='LGBMClassifier':
        model = Model(
            model={'LGBMClassifier': {'n_estimators': 139,
                                      'boosting_type': 'gbdt',
                                      'num_leaves': 494,
                                      'learning_rate': 0.09328089700948024}},
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'box-cox', 'features': ['T  (℃)']},
                              {'method': 'vast', 'features': ['pH']},
                              {'method': 'box-cox', 'features': ['TDS (mg/L)']},
                              {'method': 'quantile_normal', 'features': ['DO (mg/L)']},
                              {'method': 'quantile', 'features': ['TN (mg/L)']}
                              ],
        )

    elif model_name=='HistGradientBoostingClassifier':
        model = Model(
            model={"HistGradientBoostingClassifier": {
                "learning_rate": 0.2714740258492424,
                "max_iter": 428,
                "max_depth": 85,
                "max_leaf_nodes": 26,
                "min_samples_leaf": 17,
                "l2_regularization": 0.4170610185230958,
                "random_state": 313
            }},
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'zscore', 'features': ['T  (℃)']},
                              {'method': 'scale', 'features': ['pH']},
                              {'method': 'log2', 'features': ['TDS (mg/L)']},
                              {'method': 'yeo-johnson', 'features': ['DO (mg/L)']},
                              {'method': 'log2', 'features': ['TN (mg/L)']}
                              ],
        )

    else:
        raise ValueError(f'Invalid {model_name}')

    model.seed_everything(313)

    model.fit(x=TrainX, y=TrainY.values)

    return model
def all_inputs_OXA48(model_name):

    target = 'OXA48'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target)

    if model_name == 'LGBMClassifier':
        model = Model(
            model='LGBMClassifier',
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'vast', 'features': ['T  (℃)']},
                              {'method': 'vast', 'features': ['pH']},
                              {'method': 'log', 'features': ['TDS (mg/L)']},
                              #{'method': 'log2', 'features': ['DO (mg/L)']},
                              {'method': 'log2', 'features': ['TN (mg/L)']}
                              ],
        )
    elif model_name=='HistGradientBoostingClassifier':
        model = Model(
            model="HistGradientBoostingClassifier",
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'log', 'features': ['T  (℃)']},
                              {'method': 'box-cox', 'features': ['pH']},
                              {'method': 'quantile_normal', 'features': ['TDS (mg/L)']},
                              {'method': 'minmax', 'features': ['DO (mg/L)']},
                              {'method': 'sqrt', 'features': ['TN (mg/L)']}
                              ],
        )

    else:
        raise ValueError(f'Invalid {model_name}')

    model.seed_everything(313)

    model.fit(x=TrainX, y=TrainY.values)

    return model
def all_inputs_TEM(model_name):

    target = 'TEM'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target)

    if model_name == 'LGBMClassifier':
        model = Model(
            model={'LGBMClassifier': {'n_estimators': 144,
                                      'boosting_type': 'gbdt',
                                      'num_leaves': 423,
                                      'learning_rate': 0.09147939344768526}},
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'minmax', 'features': ['T  (℃)']},
                              {'method': 'yeo-johnson', 'features': ['pH']},
                              {'method': 'log10', 'features': ['TDS (mg/L)']},
                              {'method': 'box-cox', 'features': ['DO (mg/L)']},
                              {'method': 'quantile', 'features': ['TN (mg/L)']}
                              ],
        )
    elif model_name=='HistGradientBoostingClassifier':
        model = Model(
            model="HistGradientBoostingClassifier",
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[#{'method': 'zscore', 'features': ['T  (℃)']},
                              {'method': 'quantile_normal', 'features': ['pH']},
                              {'method': 'center', 'features': ['TDS (mg/L)']},
                              {'method': 'center', 'features': ['DO (mg/L)']},
                              {'method': 'robust', 'features': ['TN (mg/L)']}
                              ],
        )

    else:
        raise ValueError(f'Invalid {model_name}')

    model.seed_everything(313)

    model.fit(x=TrainX, y=TrainY.values)

    return model
def all_inputs_MCR_1(model_name):

    target = 'MCR-1'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target)

    if model_name == 'LGBMClassifier':
        model = Model(
            model={'LGBMClassifier': {'boosting_type': 'gbdt',
                                      'num_leaves': 185,
                                      'learning_rate': 0.07853514838993073,
                                      'n_estimators': 85,
                                      }},
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'vast', 'features': ['T  (℃)']},
                              {'method': 'scale', 'features': ['pH']},
                              {'method': 'quantile', 'features': ['TDS (mg/L)']},
                              {'method': 'zscore', 'features': ['DO (mg/L)']},
                              {'method': 'log2', 'features': ['TN (mg/L)']}
                              ],
        )
    elif model_name=='HistGradientBoostingClassifier':
        model = Model(
            model="HistGradientBoostingClassifier",
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'log10', 'features': ['T  (℃)']},
                              {'method': 'center', 'features': ['pH']},
                              {'method': 'robust', 'features': ['TDS (mg/L)']},
                              {'method': 'log10', 'features': ['DO (mg/L)']},
                              {'method': 'quantile_normal', 'features': ['TN (mg/L)']}
                              ],
        )

    else:
        raise ValueError(f'Invalid {model_name}')

    model.seed_everything(313)

    model.fit(x=TrainX, y=TrainY.values)

    return model
def get_all_inputs_model(
        target:str,
        model_name:str = "LGBMClassifier"
):
    if target == 'MCR-1':
        return all_inputs_MCR_1(model_name)
    elif target == "TEM":
        return all_inputs_TEM(model_name)
    elif target == "CTX-M":
        return all_inputs_CTX_M(model_name)
    elif target == "OXA48":
        return all_inputs_OXA48(model_name)
    else:
        raise ValueError
def no_gene_CTX_M(model_name, from_config=False):

    target = 'CTX-M'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130840')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model={'LGBMClassifier': {'boosting_type': 'goss',
                                          'num_leaves': 377,
                                          'learning_rate': 0.01511065501651556,
                                          'n_estimators': 91,
                                          }},
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'center', 'features': ['T  (℃)']},
                                  {'method': 'quantile', 'features': ['pH']},
                                  {'method': 'yeo-johnson', 'features': ['TDS (mg/L)']},
                                  #{'method': 'quantile_normal', 'features': ['DO (mg/L)']},
                                  {'method': 'log', 'features': ['TN (mg/L)']}
                                  ],
            )
        elif model_name=='HistGradientBoostingClassifier':
            model = Model(
                model="HistGradientBoostingClassifier",
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'zscore', 'features': ['T  (℃)']},
                                  {'method': 'box-cox', 'features': ['pH']},
                                  {'method': 'quantile_normal', 'features': ['TDS (mg/L)']},
                                  {'method': 'vast', 'features': ['DO (mg/L)']},
                                  {'method': 'quantile_normal', 'features': ['TN (mg/L)']}
                                  ],
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def no_gene_OXA48(model_name, from_config=False):

    target = 'OXA48'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130714')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model='LGBMClassifier',
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'scale', 'features': ['T  (℃)']},
                                  {'method': 'robust', 'features': ['pH']},
                                  {'method': 'zscore', 'features': ['TDS (mg/L)']},
                                  {'method': 'minmax', 'features': ['DO (mg/L)']},
                                  {'method': 'quantile_normal', 'features': ['TN (mg/L)']}
                                  ],
            )
        elif model_name=='HistGradientBoostingClassifier':
            model = Model(
                model="HistGradientBoostingClassifier",
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'box-cox', 'features': ['T  (℃)']},
                                  {'method': 'quantile_normal', 'features': ['pH']},
                                  {'method': 'minmax', 'features': ['TDS (mg/L)']},
                                  {'method': 'scale', 'features': ['DO (mg/L)']},
                                  {'method': 'scale', 'features': ['TN (mg/L)']}
                                  ],
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def no_gene_TEM(model_name, from_config=False):

    target = 'TEM'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130553')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model='LGBMClassifier',
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'robust', 'features': ['T  (℃)']},
                                  {'method': 'center', 'features': ['pH']},
                                  #{'method': 'quantile_normal', 'features': ['TDS (mg/L)']},
                                  {'method': 'log2', 'features': ['DO (mg/L)']},
                                  {'method': 'yeo-johnson', 'features': ['TN (mg/L)']}
                                  ],
            )
        elif model_name=='HistGradientBoostingClassifier':
            model = Model(
                model="HistGradientBoostingClassifier",
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[#{'method': 'zscore', 'features': ['T  (℃)']},
                                  {'method': 'center', 'features': ['pH']},
                                  {'method': 'minmax', 'features': ['TDS (mg/L)']},
                                  {'method': 'minmax', 'features': ['DO (mg/L)']},
                                  {'method': 'center', 'features': ['TN (mg/L)']}
                                  ],
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def no_gene_MCR_1(model_name, from_config=False):

    target = 'MCR-1'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130424')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model={'LGBMClassifier': {'boosting_type': 'goss',
                                          'num_leaves': 245,
                                          'learning_rate': 0.08783749114657222,
                                          'n_estimators': 139,
                                          }},
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'box-cox', 'features': ['T  (℃)']},
                                  {'method': 'vast', 'features': ['pH']},
                                  {'method': 'box-cox', 'features': ['TDS (mg/L)']},
                                  {'method': 'quantile_normal', 'features': ['DO (mg/L)']},
                                  {'method': 'quantile', 'features': ['TN (mg/L)']}
                                  ],
            )
        elif model_name=='HistGradientBoostingClassifier':
            model = Model(
                model="HistGradientBoostingClassifier",
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'pareto', 'features': ['T  (℃)']},
                                  {'method': 'vast', 'features': ['pH']},
                                  {'method': 'quantile', 'features': ['TDS (mg/L)']},
                                  {'method': 'minmax', 'features': ['DO (mg/L)']},
                                  {'method': 'minmax', 'features': ['TN (mg/L)']}
                                  ],
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def get_no_gene_model(
        target:str,
        from_config,
        model_name:str = "LGBMClassifier"
):
    if target == 'MCR-1':
        return no_gene_MCR_1(model_name, from_config)
    elif target == "TEM":
        return no_gene_TEM(model_name, from_config)
    elif target == "CTX-M":
        return no_gene_CTX_M(model_name, from_config)
    elif target == "OXA48":
        return no_gene_OXA48(model_name, from_config)
    else:
        raise ValueError
def no_antibiotic_CTX_M(model_name):

    target = 'CTX-M'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_antibiotics')

    if model_name == 'LGBMClassifier':
        model = Model(
            model={'LGBMClassifier': {'boosting_type': 'gbdt',
                                      'num_leaves': 283,
                                      'learning_rate': 0.08556696555457217,
                                      'n_estimators': 141,
                                      }},
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'box-cox', 'features': ['T  (℃)']},
                              {'method': 'scale', 'features': ['pH']},
                              {'method': 'log10', 'features': ['TDS (mg/L)']},
                              {'method': 'vast', 'features': ['DO (mg/L)']},
                              {'method': 'log2', 'features': ['TN (mg/L)']}
                              ],
        )
    elif model_name=='HistGradientBoostingClassifier':
        model = Model(
            model="HistGradientBoostingClassifier",
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'log2', 'features': ['T  (℃)']},
                              {'method': 'log', 'features': ['pH']},
                              {'method': 'yeo-johnson', 'features': ['TDS (mg/L)']},
                              {'method': 'robust', 'features': ['DO (mg/L)']},
                              {'method': 'pareto', 'features': ['TN (mg/L)']}
                              ],
        )

    else:
        raise ValueError(f'Invalid {model_name}')

    model.seed_everything(313)

    model.fit(x=TrainX, y=TrainY.values)

    return model
def no_antibiotic_OXA48(model_name):

    target = 'OXA48'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_antibiotics')

    if model_name == 'LGBMClassifier':
        model = Model(
            model='LGBMClassifier',
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'scale', 'features': ['T  (℃)']},
                              {'method': 'pareto', 'features': ['pH']},
                              {'method': 'pareto', 'features': ['TDS (mg/L)']},
                              #{'method': 'vast', 'features': ['DO (mg/L)']},
                              {'method': 'log2', 'features': ['TN (mg/L)']}
                              ],
        )
    elif model_name=='HistGradientBoostingClassifier':
        model = Model(
            model="HistGradientBoostingClassifier",
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'log', 'features': ['T  (℃)']},
                              {'method': 'sqrt', 'features': ['pH']},
                              {'method': 'quantile', 'features': ['TDS (mg/L)']},
                              {'method': 'quantile_normal', 'features': ['DO (mg/L)']},
                              {'method': 'log10', 'features': ['TN (mg/L)']}
                              ],
        )

    else:
        raise ValueError(f'Invalid {model_name}')

    model.seed_everything(313)

    model.fit(x=TrainX, y=TrainY.values)

    return model
def no_antibiotic_TEM(model_name):

    target = 'TEM'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_antibiotics')

    if model_name == 'LGBMClassifier':
        model = Model(
            model='LGBMClassifier',
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'yeo-johnson', 'features': ['T  (℃)']},
                              {'method': 'log10', 'features': ['pH']},
                              {'method': 'zscore', 'features': ['TDS (mg/L)']},
                              {'method': 'log', 'features': ['DO (mg/L)']},
                              {'method': 'sqrt', 'features': ['TN (mg/L)']}
                              ],
        )
    elif model_name=='HistGradientBoostingClassifier':
        model = Model(
            model="HistGradientBoostingClassifier",
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'minmax', 'features': ['T  (℃)']},
                              {'method': 'center', 'features': ['pH']},
                              {'method': 'quantile', 'features': ['TDS (mg/L)']},
                              {'method': 'minmax', 'features': ['DO (mg/L)']},
                              #{'method': 'log2', 'features': ['TN (mg/L)']}
                              ],
        )

    else:
        raise ValueError(f'Invalid {model_name}')

    model.seed_everything(313)

    model.fit(x=TrainX, y=TrainY.values)

    return model


def no_antibiotic_MCR_1(model_name):

    target = 'MCR-1'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_antibiotics')

    if model_name == 'LGBMClassifier':
        model = Model(
            model='LGBMClassifier',
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'box-cox', 'features': ['T  (℃)']},
                              {'method': 'log2', 'features': ['pH']},
                              {'method': 'scale', 'features': ['TDS (mg/L)']},
                              {'method': 'log', 'features': ['DO (mg/L)']},
                              {'method': 'log10', 'features': ['TN (mg/L)']}
                              ],
        )
    elif model_name=='HistGradientBoostingClassifier':
        model = Model(
            model="HistGradientBoostingClassifier",
            input_features=inputs,
            output_features=target,
            verbosity=0,
            x_transformation=[{'method': 'yeo-johnson', 'features': ['T  (℃)']},
                              #{'method': 'scale', 'features': ['pH']},
                              {'method': 'quantile_normal', 'features': ['TDS (mg/L)']},
                              {'method': 'log', 'features': ['DO (mg/L)']},
                              {'method': 'pareto', 'features': ['TN (mg/L)']}
                              ],
        )

    else:
        raise ValueError(f'Invalid {model_name}')

    model.seed_everything(313)

    model.fit(x=TrainX, y=TrainY.values)

    return model
def only_wq_CTX_M(model_name, from_config=False):

    target = 'CTX-M'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_wq')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130241')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model='LGBMClassifier',
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'sqrt', 'features': ['T  (℃)']},
                                  {'method': 'quantile', 'features': ['pH']},
                                  {'method': 'log2', 'features': ['TDS (mg/L)']},
                                  {'method': 'log2', 'features': ['DO (mg/L)']},
                                  {'method': 'quantile', 'features': ['TN (mg/L)']}
                                  ],
            )
        elif model_name=='HistGradientBoostingClassifier':
            model = Model(
                model="HistGradientBoostingClassifier",
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'zscore', 'features': ['T  (℃)']},
                                  {'method': 'zscore', 'features': ['pH']},
                                  {'method': 'log10', 'features': ['TDS (mg/L)']},
                                  {'method': 'log10', 'features': ['DO (mg/L)']},
                                  {'method': 'zscore', 'features': ['TN (mg/L)']}
                                  ],
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def only_wq_OXA48(model_name, from_config=False):

    target = 'OXA48'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_wq')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130130')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model='LGBMClassifier',
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'box-cox', 'features': ['T  (℃)']},
                                  {'method': 'quantile', 'features': ['pH']},
                                  {'method': 'vast', 'features': ['TDS (mg/L)']},
                                  {'method': 'quantile', 'features': ['DO (mg/L)']},
                                  {'method': 'center', 'features': ['TN (mg/L)']}
                                  ],
            )
        elif model_name=='HistGradientBoostingClassifier':
            model = Model(
                model="HistGradientBoostingClassifier",
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'log10', 'features': ['T  (℃)']},
                                  {'method': 'log2', 'features': ['pH']},
                                  {'method': 'sqrt', 'features': ['TDS (mg/L)']},
                                  {'method': 'scale', 'features': ['DO (mg/L)']},
                                  {'method': 'scale', 'features': ['TN (mg/L)']}
                                  ],
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def only_wq_TEM(model_name, from_config=False):

    target = 'TEM'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_wq')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230627_125951')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model='LGBMClassifier',
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'vast', 'features': ['T  (℃)']},
                                  {'method': 'scale', 'features': ['pH']},
                                  {'method': 'center', 'features': ['TDS (mg/L)']},
                                  {'method': 'sqrt', 'features': ['DO (mg/L)']},
                                  {'method': 'pareto', 'features': ['TN (mg/L)']}
                                  ],
            )
        elif model_name=='HistGradientBoostingClassifier':
            model = Model(
                model="HistGradientBoostingClassifier",
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'robust', 'features': ['T  (℃)']},
                                  {'method': 'log2', 'features': ['pH']},
                                  {'method': 'quantile', 'features': ['TDS (mg/L)']},
                                  {'method': 'quantile', 'features': ['DO (mg/L)']},
                                  {'method': 'minmax', 'features': ['TN (mg/L)']}
                                  ],
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def only_wq_MCR_1(model_name, from_config=False):

    target = 'MCR-1'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_wq')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230627_125012')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model='LGBMClassifier',
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'yeo-johnson', 'features': ['T  (℃)']},
                                  {'method': 'pareto', 'features': ['pH']},
                                  {'method': 'log10', 'features': ['TDS (mg/L)']},
                                  {'method': 'box-cox', 'features': ['DO (mg/L)']},
                                  {'method': 'sqrt', 'features': ['TN (mg/L)']}
                                  ],
            )
        elif model_name=='HistGradientBoostingClassifier':
            model = Model(
                model="HistGradientBoostingClassifier",
                input_features=inputs,
                output_features=target,
                verbosity=0,
                x_transformation=[{'method': 'yeo-johnson', 'features': ['T  (℃)']},
                                  {'method': 'robust', 'features': ['pH']},
                                  {'method': 'log', 'features': ['TDS (mg/L)']},
                                  {'method': 'quantile_normal', 'features': ['DO (mg/L)']},
                                  {'method': 'log2', 'features': ['TN (mg/L)']}
                                  ],
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def only_antibiotics_MCR_1(model_name, from_config=False):

    target = 'MCR-1'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_antibiotics')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230705_164018')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model={'LGBMClassifier': {'n_estimators': 150,
                              'boosting_type': 'goss',
                              'num_leaves': 500,
                              'learning_rate': 0.09734360597562812}},
                input_features=inputs,
                output_features=target,
                verbosity=0,
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def only_antibiotics_CTX_M(model_name, from_config=False):

    target = 'CTX-M'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_antibiotics')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230705_154741')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model={'LGBMClassifier': {'n_estimators': 44,
                              'boosting_type': 'goss',
                              'num_leaves': 458,
                              'learning_rate': 0.045723247279341225}},
                input_features=inputs,
                output_features=target,
                verbosity=0,
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def only_antibiotics_TEM(model_name, from_config=False):

    target = 'TEM'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_antibiotics')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230705_154645')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model={'LGBMClassifier': {'n_estimators': 149,
                              'boosting_type': 'goss',
                              'num_leaves': 483,
                              'learning_rate': 0.053619829014340085}},
                input_features=inputs,
                output_features=target,
                verbosity=0,
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def only_antibiotics_OXA48(model_name, from_config=False):

    target = 'OXA48'

    TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_antibiotics')

    if from_config:
        path = os.path.join(os.path.dirname(__file__), 'results', '20230705_164158')
        cpath = os.path.join(path, 'config.json')
        model = Model.from_config_file(config_path=cpath)
        model.update_weights()

    else:
        if model_name == 'LGBMClassifier':
            model = Model(
                model={'LGBMClassifier': {'n_estimators': 59,
                              'boosting_type': 'goss',
                              'num_leaves': 494,
                              'learning_rate': 0.07385784480075765}},
                input_features=inputs,
                output_features=target,
                verbosity=0,
            )

        else:
            raise ValueError(f'Invalid {model_name}')

        model.seed_everything(313)

        model.fit(x=TrainX, y=TrainY.values)

    return model
def plot_line(target, value, save=False):

    num_columns_without_seaon = ['T  (℃)', 'pH', 'TDS (mg/L)',
                                 'DO (mg/L)', 'TN (mg/L)']

    data = read_data()

    SEASONS = {
        0: "Fall",
        1: "Winter",
        2: "Spring",
        3: "Summer"
    }

    amp_r = data.loc[data[target] == value]

    groups = amp_r.groupby('season')

    fig, axes = create_subplots(5, sharex="all")

    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])

    colors = ['#588F07', '#E79D33', '#34ACAF', '#DD6F6D']

    for col, ax in zip(num_columns_without_seaon, axes.flatten()):

        for (label, grp), color in zip(groups, colors):
            legend = '__nolabel__'
            if col == 'TN (mg/L)':
                legend = SEASONS[label]

            plot(grp[col].values, label=legend, color=color, ax=ax, show=False)

        ax.set_ylabel(col)
        if col == 'TN (mg/L)':
            ax.legend(loc=(1.2, 0.2))

    plt.suptitle(f'{target}_{value}')
    plt.tight_layout()

    if save:
        plt.savefig(f"results/figures/line_{target}_value", bbox_inches="tight", dpi=600)
    plt.show()

    return
def plot_boxplot(target, value, save=False):

    num_columns_without_seaon = ['T  (℃)', 'pH', 'TDS (mg/L)',
                                 'DO (mg/L)', 'TN (mg/L)']

    data = read_data()

    SEASONS = {
        0: "Fall",
        1: "Winter",
        2: "Spring",
        3: "Summer"
    }

    amp_r = data.loc[data[target] == value]

    groups = amp_r.groupby('season')

    f, axes = create_subplots(5)

    colors = ['#588F07', '#E79D33', '#34ACAF', '#DD6F6D']

    if groups.size().size==3:
        colors = ['#588F07', '#E79D33', '#34ACAF']
    elif groups.size().size==2:
        colors = ['#588F07', '#E79D33']
    elif groups.size().size==1:
        colors = '#588F07'

    for col, ax in zip(num_columns_without_seaon, axes.flatten()):

        data_ = [grp[col].values for _, grp in groups]
        labels = [SEASONS[label] for label, _ in groups]
        boxplot(data_, labels=labels, ax=ax, show=False,
                line_color=colors,
                flierprops=dict(ms=2.0))

        ax.set_ylabel(col)
        # if col in ['TN (mg/L)', 'DO (mg/L)']:
        #     ax.set_xticks([1, 2, 3, 4])
        #     ax.set_xticklabels(labels)

    plt.suptitle(f'{target}_{value}')
    plt.tight_layout()
    if save:
        plt.savefig(f"results/figures/box_{target}_value", bbox_inches="tight", dpi=600)
    plt.show()

    return
def hist():

    num_columns = ['T  (℃)', 'pH', 'TDS (mg/L)',
                   'DO (mg/L)', 'TN (mg/L)', 'season']

    data = read_data()

    fig, axes = create_subplots(len(num_columns))

    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])

    for ax, col in zip(axes.flat, num_columns):
        sns.histplot(data[col], ax=ax,
                     )
       # ax.legend(fontsize=10)
    #plt.legend()
    plt.tight_layout()
    plt.show()

    return
def lineplot(palette=None, save=False, name=''):
    num_columns = ['T  (℃)', 'pH', 'TDS (mg/L)',
                   'DO (mg/L)', 'TN (mg/L)', 'season']

    data = read_data()

    fig, axes = create_subplots(len(num_columns))

    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])

    colors = ['#588F07', '#E79D33', '#34ACAF', '#DD6F6D']

    for ax, col in zip(axes.flat, num_columns):
        sns.lineplot(data[col], ax=ax,
                     palette=palette,
                     # show=False,
                     # title=label,
                     )
    plt.tight_layout()

    if save:
        plt.savefig(f"results/figures/line_{name}", bbox_inches="tight", dpi=600)
    plt.show()

    return

def season_line(season_data, title):

    num_columns_without_seaon = ['T  (℃)', 'pH', 'TDS (mg/L)',
                                 'DO (mg/L)', 'TN (mg/L)']

    fig, axes = create_subplots(len(num_columns_without_seaon))

    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])

    for ax, col in zip(axes.flat, num_columns_without_seaon):
        sns.lineplot(season_data[col], ax=ax,
                     palette='Spectral',
                     )
        #ax.legend(fontsize=10)
    #plt.legend()
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

    return
def set_rcParams(**kwargs):
    # https://matplotlib.org/stable/tutorials/introductory/customizing.html
    _kwargs = {
        'axes.labelsize': '14',
        'xtick.labelsize': '12',
        'ytick.labelsize': '12',
        'legend.title_fontsize': '12',
        'axes.titleweight': 'bold',
        'axes.titlesize': '14',
        'axes.labelweight': 'bold',
        'font.family': 'Times New Roman'
    }

    if kwargs:
        _kwargs.update(kwargs)

    for k,v in _kwargs.items():
        plt.rcParams[k] = v

    return
def version_info():
    info = get_version_info()
    info['shap'] = shap.__version__

    for k,v in info.items():
        print(k, v)

    return
def bar_pie(
        data:pd.DataFrame,
        ax:plt.Axes=None,
        save:bool = True,
        name:str = '',
        show:bool = True,
):
    if ax is None:
        f, ax = plt.subplots(figsize=(7, 9))

    sv_bar = data['mean_shap'].copy()
    colors = data['colors'].unique()
    feature_names = data['features'].tolist()

    ax_ = bar_chart(
        sv_bar,
        [LABEL_MAP[n] if n in LABEL_MAP else n for n in feature_names],
        bar_labels=np.round(sv_bar, 4),
        bar_label_kws={'label_type': 'edge',
                       'fontsize': 10,
                       'weight': 'bold',
                       "fmt": '%.4f',
                       'padding': 1.5
                       },
        show=False,
        sort=True,
        color=data['colors'].to_list(),
        ax=ax
    )
    ax_.spines[['top', 'right']].set_visible(False)
    ax_.set_xlabel(xlabel='mean(|SHAP value|)')
    # ax.set_xticklabels(ax.get_xticks().astype(float))
    ax_.set_yticklabels(ax_.get_yticklabels())

    labels = data['classes'].unique()
    handles = [plt.Rectangle((0, 0), 1, 1,
                             color=colors[idx]) for idx, l in enumerate(labels)]
    ax_.legend(handles, labels, loc='lower right')
    ax_.xaxis.set_major_locator(plt.MaxNLocator(4))

    # %%

    seg_colors = tuple(colors)
    # Change the saturation of seg_colors to 70% for the interior segments
    rgb = mcolors.to_rgba_array(seg_colors)[:, :-1]
    hsv = mcolors.rgb_to_hsv(rgb)
    hsv[:, 1] = 0.7 * hsv[:, 1]
    interior_colors = mcolors.hsv_to_rgb(hsv)

    fractions = np.array([
        data.loc[data['classes'] == 'Physicochemical']['mean_shap'].sum(),
        data.loc[data['classes'] == 'Antibiotics']['mean_shap'].sum(),
        #data.loc[data['classes'] == 'ARGs']['mean_shap'].sum(),
    ])

    labels = data['classes'].unique().tolist()

    fractions /= fractions.sum()

    ax2 = inset_axes(ax, width='50%', height='50%',
                     loc=5)

    _, texts = pie(fractions=fractions,
                   colors=seg_colors,
                   labels=labels,
                   wedgeprops=dict(edgecolor="w", width=0.03), radius=1,
                   autopct=None,
                   textprops=dict(fontsize=12),
                   startangle=90, counterclock=False, show=False,
                   ax=ax2)
    texts[0].set_fontsize(12)
    _, texts, autotexts = pie(fractions=fractions,
                              colors=interior_colors,
                              autopct='%1.0f%%',
                              textprops=dict(fontsize=24),
                              wedgeprops=dict(edgecolor="w"), radius=1 - 2 * 0.03,
                              startangle=90, counterclock=False, ax=ax2, show=False)
    texts[0].set_fontsize(12)
    if save:
        plt.savefig(f"results/figures/shap_bar_{name}.png", dpi=600, bbox_inches="tight")

    if show:
        plt.tight_layout()
        plt.show()
    return
def shap_scatter_plots(
        shap_values:np.ndarray,
        TrainX:pd.DataFrame,
        feature_name:str,
        encoders,
        save:bool = True,
        name: str = '',
):
    """
    It is expected that the columns in TrainX and shap_values have same order.

    Parameters
    ----------
    shap_values:
    TrainX:
    feature_name:
    encoders :
    save :
    name : str

    """

    CAT_FEAT = ['season',  'F',
                             'DO',
                             'CN',
                             'ATM',
                             'FOT',
                             'MCR-1',
                             'CTX-M', 'OXA48',
                             'SHV',
                             'TEM',
                             'KPC',
                             'NDM',
                             'IMP',
                             'QnrS',
                             'Qnr A']
    f, axes = create_subplots(TrainX.shape[1],
                              figsize=(12, 9))

    index = TrainX.columns.to_list().index(feature_name)

    for idx, (feature, ax) in enumerate(zip(TrainX.columns, axes.flat)):

        clr_f_is_cat = False
        if feature in CAT_FEAT:
            clr_f_is_cat = True

        if feature in CAT_FEAT:
            if feature == 'season':
                dec_feature = TrainX.loc[:, feature].values
            else:
                feature_enc = f'{feature}_enc'
                if feature_enc in encoders:
                    enc = encoders[feature_enc]
                    dec_feature = pd.Series(
                        enc.inverse_transform(TrainX.loc[:, feature].values),
                                            name=feature)


            color_feature = dec_feature

            # instead of showing the actual names, we still prefer to
            # label encode them because actual names takes very large
            # space in figure/axes
            color_feature = pd.Series(
                LabelEncoder().fit_transform(color_feature),
                                      name=feature)
        else:
            color_feature = TrainX.loc[:, feature]

        color_feature.name = LABEL_MAP.get(color_feature.name, color_feature.name)


        ax = shap_scatter(
            shap_values[:, index],
            feature_data=TrainX.loc[:, feature_name].values,
            feature_name=LABEL_MAP.get(feature_name, feature_name),
            color_feature=color_feature,
            color_feature_is_categorical=clr_f_is_cat,
            show=False,
            alpha=0.5,
            ax=ax
        )
        ax.set_ylabel('')

    plt.tight_layout()

    if save:
        feature_name = feature_name.replace(' ', '')
        feature_name = feature_name.replace('/', '_')
        plt.savefig(f"results/figures/shap_interac_{feature_name}_{name}.png", dpi=600, bbox_inches="tight")

    plt.show()

    return
def shap_scatter(
        feature_shap_values:np.ndarray,
        feature_data:Union[pd.DataFrame, np.ndarray, pd.Series],
        color_feature:pd.Series=None,
        color_feature_is_categorical:bool = False,
        feature_name:str = '',
        show_hist:bool = True,
        palette_name = "tab10",
        s:int = 70,
        ax:plt.Axes = None,
        edgecolors='black',
        linewidth=0.8,
        alpha=0.8,
        show:bool = True,
        **scatter_kws,
):
    """

    :param feature_shap_values:
    :param feature_data:
    :param color_feature:
    :param color_feature_is_categorical:
        whether the color feautre is categorical or not. If categorical then the
        array ``color_feature`` is supposed to contain categorical (either string or numerical) values which
        are then mapped to the color and are used prepare the legend box.
    :param feature_name:
    :param show_hist:
    :param palette_name:
        only relevant if ``color_feature_is_categorical`` is True
    :param s:
    :param ax:
    :param edgecolors:
    :param linewidth:
    :param alpha:
    :param show:
    :param scatter_kws:
    :return:
    """
    if ax is None:
        fig, ax = plt.subplots()

    if color_feature is None:
        c = None
    else:
        if color_feature_is_categorical:
            if isinstance(palette_name, (tuple, list)):
                assert len(palette_name) == len(color_feature.unique())
                rgb_values = palette_name
            else:
                rgb_values = sns.color_palette(palette_name, color_feature.unique().__len__())
            color_map = dict(zip(color_feature.unique(), rgb_values))
            c= color_feature.map(color_map)
        else:
            c = color_feature.values.reshape(-1,)

    _, pc = scatter(
        feature_data,
        feature_shap_values,
        c=c,
        s=s,
        marker="o",
        edgecolors=edgecolors,
        linewidth=linewidth,
        alpha=alpha,
        ax=ax,
        show=False,
        **scatter_kws
    )

    if color_feature is not None:
        feature_wrt_name = ' '.join(color_feature.name.split('_'))
        if color_feature_is_categorical:
            # add a legend
            handles = [Line2D([0], [0],
                              marker='o',
                              color='w',
                              markerfacecolor=v,
                              label=k, markersize=8) for k, v in color_map.items()]

            ax.legend(title=feature_wrt_name,
                  handles=handles, bbox_to_anchor=(1.05, 1),
                      loc='upper left',
                      title_fontsize=14
                      )
        else:
            fig = ax.get_figure()
            # increasing aspect will make the colorbar thin
            cbar = fig.colorbar(pc, ax=ax, aspect=20)
            cbar.ax.set_ylabel(feature_wrt_name,
                               rotation=90, labelpad=14)

            cbar.set_alpha(1)
            cbar.outline.set_visible(False)

    ax.set_xlabel(feature_name)
    ax.set_ylabel(f"SHAP value for {feature_name}")
    ax.axhline(0, color='grey', linewidth=1.3, alpha=0.3, linestyle='--')

    if show_hist:
        if isinstance(feature_data, (pd.Series, pd.DataFrame)):
            feature_data = feature_data.values
        x = feature_data

        if len(x) >= 500:
            bin_edges = 50
        elif len(x) >= 200:
            bin_edges = 20
        elif len(x) >= 100:
            bin_edges = 10
        else:
            bin_edges = 5

        ax2 = ax.twinx()

        xlim = ax.get_xlim()

        ax2.hist(x.reshape(-1,), bin_edges,
                 range=(xlim[0], xlim[1]),
                 density=False, facecolor='#000000', alpha=0.1, zorder=-1)
        ax2.set_ylim(0, len(x))
        ax2.set_yticks([])

    if show:
        plt.show()

    return ax

Total running time of the script: ( 0 minutes 0.036 seconds)

Gallery generated by Sphinx-Gallery