Note
Go to the end to download the full example code
utils
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from typing import Union
import seaborn as sns
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
import shap
from ai4water import Model
from ai4water.utils.utils import TrainTestSplit
from ai4water.utils.utils import get_version_info
from easy_mpl import plot, boxplot, pie, bar_chart, scatter
from easy_mpl.utils import create_subplots
LABEL_MAP = {
'season': 'Season'
}
SAVE = False
CATEGORIES = {
"Physicochemical": [
'T (℃)', 'pH', 'TDS (mg/L)', 'DO (mg/L)', 'TN (mg/L)',
'season'
],
"ARGs": ['MCR-1', #gene
'CTX-M', # gene
'SHV', # gene
'TEM', # gene
'OXA48', # gene
'KPC', # gene
'NDM', # gene
'IMP', # gene
'QnrS', # gene
'Qnr A'],
"Antibiotics": ['F', 'DO', 'CN', 'ATM', 'FOT',],
}
def read_data():
fpath = os.path.join(os.getcwd(), "data", "meropenem_isolates.xlsx")
fall20_df = pd.read_excel(fpath,
sheet_name=0, header=2)
fall20_df['season'] = 0
win2021_df = pd.read_excel(fpath,
sheet_name=1, header=2)
win2021_df['season'] = 1
spr21_df = pd.read_excel(fpath,
sheet_name=2, header=2)
spr21_df['season'] = 2
sum21_df = pd.read_excel(fpath,
sheet_name=3, header=2)
sum21_df['season'] = 3
frames = [fall20_df, win2021_df, spr21_df, sum21_df]
result = pd.concat(frames)
result = result.reset_index()
result.pop('index')
# put the season column at the start
result.insert(0, "season", result.pop("season"))
columns = ["Qnr A", 'Qnr B', 'QnrS', 'IMP', 'NDM', 'BIC', 'KPC',
'OXA48', 'TEM', 'SHV', 'CTX-M', 'MCR-1', 'AMP', 'CAZ',
'AMC', 'CTX', 'FEP', 'ATM', 'MEM', 'CIP', 'C', 'NA',
'CN', 'DO', 'F', 'FOT', 'SXT']
cat_inp_features = ['DO', 'F']
for col in columns:
result[col] = result[col].str.strip()
return result
# function for OHE
def _ohe_encoder(df:pd.DataFrame, col_name:str)->tuple:
assert isinstance(col_name, str)
encoder = OneHotEncoder(sparse=False)
ohe_cat = encoder.fit_transform(df[col_name].values.reshape(-1, 1))
cols_added = [f"{col_name}_{i}" for i in range(ohe_cat.shape[-1])]
df[cols_added] = ohe_cat
df = df.drop(columns=col_name)
return df, cols_added, encoder
def _label_encoder(df:pd.DataFrame, col_name:str)->tuple:
assert isinstance(col_name, str)
encoder = LabelEncoder()
le_cat = encoder.fit_transform(df[col_name].values.reshape(-1, 1))
df[col_name] = le_cat
return df, col_name, encoder
def make_data(
inputs,
target,
encode=True,
encode_type_ab:str = "le"
):
"""
:param inputs:
names of input features
:param target:
the target to use
:param encode:
whether to encode antibiotic and ARGs or not
:param encode_type_ab:
the type of encoding to be applied on Antibiotic data
:return:
"""
_def_inputs = ['season', 'T (℃)', 'pH', 'TDS (mg/L)', 'DO (mg/L)', 'TN (mg/L)']
_args = ['Qnr A', 'Qnr B', 'QnrS', 'IMP', 'NDM', 'BIC', 'KPC', 'OXA48',
'TEM', 'SHV', 'CTX-M', 'MCR-1']
whole_data = read_data()
assert isinstance(inputs, list)
assert isinstance(target, str)
assert encode_type_ab in ["le", "ohe"]
parameters = inputs + [target]
data = pd.DataFrame()
encoders = {}
for para in parameters:
if para in _def_inputs:
data[para] = whole_data[para].copy()
elif para != target:
if para in _args:
if encode:
encoded_data, cols, encoders[f"{para}_enc"] = _label_encoder(whole_data, para)
data[cols] = encoded_data[cols].copy()
else:
data[para] = whole_data[para].copy()
else:
if encode:
if encode_type_ab == "ohe":
encoded_data, cols, encoders[f"{para}_enc"] = _ohe_encoder(whole_data, para)
data[cols] = encoded_data[cols].copy()
else:
encoded_data, cols, encoders[f"{para}_enc"] = _label_encoder(whole_data, para)
data[cols] = encoded_data[cols].copy()
else:
data[para] = whole_data[para].copy()
data[target] = whole_data[target]
# replacing strings in target with binary
data[target] = data[target].replace(['-'], 0)
data[target] = data[target].replace(['+'], 1)
return data, encoders
def return_train_test(target, scenario=None, return_encoders=False)->tuple:
inputs = ['season', 'T (℃)', 'pH', 'TDS (mg/L)',
'DO (mg/L)', 'TN (mg/L)',
'F', #antibiotic
'DO', #antibiotic
'CN', #antibiotic
'ATM', #antibiotic
'FOT', #antibiotic
'MCR-1', #gene
'CTX-M', # gene
'SHV', # gene
'TEM', # gene
'OXA48', # gene
'KPC', # gene
'NDM', # gene
'IMP', # gene
'QnrS', # gene
'Qnr A' # gene
]
antibiotics = ['F', 'DO', 'CN', 'ATM', 'FOT']
genes = ['MCR-1', 'CTX-M', 'SHV', 'TEM', 'OXA48',
'KPC', 'NDM', 'IMP', 'QnrS', 'Qnr A']
wq = ['season', 'T (℃)', 'pH', 'TDS (mg/L)',
'DO (mg/L)', 'TN (mg/L)']
if scenario=='no_genes':
for i in genes:
inputs.remove(i)
elif scenario=='no_antibiotics':
inputs.remove(target)
for i in antibiotics:
inputs.remove(i)
elif scenario=='only_wq':
features = antibiotics + genes
for i in features:
inputs.remove(i)
elif scenario=='only_antibiotics':
features = genes + wq
for i in features:
inputs.remove(i)
elif scenario in [None, 'all_inputs']:
inputs.remove(target)
else:
raise ValueError(f"Invalid scenario {scenario}")
data, encoders = make_data(
inputs=inputs, target=target, encode=True)
data = data.dropna()
input_features = data.columns.tolist()[0:-1]
output_features = data.columns.tolist()[-1:]
TrainX, TestX, TrainY, TestY = TrainTestSplit(seed=313).split_by_random(
data[input_features],
data[output_features]
)
if return_encoders:
return TrainX, TrainY, TestX, TestY, inputs, encoders
return TrainX, TrainY, TestX, TestY, inputs
def all_inputs_CTX_M(model_name):
target = 'CTX-M'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target)
if model_name=='LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'n_estimators': 139,
'boosting_type': 'gbdt',
'num_leaves': 494,
'learning_rate': 0.09328089700948024}},
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'box-cox', 'features': ['T (℃)']},
{'method': 'vast', 'features': ['pH']},
{'method': 'box-cox', 'features': ['TDS (mg/L)']},
{'method': 'quantile_normal', 'features': ['DO (mg/L)']},
{'method': 'quantile', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model={"HistGradientBoostingClassifier": {
"learning_rate": 0.2714740258492424,
"max_iter": 428,
"max_depth": 85,
"max_leaf_nodes": 26,
"min_samples_leaf": 17,
"l2_regularization": 0.4170610185230958,
"random_state": 313
}},
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'zscore', 'features': ['T (℃)']},
{'method': 'scale', 'features': ['pH']},
{'method': 'log2', 'features': ['TDS (mg/L)']},
{'method': 'yeo-johnson', 'features': ['DO (mg/L)']},
{'method': 'log2', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def all_inputs_OXA48(model_name):
target = 'OXA48'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target)
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'vast', 'features': ['T (℃)']},
{'method': 'vast', 'features': ['pH']},
{'method': 'log', 'features': ['TDS (mg/L)']},
#{'method': 'log2', 'features': ['DO (mg/L)']},
{'method': 'log2', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'log', 'features': ['T (℃)']},
{'method': 'box-cox', 'features': ['pH']},
{'method': 'quantile_normal', 'features': ['TDS (mg/L)']},
{'method': 'minmax', 'features': ['DO (mg/L)']},
{'method': 'sqrt', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def all_inputs_TEM(model_name):
target = 'TEM'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target)
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'n_estimators': 144,
'boosting_type': 'gbdt',
'num_leaves': 423,
'learning_rate': 0.09147939344768526}},
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'minmax', 'features': ['T (℃)']},
{'method': 'yeo-johnson', 'features': ['pH']},
{'method': 'log10', 'features': ['TDS (mg/L)']},
{'method': 'box-cox', 'features': ['DO (mg/L)']},
{'method': 'quantile', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[#{'method': 'zscore', 'features': ['T (℃)']},
{'method': 'quantile_normal', 'features': ['pH']},
{'method': 'center', 'features': ['TDS (mg/L)']},
{'method': 'center', 'features': ['DO (mg/L)']},
{'method': 'robust', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def all_inputs_MCR_1(model_name):
target = 'MCR-1'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target)
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'boosting_type': 'gbdt',
'num_leaves': 185,
'learning_rate': 0.07853514838993073,
'n_estimators': 85,
}},
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'vast', 'features': ['T (℃)']},
{'method': 'scale', 'features': ['pH']},
{'method': 'quantile', 'features': ['TDS (mg/L)']},
{'method': 'zscore', 'features': ['DO (mg/L)']},
{'method': 'log2', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'log10', 'features': ['T (℃)']},
{'method': 'center', 'features': ['pH']},
{'method': 'robust', 'features': ['TDS (mg/L)']},
{'method': 'log10', 'features': ['DO (mg/L)']},
{'method': 'quantile_normal', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def get_all_inputs_model(
target:str,
model_name:str = "LGBMClassifier"
):
if target == 'MCR-1':
return all_inputs_MCR_1(model_name)
elif target == "TEM":
return all_inputs_TEM(model_name)
elif target == "CTX-M":
return all_inputs_CTX_M(model_name)
elif target == "OXA48":
return all_inputs_OXA48(model_name)
else:
raise ValueError
def no_gene_CTX_M(model_name, from_config=False):
target = 'CTX-M'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130840')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'boosting_type': 'goss',
'num_leaves': 377,
'learning_rate': 0.01511065501651556,
'n_estimators': 91,
}},
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'center', 'features': ['T (℃)']},
{'method': 'quantile', 'features': ['pH']},
{'method': 'yeo-johnson', 'features': ['TDS (mg/L)']},
#{'method': 'quantile_normal', 'features': ['DO (mg/L)']},
{'method': 'log', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'zscore', 'features': ['T (℃)']},
{'method': 'box-cox', 'features': ['pH']},
{'method': 'quantile_normal', 'features': ['TDS (mg/L)']},
{'method': 'vast', 'features': ['DO (mg/L)']},
{'method': 'quantile_normal', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def no_gene_OXA48(model_name, from_config=False):
target = 'OXA48'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130714')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'scale', 'features': ['T (℃)']},
{'method': 'robust', 'features': ['pH']},
{'method': 'zscore', 'features': ['TDS (mg/L)']},
{'method': 'minmax', 'features': ['DO (mg/L)']},
{'method': 'quantile_normal', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'box-cox', 'features': ['T (℃)']},
{'method': 'quantile_normal', 'features': ['pH']},
{'method': 'minmax', 'features': ['TDS (mg/L)']},
{'method': 'scale', 'features': ['DO (mg/L)']},
{'method': 'scale', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def no_gene_TEM(model_name, from_config=False):
target = 'TEM'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130553')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'robust', 'features': ['T (℃)']},
{'method': 'center', 'features': ['pH']},
#{'method': 'quantile_normal', 'features': ['TDS (mg/L)']},
{'method': 'log2', 'features': ['DO (mg/L)']},
{'method': 'yeo-johnson', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[#{'method': 'zscore', 'features': ['T (℃)']},
{'method': 'center', 'features': ['pH']},
{'method': 'minmax', 'features': ['TDS (mg/L)']},
{'method': 'minmax', 'features': ['DO (mg/L)']},
{'method': 'center', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def no_gene_MCR_1(model_name, from_config=False):
target = 'MCR-1'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_genes')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130424')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'boosting_type': 'goss',
'num_leaves': 245,
'learning_rate': 0.08783749114657222,
'n_estimators': 139,
}},
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'box-cox', 'features': ['T (℃)']},
{'method': 'vast', 'features': ['pH']},
{'method': 'box-cox', 'features': ['TDS (mg/L)']},
{'method': 'quantile_normal', 'features': ['DO (mg/L)']},
{'method': 'quantile', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'pareto', 'features': ['T (℃)']},
{'method': 'vast', 'features': ['pH']},
{'method': 'quantile', 'features': ['TDS (mg/L)']},
{'method': 'minmax', 'features': ['DO (mg/L)']},
{'method': 'minmax', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def get_no_gene_model(
target:str,
from_config,
model_name:str = "LGBMClassifier"
):
if target == 'MCR-1':
return no_gene_MCR_1(model_name, from_config)
elif target == "TEM":
return no_gene_TEM(model_name, from_config)
elif target == "CTX-M":
return no_gene_CTX_M(model_name, from_config)
elif target == "OXA48":
return no_gene_OXA48(model_name, from_config)
else:
raise ValueError
def no_antibiotic_CTX_M(model_name):
target = 'CTX-M'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_antibiotics')
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'boosting_type': 'gbdt',
'num_leaves': 283,
'learning_rate': 0.08556696555457217,
'n_estimators': 141,
}},
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'box-cox', 'features': ['T (℃)']},
{'method': 'scale', 'features': ['pH']},
{'method': 'log10', 'features': ['TDS (mg/L)']},
{'method': 'vast', 'features': ['DO (mg/L)']},
{'method': 'log2', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'log2', 'features': ['T (℃)']},
{'method': 'log', 'features': ['pH']},
{'method': 'yeo-johnson', 'features': ['TDS (mg/L)']},
{'method': 'robust', 'features': ['DO (mg/L)']},
{'method': 'pareto', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def no_antibiotic_OXA48(model_name):
target = 'OXA48'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_antibiotics')
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'scale', 'features': ['T (℃)']},
{'method': 'pareto', 'features': ['pH']},
{'method': 'pareto', 'features': ['TDS (mg/L)']},
#{'method': 'vast', 'features': ['DO (mg/L)']},
{'method': 'log2', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'log', 'features': ['T (℃)']},
{'method': 'sqrt', 'features': ['pH']},
{'method': 'quantile', 'features': ['TDS (mg/L)']},
{'method': 'quantile_normal', 'features': ['DO (mg/L)']},
{'method': 'log10', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def no_antibiotic_TEM(model_name):
target = 'TEM'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_antibiotics')
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'yeo-johnson', 'features': ['T (℃)']},
{'method': 'log10', 'features': ['pH']},
{'method': 'zscore', 'features': ['TDS (mg/L)']},
{'method': 'log', 'features': ['DO (mg/L)']},
{'method': 'sqrt', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'minmax', 'features': ['T (℃)']},
{'method': 'center', 'features': ['pH']},
{'method': 'quantile', 'features': ['TDS (mg/L)']},
{'method': 'minmax', 'features': ['DO (mg/L)']},
#{'method': 'log2', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def no_antibiotic_MCR_1(model_name):
target = 'MCR-1'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'no_antibiotics')
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'box-cox', 'features': ['T (℃)']},
{'method': 'log2', 'features': ['pH']},
{'method': 'scale', 'features': ['TDS (mg/L)']},
{'method': 'log', 'features': ['DO (mg/L)']},
{'method': 'log10', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'yeo-johnson', 'features': ['T (℃)']},
#{'method': 'scale', 'features': ['pH']},
{'method': 'quantile_normal', 'features': ['TDS (mg/L)']},
{'method': 'log', 'features': ['DO (mg/L)']},
{'method': 'pareto', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def only_wq_CTX_M(model_name, from_config=False):
target = 'CTX-M'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_wq')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130241')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'sqrt', 'features': ['T (℃)']},
{'method': 'quantile', 'features': ['pH']},
{'method': 'log2', 'features': ['TDS (mg/L)']},
{'method': 'log2', 'features': ['DO (mg/L)']},
{'method': 'quantile', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'zscore', 'features': ['T (℃)']},
{'method': 'zscore', 'features': ['pH']},
{'method': 'log10', 'features': ['TDS (mg/L)']},
{'method': 'log10', 'features': ['DO (mg/L)']},
{'method': 'zscore', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def only_wq_OXA48(model_name, from_config=False):
target = 'OXA48'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_wq')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230627_130130')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'box-cox', 'features': ['T (℃)']},
{'method': 'quantile', 'features': ['pH']},
{'method': 'vast', 'features': ['TDS (mg/L)']},
{'method': 'quantile', 'features': ['DO (mg/L)']},
{'method': 'center', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'log10', 'features': ['T (℃)']},
{'method': 'log2', 'features': ['pH']},
{'method': 'sqrt', 'features': ['TDS (mg/L)']},
{'method': 'scale', 'features': ['DO (mg/L)']},
{'method': 'scale', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def only_wq_TEM(model_name, from_config=False):
target = 'TEM'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_wq')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230627_125951')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'vast', 'features': ['T (℃)']},
{'method': 'scale', 'features': ['pH']},
{'method': 'center', 'features': ['TDS (mg/L)']},
{'method': 'sqrt', 'features': ['DO (mg/L)']},
{'method': 'pareto', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'robust', 'features': ['T (℃)']},
{'method': 'log2', 'features': ['pH']},
{'method': 'quantile', 'features': ['TDS (mg/L)']},
{'method': 'quantile', 'features': ['DO (mg/L)']},
{'method': 'minmax', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def only_wq_MCR_1(model_name, from_config=False):
target = 'MCR-1'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_wq')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230627_125012')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model='LGBMClassifier',
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'yeo-johnson', 'features': ['T (℃)']},
{'method': 'pareto', 'features': ['pH']},
{'method': 'log10', 'features': ['TDS (mg/L)']},
{'method': 'box-cox', 'features': ['DO (mg/L)']},
{'method': 'sqrt', 'features': ['TN (mg/L)']}
],
)
elif model_name=='HistGradientBoostingClassifier':
model = Model(
model="HistGradientBoostingClassifier",
input_features=inputs,
output_features=target,
verbosity=0,
x_transformation=[{'method': 'yeo-johnson', 'features': ['T (℃)']},
{'method': 'robust', 'features': ['pH']},
{'method': 'log', 'features': ['TDS (mg/L)']},
{'method': 'quantile_normal', 'features': ['DO (mg/L)']},
{'method': 'log2', 'features': ['TN (mg/L)']}
],
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def only_antibiotics_MCR_1(model_name, from_config=False):
target = 'MCR-1'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_antibiotics')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230705_164018')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'n_estimators': 150,
'boosting_type': 'goss',
'num_leaves': 500,
'learning_rate': 0.09734360597562812}},
input_features=inputs,
output_features=target,
verbosity=0,
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def only_antibiotics_CTX_M(model_name, from_config=False):
target = 'CTX-M'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_antibiotics')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230705_154741')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'n_estimators': 44,
'boosting_type': 'goss',
'num_leaves': 458,
'learning_rate': 0.045723247279341225}},
input_features=inputs,
output_features=target,
verbosity=0,
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def only_antibiotics_TEM(model_name, from_config=False):
target = 'TEM'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_antibiotics')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230705_154645')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'n_estimators': 149,
'boosting_type': 'goss',
'num_leaves': 483,
'learning_rate': 0.053619829014340085}},
input_features=inputs,
output_features=target,
verbosity=0,
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def only_antibiotics_OXA48(model_name, from_config=False):
target = 'OXA48'
TrainX, TrainY, TestX, TestY, inputs = return_train_test(target, 'only_antibiotics')
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', '20230705_164158')
cpath = os.path.join(path, 'config.json')
model = Model.from_config_file(config_path=cpath)
model.update_weights()
else:
if model_name == 'LGBMClassifier':
model = Model(
model={'LGBMClassifier': {'n_estimators': 59,
'boosting_type': 'goss',
'num_leaves': 494,
'learning_rate': 0.07385784480075765}},
input_features=inputs,
output_features=target,
verbosity=0,
)
else:
raise ValueError(f'Invalid {model_name}')
model.seed_everything(313)
model.fit(x=TrainX, y=TrainY.values)
return model
def plot_line(target, value, save=False):
num_columns_without_seaon = ['T (℃)', 'pH', 'TDS (mg/L)',
'DO (mg/L)', 'TN (mg/L)']
data = read_data()
SEASONS = {
0: "Fall",
1: "Winter",
2: "Spring",
3: "Summer"
}
amp_r = data.loc[data[target] == value]
groups = amp_r.groupby('season')
fig, axes = create_subplots(5, sharex="all")
if not isinstance(axes, np.ndarray):
axes = np.array([axes])
colors = ['#588F07', '#E79D33', '#34ACAF', '#DD6F6D']
for col, ax in zip(num_columns_without_seaon, axes.flatten()):
for (label, grp), color in zip(groups, colors):
legend = '__nolabel__'
if col == 'TN (mg/L)':
legend = SEASONS[label]
plot(grp[col].values, label=legend, color=color, ax=ax, show=False)
ax.set_ylabel(col)
if col == 'TN (mg/L)':
ax.legend(loc=(1.2, 0.2))
plt.suptitle(f'{target}_{value}')
plt.tight_layout()
if save:
plt.savefig(f"results/figures/line_{target}_value", bbox_inches="tight", dpi=600)
plt.show()
return
def plot_boxplot(target, value, save=False):
num_columns_without_seaon = ['T (℃)', 'pH', 'TDS (mg/L)',
'DO (mg/L)', 'TN (mg/L)']
data = read_data()
SEASONS = {
0: "Fall",
1: "Winter",
2: "Spring",
3: "Summer"
}
amp_r = data.loc[data[target] == value]
groups = amp_r.groupby('season')
f, axes = create_subplots(5)
colors = ['#588F07', '#E79D33', '#34ACAF', '#DD6F6D']
if groups.size().size==3:
colors = ['#588F07', '#E79D33', '#34ACAF']
elif groups.size().size==2:
colors = ['#588F07', '#E79D33']
elif groups.size().size==1:
colors = '#588F07'
for col, ax in zip(num_columns_without_seaon, axes.flatten()):
data_ = [grp[col].values for _, grp in groups]
labels = [SEASONS[label] for label, _ in groups]
boxplot(data_, labels=labels, ax=ax, show=False,
line_color=colors,
flierprops=dict(ms=2.0))
ax.set_ylabel(col)
# if col in ['TN (mg/L)', 'DO (mg/L)']:
# ax.set_xticks([1, 2, 3, 4])
# ax.set_xticklabels(labels)
plt.suptitle(f'{target}_{value}')
plt.tight_layout()
if save:
plt.savefig(f"results/figures/box_{target}_value", bbox_inches="tight", dpi=600)
plt.show()
return
def hist():
num_columns = ['T (℃)', 'pH', 'TDS (mg/L)',
'DO (mg/L)', 'TN (mg/L)', 'season']
data = read_data()
fig, axes = create_subplots(len(num_columns))
if not isinstance(axes, np.ndarray):
axes = np.array([axes])
for ax, col in zip(axes.flat, num_columns):
sns.histplot(data[col], ax=ax,
)
# ax.legend(fontsize=10)
#plt.legend()
plt.tight_layout()
plt.show()
return
def lineplot(palette=None, save=False, name=''):
num_columns = ['T (℃)', 'pH', 'TDS (mg/L)',
'DO (mg/L)', 'TN (mg/L)', 'season']
data = read_data()
fig, axes = create_subplots(len(num_columns))
if not isinstance(axes, np.ndarray):
axes = np.array([axes])
colors = ['#588F07', '#E79D33', '#34ACAF', '#DD6F6D']
for ax, col in zip(axes.flat, num_columns):
sns.lineplot(data[col], ax=ax,
palette=palette,
# show=False,
# title=label,
)
plt.tight_layout()
if save:
plt.savefig(f"results/figures/line_{name}", bbox_inches="tight", dpi=600)
plt.show()
return
def season_line(season_data, title):
num_columns_without_seaon = ['T (℃)', 'pH', 'TDS (mg/L)',
'DO (mg/L)', 'TN (mg/L)']
fig, axes = create_subplots(len(num_columns_without_seaon))
if not isinstance(axes, np.ndarray):
axes = np.array([axes])
for ax, col in zip(axes.flat, num_columns_without_seaon):
sns.lineplot(season_data[col], ax=ax,
palette='Spectral',
)
#ax.legend(fontsize=10)
#plt.legend()
plt.suptitle(title)
plt.tight_layout()
plt.show()
return
def set_rcParams(**kwargs):
# https://matplotlib.org/stable/tutorials/introductory/customizing.html
_kwargs = {
'axes.labelsize': '14',
'xtick.labelsize': '12',
'ytick.labelsize': '12',
'legend.title_fontsize': '12',
'axes.titleweight': 'bold',
'axes.titlesize': '14',
'axes.labelweight': 'bold',
'font.family': 'Times New Roman'
}
if kwargs:
_kwargs.update(kwargs)
for k,v in _kwargs.items():
plt.rcParams[k] = v
return
def version_info():
info = get_version_info()
info['shap'] = shap.__version__
for k,v in info.items():
print(k, v)
return
def bar_pie(
data:pd.DataFrame,
ax:plt.Axes=None,
save:bool = True,
name:str = '',
show:bool = True,
):
if ax is None:
f, ax = plt.subplots(figsize=(7, 9))
sv_bar = data['mean_shap'].copy()
colors = data['colors'].unique()
feature_names = data['features'].tolist()
ax_ = bar_chart(
sv_bar,
[LABEL_MAP[n] if n in LABEL_MAP else n for n in feature_names],
bar_labels=np.round(sv_bar, 4),
bar_label_kws={'label_type': 'edge',
'fontsize': 10,
'weight': 'bold',
"fmt": '%.4f',
'padding': 1.5
},
show=False,
sort=True,
color=data['colors'].to_list(),
ax=ax
)
ax_.spines[['top', 'right']].set_visible(False)
ax_.set_xlabel(xlabel='mean(|SHAP value|)')
# ax.set_xticklabels(ax.get_xticks().astype(float))
ax_.set_yticklabels(ax_.get_yticklabels())
labels = data['classes'].unique()
handles = [plt.Rectangle((0, 0), 1, 1,
color=colors[idx]) for idx, l in enumerate(labels)]
ax_.legend(handles, labels, loc='lower right')
ax_.xaxis.set_major_locator(plt.MaxNLocator(4))
# %%
seg_colors = tuple(colors)
# Change the saturation of seg_colors to 70% for the interior segments
rgb = mcolors.to_rgba_array(seg_colors)[:, :-1]
hsv = mcolors.rgb_to_hsv(rgb)
hsv[:, 1] = 0.7 * hsv[:, 1]
interior_colors = mcolors.hsv_to_rgb(hsv)
fractions = np.array([
data.loc[data['classes'] == 'Physicochemical']['mean_shap'].sum(),
data.loc[data['classes'] == 'Antibiotics']['mean_shap'].sum(),
#data.loc[data['classes'] == 'ARGs']['mean_shap'].sum(),
])
labels = data['classes'].unique().tolist()
fractions /= fractions.sum()
ax2 = inset_axes(ax, width='50%', height='50%',
loc=5)
_, texts = pie(fractions=fractions,
colors=seg_colors,
labels=labels,
wedgeprops=dict(edgecolor="w", width=0.03), radius=1,
autopct=None,
textprops=dict(fontsize=12),
startangle=90, counterclock=False, show=False,
ax=ax2)
texts[0].set_fontsize(12)
_, texts, autotexts = pie(fractions=fractions,
colors=interior_colors,
autopct='%1.0f%%',
textprops=dict(fontsize=24),
wedgeprops=dict(edgecolor="w"), radius=1 - 2 * 0.03,
startangle=90, counterclock=False, ax=ax2, show=False)
texts[0].set_fontsize(12)
if save:
plt.savefig(f"results/figures/shap_bar_{name}.png", dpi=600, bbox_inches="tight")
if show:
plt.tight_layout()
plt.show()
return
def shap_scatter_plots(
shap_values:np.ndarray,
TrainX:pd.DataFrame,
feature_name:str,
encoders,
save:bool = True,
name: str = '',
):
"""
It is expected that the columns in TrainX and shap_values have same order.
Parameters
----------
shap_values:
TrainX:
feature_name:
encoders :
save :
name : str
"""
CAT_FEAT = ['season', 'F',
'DO',
'CN',
'ATM',
'FOT',
'MCR-1',
'CTX-M', 'OXA48',
'SHV',
'TEM',
'KPC',
'NDM',
'IMP',
'QnrS',
'Qnr A']
f, axes = create_subplots(TrainX.shape[1],
figsize=(12, 9))
index = TrainX.columns.to_list().index(feature_name)
for idx, (feature, ax) in enumerate(zip(TrainX.columns, axes.flat)):
clr_f_is_cat = False
if feature in CAT_FEAT:
clr_f_is_cat = True
if feature in CAT_FEAT:
if feature == 'season':
dec_feature = TrainX.loc[:, feature].values
else:
feature_enc = f'{feature}_enc'
if feature_enc in encoders:
enc = encoders[feature_enc]
dec_feature = pd.Series(
enc.inverse_transform(TrainX.loc[:, feature].values),
name=feature)
color_feature = dec_feature
# instead of showing the actual names, we still prefer to
# label encode them because actual names takes very large
# space in figure/axes
color_feature = pd.Series(
LabelEncoder().fit_transform(color_feature),
name=feature)
else:
color_feature = TrainX.loc[:, feature]
color_feature.name = LABEL_MAP.get(color_feature.name, color_feature.name)
ax = shap_scatter(
shap_values[:, index],
feature_data=TrainX.loc[:, feature_name].values,
feature_name=LABEL_MAP.get(feature_name, feature_name),
color_feature=color_feature,
color_feature_is_categorical=clr_f_is_cat,
show=False,
alpha=0.5,
ax=ax
)
ax.set_ylabel('')
plt.tight_layout()
if save:
feature_name = feature_name.replace(' ', '')
feature_name = feature_name.replace('/', '_')
plt.savefig(f"results/figures/shap_interac_{feature_name}_{name}.png", dpi=600, bbox_inches="tight")
plt.show()
return
def shap_scatter(
feature_shap_values:np.ndarray,
feature_data:Union[pd.DataFrame, np.ndarray, pd.Series],
color_feature:pd.Series=None,
color_feature_is_categorical:bool = False,
feature_name:str = '',
show_hist:bool = True,
palette_name = "tab10",
s:int = 70,
ax:plt.Axes = None,
edgecolors='black',
linewidth=0.8,
alpha=0.8,
show:bool = True,
**scatter_kws,
):
"""
:param feature_shap_values:
:param feature_data:
:param color_feature:
:param color_feature_is_categorical:
whether the color feautre is categorical or not. If categorical then the
array ``color_feature`` is supposed to contain categorical (either string or numerical) values which
are then mapped to the color and are used prepare the legend box.
:param feature_name:
:param show_hist:
:param palette_name:
only relevant if ``color_feature_is_categorical`` is True
:param s:
:param ax:
:param edgecolors:
:param linewidth:
:param alpha:
:param show:
:param scatter_kws:
:return:
"""
if ax is None:
fig, ax = plt.subplots()
if color_feature is None:
c = None
else:
if color_feature_is_categorical:
if isinstance(palette_name, (tuple, list)):
assert len(palette_name) == len(color_feature.unique())
rgb_values = palette_name
else:
rgb_values = sns.color_palette(palette_name, color_feature.unique().__len__())
color_map = dict(zip(color_feature.unique(), rgb_values))
c= color_feature.map(color_map)
else:
c = color_feature.values.reshape(-1,)
_, pc = scatter(
feature_data,
feature_shap_values,
c=c,
s=s,
marker="o",
edgecolors=edgecolors,
linewidth=linewidth,
alpha=alpha,
ax=ax,
show=False,
**scatter_kws
)
if color_feature is not None:
feature_wrt_name = ' '.join(color_feature.name.split('_'))
if color_feature_is_categorical:
# add a legend
handles = [Line2D([0], [0],
marker='o',
color='w',
markerfacecolor=v,
label=k, markersize=8) for k, v in color_map.items()]
ax.legend(title=feature_wrt_name,
handles=handles, bbox_to_anchor=(1.05, 1),
loc='upper left',
title_fontsize=14
)
else:
fig = ax.get_figure()
# increasing aspect will make the colorbar thin
cbar = fig.colorbar(pc, ax=ax, aspect=20)
cbar.ax.set_ylabel(feature_wrt_name,
rotation=90, labelpad=14)
cbar.set_alpha(1)
cbar.outline.set_visible(False)
ax.set_xlabel(feature_name)
ax.set_ylabel(f"SHAP value for {feature_name}")
ax.axhline(0, color='grey', linewidth=1.3, alpha=0.3, linestyle='--')
if show_hist:
if isinstance(feature_data, (pd.Series, pd.DataFrame)):
feature_data = feature_data.values
x = feature_data
if len(x) >= 500:
bin_edges = 50
elif len(x) >= 200:
bin_edges = 20
elif len(x) >= 100:
bin_edges = 10
else:
bin_edges = 5
ax2 = ax.twinx()
xlim = ax.get_xlim()
ax2.hist(x.reshape(-1,), bin_edges,
range=(xlim[0], xlim[1]),
density=False, facecolor='#000000', alpha=0.1, zorder=-1)
ax2.set_ylim(0, len(x))
ax2.set_yticks([])
if show:
plt.show()
return ax
Total running time of the script: ( 0 minutes 0.036 seconds)