数据集的划分——交叉验证法
本文作者:王 歌
文字编辑:戴 雯
技术总编:张 邯
Python云端培训课程火热招生中~
重大通知!!!爬虫俱乐部于2020年7月11日至14日在线上举行为期四天的Stata编程技术定制培训,招生工作已经圆满结束啦!!!
另外,应广大学员需求,爬虫俱乐部将于2020年7月25日至28日在线上举行Python编程技术训练营,本次培训采用理论与案例相结合的方式,帮助大家在掌握Python基本思想的基础上,学习科学计算技术与网络数据抓取技术,详情可点击《Python云端培训课程开始报名~》,或点击文末阅读原文直接提交报名信息呦~
导读
1方法介绍
2程序实现
我们这里依然使用的是鸢尾花的数据,同时使用Logistic回归训练模型。在sklearn中,通常使用 cross_val_predict
实现k折交叉验证,它返回的是一个使用交叉验证以后的输出值,若要返回准确度评分,可以使用 cross_val_score
。两者参数相同,第一个参数为所使用的分类器,第二个和第三个参数分别是属性值和标签值,最后一个参数 cv
确定折数。我们这里进行5折的交叉验证,程序如下:
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, cross_val_predict
iris_sample = load_iris()
x = iris_sample.data
y = iris_sample.target
lrclf = LogisticRegression()
predicted = cross_val_predict(lrclf, x, y, cv=5) #计算预测值
print('5折交叉验证预测值:', predicted)
scores = cross_val_score(lrclf, x, y, cv=5) #计算模型的评分情况
print('评分:', scores)
print('准确度:', metrics.accuracy_score(predicted, y)) #计算评分的均值
若使用留一法,则要使用 LeaveOneOut
类,没有参数需要设置。具体程序如下:
from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()
scores = cross_val_score(lrclf, x, y, cv=loo)
predicted = cross_val_predict(lrclf, x, y, cv=loo)
print('留一法预测值:', predicted)
scores = cross_val_score(lrclf, x, y, cv=loo)
print('评分:', scores)
print('准确度:', metrics.accuracy_score(predicted, y)
假设进行5次5折交叉验证,我们使用 RepeatedKFold
类,有三个参数:
(1) n_split
表示要划分的折数;
(2) n_repeats
表示重复几次;
(3) random_state
设置随机种子。
from sklearn.model_selection import RepeatedKFold
kf = RepeatedKFold(n_splits=5, n_repeats=5, random_state=0) #种子设为0
predicted = cross_val_predict(lrclf, x, y, cv=5)
print('5次5折交叉验证预测值:', predicted)
scores = cross_val_score(lrclf, x, y, cv=kf)
print('评分:', scores)
print('准确度:', metrics.accuracy_score(predicted, y))
在sklearn中还提供了许多其它交叉验证的类,比如使用 ShuffleSplit
类可以随机的把数据打乱,然后分为训练集和测试集;对于时间序列的数据,可以使用 TimeSeriesSplit
;若要实现分层抽样式的交叉验证,可以使用 StratifiedKFold
;分层随机划分可以使用 StratifiedShuffleSplit
,等等,大家可以根据自己的需要来选择合适的交叉验证方式。
PDF文本信息提取(二)
关于我们
微信公众号“Stata and Python数据分析”分享实用的stata、python等软件的数据处理知识,欢迎转载、打赏。我们是由李春涛教授领导下的研究生及本科生组成的大数据处理和分析团队。
1)必须原创,禁止抄袭;
2)必须准确,详细,有例子,有截图;
注意事项:
1)所有投稿都会经过本公众号运营团队成员的审核,审核通过才可录用,一经录用,会在该推文里为作者署名,并有赏金分成。
2)邮件请注明投稿,邮件名称为“投稿+推文名称”。
3)应广大读者要求,现开通有偿问答服务,如果大家遇到有关数据处理、分析等问题,可以在公众号中提出,只需支付少量赏金,我们会在后期的推文里给予解答。