Spark SQL 快速入门系列(6) | 一文教你如何自定义 SparkSQL 函数

举报
不温卜火 发表于 2020/12/03 01:14:47 2020/12/03
【摘要】   大家好,我是不温卜火,是一名计算机学院大数据专业大二的学生,昵称来源于成语—不温不火,本意是希望自己性情温和。作为一名互联网行业的小白,博主写博客一方面是为了记录自己的学习过程,另一方面是总结自己所犯的错误希望能够帮助到很多和自己一样处于起步阶段的萌新。但由于水平有限,博客中难免会有一些错误出现,有纰漏之处恳请各位大佬不吝赐教!暂时只有csdn这一个平台,博客...

  大家好,我是不温卜火,是一名计算机学院大数据专业大二的学生,昵称来源于成语—不温不火,本意是希望自己性情温和。作为一名互联网行业的小白,博主写博客一方面是为了记录自己的学习过程,另一方面是总结自己所犯的错误希望能够帮助到很多和自己一样处于起步阶段的萌新。但由于水平有限,博客中难免会有一些错误出现,有纰漏之处恳请各位大佬不吝赐教!暂时只有csdn这一个平台,博客主页:https://buwenbuhuo.blog.csdn.net/

  本片博文为大家带来的是一文教你如何自定义 SparkSQL 函数。
1


2

一. 自定义 UDF 函数

  在Shell窗口中可以通过spark.udf功能用户可以自定义函数。

scala> val df = spark.read.json("examples/src/main/resources/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]

scala> df.show
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+
// 注册一个 udf 函数: toUpper是函数名, 第二个参数是函数的具体实现
scala> spark.udf.register("toUpper", (s: String) => s.toUpperCase)
res1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))

scala> df.createOrReplaceTempView("people")

scala> spark.sql("select toUpper(name), age from people").show
+-----------------+----+
|UDF:toUpper(name)| age|
+-----------------+----+
| MICHAEL|null|
| ANDY|  30|
| JUSTIN|  19|
+-----------------+----+


  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

二. 用户自定义聚合函数

  强类型的Dataset弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数

2.1 弱类型UDF(求和)

  • 1.源码
package com.buwenbuhuo.spark.sql.day01.udf

import com.buwenbuhuo.spark.sql.day01.Pelple
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

import scala.collection.immutable.Nil

