SPARK ML 出租车数据分析

举报
yk02901 发表于 2021/05/15 12:05:55 2021/05/15
【摘要】 通过分析出租车数据,然后使用KMeans对经纬度进行聚类,然后按照(类别,时间)进行分类,再统计每个类别每个时段的次数。数据格式以及意义:111,30.655325,104.072573,173749111,30.655346,104.072363,173828111,30.655377,104.120252,124057111,30.655439,104.088812,142016列一:出...

通过分析出租车数据,然后使用KMeans对经纬度进行聚类,然后按照(类别,时间)进行分类,再统计每个类别每个时段的次数。

数据格式以及意义:

111,30.655325,104.072573,173749
111,30.655346,104.072363,173828
111,30.655377,104.120252,124057
111,30.655439,104.088812,142016

列一:出租车ID

列二:经度

列三:纬度

列四:时间(例如:142016表示14点20分16秒)

步骤:

1.整理数据,分割成训练数据和测试数据,且使其符合KMeans模型训练的格式

2.使用训练好的模型对测试数据进行预测,然后对结果以(类别,小时时间 )进行count统计,结果为每个类别每个小时的总次数。

import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
object Tax1 {
  def main(arg:Array[String]):Unit={
    val spark = SparkSession.builder().appName("Taxi1").master("local[*]").getOrCreate()
//为读取的数据创建schema
    val taxiSchema = StructType(Array(
      StructField("id",IntegerType,true),
      StructField("tid",DoubleType,true),
      StructField("lat",DoubleType,true),
      StructField("time",StringType,true)
    ))
    val path = "/home/enche/data/taxi.csv"
    val data = spark.read.schema(taxiSchema).csv(path)
    //将tid和lat转换成训练使用的Vector类型
    val assembler = new VectorAssembler()
    val tid_lat = Array("tid","lat")
    assembler.setInputCols(tid_lat).setOutputCol("feature").transform(data)
    //按照8:2的比例随即分割数据,分别用于训练和测试
    val splitRait = Array(0.8, 0.2)
    val Array(train, test) = data.randomSplit(splitRait)
    //建立Kmeans,设置类别数为10 
    val kmeans = new KMeans().setK(10).setFeaturesCol("feature").setPredictionCol("prediction")
   //模型训练
    val model = kmeans.fit(train)
    //使用模型预测 测试数据
    val testResult = model.transform(test)
    //导入隐式转换,不然$"time"会出现错误 $ not e member of StringContext
    import spark.implicits._
    val time_prediction = testResult.select(substring($"time", 0, 2).alias("hour"), $"prediction")
    time_prediction.groupBy("hour","prediction").agg(("prediction","count")).orderBy(desc("count")).filter(row=>row.getAs(0)==15).take(10)
  }
}
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。