《Spark机器学习进阶实战》——3.4.3 训练模型

举报
华章计算机 发表于 2019/06/01 22:35:22 2019/06/01
【摘要】 本书摘自《Spark机器学习进阶实战》——书中的第3章,第3.4.3节,作者是马海平、于俊、吕昕、向海。

3.4.3 训练模型

现在我们已经从数据集中提取了基本的特征并将数据转化成了libsvm文件格式,接下来进入模型训练阶段。为了比较不同模型的性能,将训练朴素贝叶斯和SVM,其他诸如逻辑回归、决策树等留给读者扩展实践。

鉴于MLlib中RDD-based API将逐渐由Pipeline-based API替代,因此本书中所有模型的训练,优先使用Pipeline-based模式。你会发现这两种模式下,每一个模型的训练过程几乎一样,不同的是不同的算法有自己特定的参数。

1. 使用朴素贝叶斯分类器

使用朴素贝叶斯分类器训练分类模型是比较容易的,首先需要读取input目录中的libsvm格式的数据,并根据数据训练模型,详细代码参考ch03/AppClassification.scala。

本地测试参数和值如表3-3所示。

image.png

下面根据具体代码详细介绍如何一步一步地通过训练数据得到最终的贝叶斯分类模型的结果。

val data = spark.read.format("libsvm").load(input)

val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L)

// 训练一个贝叶斯模型

val model = new NaiveBayes().fit(trainingData)

因数据量不大,即使本地运行也会非常快地得到结果。

INFO DAGScheduler: ResultStage 9 (parquet at NaiveBayes.scala:262) finished in 1.256 s

INFO DAGScheduler: Job 5 finished: parquet at NaiveBayes.scala:262, took 1.733254 s

2. 使用SVM分类器

SVM是一种典型的二分类器,即它只能回答是正类还是负类,而应用分类是一个多分类问题,因此我们要从二分类器得到多类分类器。这里介绍两种常用方法。

1)一类对其他:每次仍然解一个二分类的问题。比如我们有3个类别,第一次就把类别1的样本定为正样本,其余类别2、3的样本合起来定为负样本,这样会得到一个二分类器,它能够指出一篇文章是否为第1类的;第二次把类别2的样本定为正样本,把类别1、3的样本合起来定为负样本,得到一个分类器。如此下去,便可以得到3个这样的二分类器(总是和类别的数目一致)。

2)一对一分类:每次也是解一个二分类的问题。每次选一个类的样本作为正类样本,而负样本则变成只选一个类。同上面的例子,训练一批分类器来回答“是第1类还是第2类”“是第1类还是第3类”和“是第2类还是第3类”。此时分类器的个数为k(k-1)/2。

我们选择方法2进行SVM分类模型训练,详细代码参考ch03/AppClassificationSVM.scala,本地测试参数和值如表3-4所示。

image.png

训练代码如下:

/* 5 * 4 / 2 = 10 */

val data = MLUtils.loadLibSVMFile(sc, input).cache()

val labels = data.map(_.label).distinct().collect().sorted.combinations(2)

.map(x => (x.mkString("_"), x))

labels.foreach {

case (tag, tuple) =>

val parts = data.filter(lp => tuple.contains(lp.label)).map{

case lp =>

val label = if (lp.label == tuple(0)) 0 else 1

        new LabeledPoint(label, lp.features)

    }

val splits = parts.randomSplit(Array(0.7, 0.3), seed = 11L)

val training = splits(0).cache()

val test = splits(1)

val svmAlg = new SVMWithSGD()

    svmAlg.optimizer

.setNumIterations(100)

      .setRegParam(0.01)

val model = svmAlg.run(training)

}


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。