机器学习的练功方式(五)——模型选择及调优

举报
ArimaMisaki 发表于 2022/08/08 22:54:30 2022/08/08
【摘要】 文章目录 5 模型选择及调优5.1 数据增强5.2 过拟合5.3 交叉验证5.4 超参数搜索——网格搜索 5 模型选择及调优 5.1 数据增强 有时候,你和你的老板说你数据不够,它是...

5 模型选择及调优

5.1 数据增强

有时候,你和你的老板说你数据不够,它是不会理你的。老板会发问:为什么你是做机器学习的要那么多数据干嘛,让机器去做不就行了。

对于这种问题有时候即使无语但你也不能正面拆穿,否则你的工作就不用干了。而为了解决数据集不足的问题,我们通常会采用数据增强

这个名词看似高大上,实际上就是把数据集经过某些变换,从而产生新的数据集。

这种方法多用于图片识别上,将图片结果左右对称变换,或反转,或偏移角度来达到拥有更多数据集的目的。

5.2 过拟合

统计学习的目的是使学到的模型不仅对已知数据而且对未知数据都能有很好的预测能力。不同的学习方法会给出不同的模型。当损失函数给定时,基于损失函数的模型的训练误差和模型的测试误差就自然成为学习方法评估的标准。注意,统计学习方法具体采用的损失函数未必是评估时使用的损失函数。当然,让两者一致是比较理想的。

训练误差的大小,对判断给定的问题是不是一个容易学习的问题是有意义的,但本质上不重要。测试误差反映了学习方法对未知的测试数据集的预测能力,是学习中的重要概念。显然,给定两种学习方法,测试误差小的方法具有更好的预测能力,是更有效的方法。

通常将学习方法对未知数据的预测能力称为泛化能力

当假设空间含有不同复杂度的模型时,就要面临模型选择的问题。我们希望选择或学习一个合适的模型。如果在假设空间中存在真模型,那么所选择的模型应该逼近真模型。具体地,所选择的模型要与真模型的参数个数相同,所选择的模型的参数向量与真模型的参数向量相近。

如果一味追求提高对训练数据的预测能力,所选模型的复杂度则往往会比真模型更高。这种现象称为过拟合。过拟合是指学习时选择的模型所包含的参数过多,以至于出现这一模型对已知数据预测的很好,但对未知数据预测得很差的现象。可以说模型选择旨在避免过拟合并提高模型的预测能力。

也就是说,上述的话翻译成人话就是,我们不要那种能够完全贴合训练集的函数,那种函数训练出来训练集在上面跑挺牛逼,一到测试集就不行了。我们需要的是那种在训练集跑的差不多,对于测试集跑出来效果也很好的那种函数。

image-20220306172629610

现在我们有以上的数据集,我们要选择一个模型去拟合真模型,也就是M=0时图中画的曲线,那条曲线即为真模型。当然了,根据我们上面所说,我们要的是做到“差不多”即可,我们不要精度完全一样或者超过真模型。

当我们M=1,选择的是一条直线,这种模型其实是罔顾事实的做法,我们完全不考虑拟合的效果,一上来就乱套模型,这样会导致拟合数据的效果贼差。这种在古老的文献中称为“欠拟合”现象。

当M=3时,我们选择的模型已经接近数据所对应的真模型了,已经几乎拟合了,这时候的模型符合测试误差最小的学习目的了。

当M=9时,这时候就是所谓的过拟合现象了,由于参数设置过多,导致这条曲线几乎穿过了我们已知的所有的数据点。的确,他对已知数据预测很好(穿过了嘛),但是他对未知数据却预测很差(说不定下一个点不在这条线上,这就导致前面预测很准,后面误差越来越大)。

简单来说,想解决过拟合,实际上无非就是选择复杂度适当的模型,以达到使测试误差最小的学习目的。我们常用的模型选择方法:正则化交叉验证。关于正则化的学习我们在后面的学习中会接触到,我们这里要提到的是关于交叉验证

5.3 交叉验证

交叉验证(cross validation)简单来说就是将拿到的训练数据,再次分为训练集和验证集。其中验证集和测试集的功能一样,都是对训练集训练出来的模型进行评估。而交叉验证方法就是将训练集划分为训练集+测试集,测试集通常占1份,而训练集占k-1份,通过四次测试,每次更换不同的验证集来达到在有限的数据集中得出不同精度,得出4组模型结果;得出结果后取平均值作为最终结果。我们把上述的做法称为K折交叉验证。

K折交叉验证示意图如下:

image-20220306173643643

虽然K折交叉验证能够在k均值算法中起到优化K值的效果,那么如何来选取K值呢?什么时候才是最好呢?这就要进入我们的下一小节了。

5.4 超参数搜索——网格搜索

通常情况下,有很多参数时需要手动指定的(如K-近邻算法中的K值),这种叫做超参数。但是手动指定不准且计算复杂,所以我们要对模型预设几种超参数组合,每组超参数都采用交叉验证来进行评估,最后选出最优参数组合建立模型。

知道原理了就是动手写代码的时刻,我们看一下sklearn中有哪些库供我们调用。

sklearn.model_selection.GridSearchCV(estimator,param_grid = None,cv = None)

该API可以对估计器的指定参数值进行详尽搜索

  • estimator:估计器对象
  • param——grid:估计器参数,对应到knn中可以传入多个k值建立多个模型,以此评估哪个模型最好,传入时要用字典形式,如{“n_neighbors”:[1,3,5]}
  • cv:指定几折交叉验证,常用10折交叉验证

fit():输入训练数据

score():准确率

通过调用以下属性可以查看结果:

  • 最佳参数:best_params_
  • 最佳结果:best_score_
  • 最佳估计器:best_estimator_
  • 交叉验证结果:cv_results_

知道上面的原理,让我们对前一讲的KNN分类鸢尾花代码优化一下吧!

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV


def knn_iris():
    """用KNN算法对鸢尾花进行分类"""
    # 1 导入数据集
    iris = load_iris()

    # 2 划分数据集
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)

    # 3 特征工程:标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)

    # 4 实例化KNN算法预估器
    estimator = KNeighborsClassifier()

    # 选用合适的K值来选择多个模型
    param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}

    # 加入超参数网格搜索和交叉验证
    estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
    estimator.fit(x_train, y_train)

    # 5 模型评估
    # 方法1 直接比对真实值和预测值
    y_predict = estimator.predict(x_test)
    print("y_predict:\n", y_predict)
    print("直接对比真实值和预测值:\n", y_test == y_predict)

    # 方法2 计算准确率
    score = estimator.score(x_test, y_test)
    print("准确率为:\n", score)

    # 查看最佳参数
    print("KNN模型最佳参数:\n", estimator.best_params_)
    print("最佳结果:\n", estimator.best_score_)
    print("最佳估计器:\n", estimator.best_estimator_)
    print("交叉验证结果:\n", estimator.cv_results_)


# 调用方法
knn_iris()

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

文章来源: blog.csdn.net,作者:ArimaMisaki,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/chengyuhaomei520/article/details/123327676

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。