Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDocCN
pycaret
提交
3396f1f8
pycaret
项目概览
OpenDocCN
/
pycaret
通知
2
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
pycaret
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
3396f1f8
编写于
6月 05, 2020
作者:
P
PyCaret
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
updated classification.py
上级
3ac8234e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
56 addition
and
11 deletion
+56
-11
pycaret/classification.py
pycaret/classification.py
+56
-11
未找到文件。
pycaret/classification.py
浏览文件 @
3396f1f8
...
...
@@ -54,6 +54,7 @@ def setup(data,
session_id
=
None
,
experiment_name
=
None
,
#added in pycaret==1.0.1
logging
=
True
,
#added in pycaret==1.0.1
log_plots
=
False
,
#added in pycaret==1.0.1
log_profile
=
False
,
#added in pycaret==1.0.1
silent
=
False
,
verbose
=
True
,
#added in pycaret==1.0.1
...
...
@@ -765,7 +766,7 @@ def setup(data,
#declaring global variables to be accessed by other functions
global
X
,
y
,
X_train
,
X_test
,
y_train
,
y_test
,
seed
,
prep_pipe
,
experiment__
,
\
folds_shuffle_param
,
n_jobs_param
,
create_model_container
,
master_model_container
,
display_container
,
exp_name_log
,
logging_param
folds_shuffle_param
,
n_jobs_param
,
create_model_container
,
master_model_container
,
display_container
,
exp_name_log
,
logging_param
,
log_plots_param
#generate seed to be used globally
if
session_id
is
None
:
...
...
@@ -1170,6 +1171,12 @@ def setup(data,
#create exp_name_log param incase logging is False
exp_name_log
=
'no_logging'
#create an empty log_plots_param
if
log_plots
:
log_plots_param
=
True
else
:
log_plots_param
=
False
#sample estimator
if
sample_estimator
is
None
:
model
=
LogisticRegression
()
...
...
@@ -1667,7 +1674,7 @@ def setup(data,
mlflow
.
log_artifact
(
"Test.csv"
)
return
X
,
y
,
X_train
,
X_test
,
y_train
,
y_test
,
seed
,
prep_pipe
,
experiment__
,
\
folds_shuffle_param
,
n_jobs_param
,
html_param
,
create_model_container
,
master_model_container
,
display_container
,
exp_name_log
,
logging_param
folds_shuffle_param
,
n_jobs_param
,
html_param
,
create_model_container
,
master_model_container
,
display_container
,
exp_name_log
,
logging_param
,
log_plots_param
def
create_model
(
estimator
=
None
,
...
...
@@ -2312,10 +2319,20 @@ def create_model(estimator = None,
holdout_score
.
to_html
(
'Holdout.html'
,
col_space
=
65
,
justify
=
'left'
)
mlflow
.
log_artifact
(
'Holdout.html'
)
# Log AUC and Confusion Matrix plot
if
log_plots_param
:
plot_model
(
model
,
plot
=
'auc'
,
verbose
=
False
,
save
=
True
,
system
=
False
)
mlflow
.
log_artifact
(
'AUC.png'
)
plot_model
(
model
,
plot
=
'confusion_matrix'
,
verbose
=
False
,
save
=
True
,
system
=
False
)
mlflow
.
log_artifact
(
'Confusion Matrix.png'
)
# Log model and transformation pipeline
save_model
(
model
,
'Trained Model'
,
verbose
=
False
)
mlflow
.
log_artifact
(
'Trained Model'
+
'.pkl'
)
clear_output
()
progress
.
value
+=
1
#storing into experiment
...
...
@@ -2973,7 +2990,10 @@ def ensemble_model(estimator,
return
model
def
plot_model
(
estimator
,
plot
=
'auc'
):
plot
=
'auc'
,
save
=
False
,
#added in pycaret 1.0.1
verbose
=
True
,
#added in pycaret 1.0.1
system
=
True
):
#added in pycaret 1.0.1
"""
...
...
@@ -3106,7 +3126,9 @@ def plot_model(estimator,
#progress bar
progress
=
ipw
.
IntProgress
(
value
=
0
,
min
=
0
,
max
=
5
,
step
=
1
,
description
=
'Processing: '
)
display
(
progress
)
if
verbose
:
if
html_param
:
display
(
progress
)
#ignore warnings
import
warnings
...
...
@@ -3135,8 +3157,13 @@ def plot_model(estimator,
visualizer
.
score
(
X_test
,
y_test
)
progress
.
value
+=
1
clear_output
()
visualizer
.
show
(
outpath
=
"image2.png"
)
visualizer
.
poof
()
if
save
or
log_plots_param
:
if
system
:
visualizer
.
show
(
outpath
=
"AUC.png"
)
else
:
visualizer
.
show
(
outpath
=
"AUC.png"
,
clear_figure
=
True
)
else
:
visualizer
.
show
()
elif
plot
==
'threshold'
:
...
...
@@ -3148,8 +3175,14 @@ def plot_model(estimator,
visualizer
.
score
(
X_test
,
y_test
)
progress
.
value
+=
1
clear_output
()
visualizer
.
poof
()
if
save
or
log_plots_param
:
if
system
:
visualizer
.
show
(
outpath
=
"Threshold Curve.png"
)
else
:
visualizer
.
show
(
outpath
=
"Threshold Curve.png"
,
clear_figure
=
True
)
else
:
visualizer
.
show
()
elif
plot
==
'pr'
:
from
yellowbrick.classifier
import
PrecisionRecallCurve
...
...
@@ -3160,7 +3193,13 @@ def plot_model(estimator,
visualizer
.
score
(
X_test
,
y_test
)
progress
.
value
+=
1
clear_output
()
visualizer
.
poof
()
if
save
or
log_plots_param
:
if
system
:
visualizer
.
show
(
outpath
=
"Precision Recall.png"
)
else
:
visualizer
.
show
(
outpath
=
"Precision Recall.png"
,
clear_figure
=
True
)
else
:
visualizer
.
show
()
elif
plot
==
'confusion_matrix'
:
...
...
@@ -3172,8 +3211,14 @@ def plot_model(estimator,
visualizer
.
score
(
X_test
,
y_test
)
progress
.
value
+=
1
clear_output
()
visualizer
.
poof
()
if
save
or
log_plots_param
:
if
system
:
visualizer
.
show
(
outpath
=
"Confusion Matrix.png"
)
else
:
visualizer
.
show
(
outpath
=
"Confusion Matrix.png"
,
clear_figure
=
True
)
else
:
visualizer
.
show
()
elif
plot
==
'error'
:
from
yellowbrick.classifier
import
ClassPredictionError
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录