提交 3396f1f8 编写于 作者: P PyCaret

updated classification.py

上级 3ac8234e
......@@ -54,6 +54,7 @@ def setup(data,
session_id = None,
experiment_name = None, #added in pycaret==1.0.1
logging = True, #added in pycaret==1.0.1
log_plots = False, #added in pycaret==1.0.1
log_profile = False, #added in pycaret==1.0.1
silent=False,
verbose=True, #added in pycaret==1.0.1
......@@ -765,7 +766,7 @@ def setup(data,
#declaring global variables to be accessed by other functions
global X, y, X_train, X_test, y_train, y_test, seed, prep_pipe, experiment__,\
folds_shuffle_param, n_jobs_param, create_model_container, master_model_container, display_container, exp_name_log, logging_param
folds_shuffle_param, n_jobs_param, create_model_container, master_model_container, display_container, exp_name_log, logging_param, log_plots_param
#generate seed to be used globally
if session_id is None:
......@@ -1170,6 +1171,12 @@ def setup(data,
#create exp_name_log param incase logging is False
exp_name_log = 'no_logging'
#create an empty log_plots_param
if log_plots:
log_plots_param = True
else:
log_plots_param = False
#sample estimator
if sample_estimator is None:
model = LogisticRegression()
......@@ -1667,7 +1674,7 @@ def setup(data,
mlflow.log_artifact("Test.csv")
return X, y, X_train, X_test, y_train, y_test, seed, prep_pipe, experiment__,\
folds_shuffle_param, n_jobs_param, html_param, create_model_container, master_model_container, display_container, exp_name_log, logging_param
folds_shuffle_param, n_jobs_param, html_param, create_model_container, master_model_container, display_container, exp_name_log, logging_param, log_plots_param
def create_model(estimator = None,
......@@ -2312,10 +2319,20 @@ def create_model(estimator = None,
holdout_score.to_html('Holdout.html', col_space=65, justify='left')
mlflow.log_artifact('Holdout.html')
# Log AUC and Confusion Matrix plot
if log_plots_param:
plot_model(model, plot = 'auc', verbose=False, save=True, system=False)
mlflow.log_artifact('AUC.png')
plot_model(model, plot = 'confusion_matrix', verbose=False, save=True, system=False)
mlflow.log_artifact('Confusion Matrix.png')
# Log model and transformation pipeline
save_model(model, 'Trained Model', verbose=False)
mlflow.log_artifact('Trained Model' + '.pkl')
clear_output()
progress.value += 1
#storing into experiment
......@@ -2973,7 +2990,10 @@ def ensemble_model(estimator,
return model
def plot_model(estimator,
plot = 'auc'):
plot = 'auc',
save = False, #added in pycaret 1.0.1
verbose = True, #added in pycaret 1.0.1
system = True): #added in pycaret 1.0.1
"""
......@@ -3106,7 +3126,9 @@ def plot_model(estimator,
#progress bar
progress = ipw.IntProgress(value=0, min=0, max=5, step=1 , description='Processing: ')
display(progress)
if verbose:
if html_param:
display(progress)
#ignore warnings
import warnings
......@@ -3135,8 +3157,13 @@ def plot_model(estimator,
visualizer.score(X_test, y_test)
progress.value += 1
clear_output()
visualizer.show(outpath="image2.png")
visualizer.poof()
if save or log_plots_param:
if system:
visualizer.show(outpath="AUC.png")
else:
visualizer.show(outpath="AUC.png", clear_figure=True)
else:
visualizer.show()
elif plot == 'threshold':
......@@ -3148,8 +3175,14 @@ def plot_model(estimator,
visualizer.score(X_test, y_test)
progress.value += 1
clear_output()
visualizer.poof()
if save or log_plots_param:
if system:
visualizer.show(outpath="Threshold Curve.png")
else:
visualizer.show(outpath="Threshold Curve.png", clear_figure=True)
else:
visualizer.show()
elif plot == 'pr':
from yellowbrick.classifier import PrecisionRecallCurve
......@@ -3160,7 +3193,13 @@ def plot_model(estimator,
visualizer.score(X_test, y_test)
progress.value += 1
clear_output()
visualizer.poof()
if save or log_plots_param:
if system:
visualizer.show(outpath="Precision Recall.png")
else:
visualizer.show(outpath="Precision Recall.png", clear_figure=True)
else:
visualizer.show()
elif plot == 'confusion_matrix':
......@@ -3172,8 +3211,14 @@ def plot_model(estimator,
visualizer.score(X_test, y_test)
progress.value += 1
clear_output()
visualizer.poof()
if save or log_plots_param:
if system:
visualizer.show(outpath="Confusion Matrix.png")
else:
visualizer.show(outpath="Confusion Matrix.png", clear_figure=True)
else:
visualizer.show()
elif plot == 'error':
from yellowbrick.classifier import ClassPredictionError
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册