《scikit-learn机器学习常用算法原理及编程实战》—3.3.2 交叉验证数据集

举报
华章计算机 发表于 2019/05/31 17:04:36 2019/05/31
【摘要】 本书摘自《scikit-learn机器学习常用算法原理及编程实战》一书中的第3章,第3.3.2节,编著是黄永昌 .

3.3.2  交叉验证数据集

  另外一个更科学的方法是把数据集分成3份,分别是训练数据集、交叉验证数据集和测试数据集,推荐比例是6∶2∶2。

  为什么需要交叉验证数据集呢?以多项式模型选择为例。假设我们用一阶多项式、二阶多项式、三阶多项式……十阶多项式来拟合数据,多项式的阶数记为d。我们把数据集分成训练数据集和测试数据集。先用训练数据集训练出机器学习算法的参数θ(1),θ(2),θ(3),…, θ(10),这些参数分别代表从一阶到十阶多项式的模型参数。这10个模型里,哪个模型更好呢?这个时候我们会用测试数据集算出针对测试数据集的成本Jtest(θ),看哪个模型的测试数据集成本最低,我们就选择这个多项式来拟合数据,但实际上,这是有问题的。测试数据集的最主要功能是测试模型的准确性,需要确保模型“没见过”这些数据。现在我们用测试数据集来选择多项式的阶数d,相当于把测试数据集提前让模型“见过”了。这样选择出来的多项式阶数d本身就是对训练数据集最友好的一个,这样模型的准确性测试就失去了意义。

  为了解决这个问题,我们把数据分成3部分,随机选择60%的数据作为训练数据集,其成本记为J(θ),随机选择20%的数据作为交叉验证数据集(Cross Validation),其成本记为Jcv(θ),剩下的20%作为测试数据集,其成本记为Jtest(θ)。

  在模型选择时,我们使用训练数据集来训练算法参数,用交叉验证数据集来验证参数。选择交叉验证数据集的成本Jcv(θ)最小的多项式来作为数据拟合模型,最后再用测试数据集来测试选择出来的模型针对测试数据集的准确性。

  因为在模型选择过程中,我们使用了交叉验证数据集,所以筛选模型多项式阶数d的过程中,实际上并没有使用测试数据集。这样保证了使用测试数据集来计算成本衡量模型的准确性,我们选择出来的模型是没有“见过”测试数据,即测试数据集没有参与模型选择的过程。

  当然,在实践过程中,很多人直接把数据集分成训练数据集和测试数据集,而没有分出交叉验证数据集。这是因为很多时候并不需要横向去对比不同的模型。在工程上,大多数时候我们最主要的工作不是选择模型,而是获取更多数据、分析数据、挖掘数据。


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。