/**
 **
 *
 * @author 不温卜火
 * *
 * @create 2020-08-03 12:13
 **
 * MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
object UDAFDemo {
  def main(args: Array[String]): Unit = { // 在sql中,聚合函数如何使用 val spark: SparkSession = SparkSession.builder() .appName("UDAFDemo") .master("local[2]") .getOrCreate() import spark.implicits._ val df: DataFrame = spark.read.json("d:/users.json") df.createOrReplaceTempView("user") // 注册聚合函数 spark.udf.register("mySum",new MySum) spark.sql("select mySum(age) from user").show spark.close()
  }
}

class MySum extends UserDefinedAggregateFunction { // 用来定义输入的数据类型  10.1 12.2 100
  override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil) // 缓冲区的类型
  override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::Nil) // 最终聚合结果的类型
  override def dataType: DataType = DoubleType // 相同的输入是否返回相同的输出
  override def deterministic: Boolean = true // 对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = { // 在缓冲区集合中初始化和 buffer(0) = 0D  // 等价于buffer.update(0,0D) } // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { // input是指的使用聚合函数的时候,缓过来的参数封装到了Row if(!input.isNullAt(0)){ // 考虑到传字段可能是null val v: Double = input.getAs[Double](0)  // getDouble(0) buffer(0) = buffer.getDouble(0) + v }
  } // 分区间的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { // 把buffer1和buffer2 的缓冲聚合到一起,然后再把值写回到buffer1 buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
  } // 返回最初的输出值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}


  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 2. 运行结果

3

2.2 弱类型UDF(求均值)

  • 1. 源码
package com.buwenbuhuo.spark.sql.day01.udf

import com.buwenbuhuo.spark.sql.day01.Pelple
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

import scala.collection.immutable.Nil

/**
 **
 *
 * @author 不温卜火
 * *
 * @create 2020-08-03 12:13
 **
 * MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
object UDAFDemo1 {
  def main(args: Array[String]): Unit = { // 在sql中,聚合函数如何使用 val spark: SparkSession = SparkSession.builder() .appName("UDAFDemo1") .master("local[2]") .getOrCreate() import spark.implicits._ val df: DataFrame = spark.read.json("d:/users.json") df.createOrReplaceTempView("user") // 注册聚合函数 spark.udf.register("myAvg",new MyAvg) spark.sql("select myAvg(age) from user").show spark.close()
  }
}

class MyAvg extends UserDefinedAggregateFunction { // 用来定义输入的数据类型  10.1 12.2 100
  override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil) // 缓冲区的类型
  override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::StructField("count",LongType)::Nil) // 最终聚合结果的类型
  override def dataType: DataType = DoubleType // 相同的输入是否返回相同的输出
  override def deterministic: Boolean = true // 对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = { // 在缓冲区集合中初始化和 buffer(0) = 0D  // 等价于buffer.update(0,0D) buffer(1) = 0L } // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { // input是指的使用聚合函数的时候,缓过来的参数封装到了Row if(!input.isNullAt(0)){ // 考虑到传字段可能是null val v: Double = input.getAs[Double](0)  // getDouble(0) buffer(0) = buffer.getDouble(0) + v buffer(1) = buffer.getLong(1) + 1L }
  } // 分区间的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { // 把buffer1和buffer2 的缓冲聚合到一起,然后再把值写回到buffer1 buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  } // 返回最初的输出值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0)/buffer.getLong(1)
}


  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 2. 运行结果

4

2.3 强类型UDF(求均值)

  • 1. 源码
package com.buwenbuhuo.spark.sql.day01.udf


import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession, TypedColumn}
import org.apache.spark.sql.expressions.Aggregator


/**
 **
 *
 * @author 不温卜火
 * *
 * @create 2020-08-03 12:43
 **
 * MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
case class Dog(name:String,age:Int)

case class AgeAvg(sum:Int,count:Int){
  def avg = sum.toDouble / count
}

object UDAFDemo3 {
  def main(args: Array[String]): Unit = { // 在sql中,聚合函数如何使用 val spark: SparkSession = SparkSession.builder() .appName("UDAFDemo3") .master("local[2]") .getOrCreate() import spark.implicits._ val ds: Dataset[Dog] = List(Dog("大黄", 6), Dog("小黄", 2), Dog("中黄", 4)).toDS() // 强类型的使用方式 val avg: TypedColumn[Dog, Double] = new MyAvg2().toColumn.name("avg") val result: Dataset[Double] = ds.select(avg) result.show() spark.close() }
}
class MyAvg2 extends Aggregator[Dog,AgeAvg,Double]{ // 对缓冲区进行初始化
  override def zero: AgeAvg = AgeAvg(0,0) // 聚合(分区内聚合)
  override def reduce(b: AgeAvg, a: Dog): AgeAvg = a match { // 如果是dog对象,则把年龄相加,个数加1 case Dog(name,age) => AgeAvg(b.sum + age , b.count + 1) // 如果是null,则原封不动返回 case _ => b
  } // 分区间的聚合
  override def merge(b1: AgeAvg, b2: AgeAvg): AgeAvg = { AgeAvg(b1.sum + b2.sum,b1.count + b2.count)
  } // 返回最终的值
  override def finish(reduction: AgeAvg): Double = reduction.avg // 对缓冲区进行编码
  override def bufferEncoder: Encoder[AgeAvg] = Encoders.product // 如果是样例,就直接返回这个编码器就行了 //对返回值进行编码
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
/*
强类型UDF
 */

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 2. 运行结果
    5
      本次的分享就到这里了,

14

  好书不厌读百回,熟读课思子自知。而我想要成为全场最靓的仔,就必须坚持通过学习来获取更多知识,用知识改变命运,用博客见证成长,用行动证明我在努力。
  如果我的博客对你有帮助、如果你喜欢我的博客内容,请“点赞” “评论”“收藏”一键三连哦!听说点赞的人运气不会太差,每一天都会元气满满呦!如果实在要白嫖的话,那祝你开心每一天,欢迎常来我博客看看。
  码字不易,大家的支持就是我坚持下去的动力。点赞后不要忘了关注我哦!

15

16

文章来源: buwenbuhuo.blog.csdn.net,作者:不温卜火,版权归原作者所有,如需转载,请联系作者。

原文链接:buwenbuhuo.blog.csdn.net/article/details/107766831

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。