0


【极简spark教程】spark聚合函数

聚合函数分为两类,一种是spark内置的常用聚合函数,一种是用户自定义聚合函数

UDAF

UDAF的定义

  1. 继承UserDefinedAggregateFunction
  2. 定义输入数据的schema
  3. 定义缓存的数据结构
  4. 聚合函数返回值的数据类型
  5. 定义聚合函数的幂等性,一般为true
  6. 初始化缓存
  7. 更新缓存
  8. 合并缓存
  9. 计算结果
importorg.apache.spark.{SparkConf, SparkContext}importorg.apache.spark.sql.{Row, SparkSession}importorg.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}importorg.apache.spark.sql.types._
importorg.apache.spark.sql.functions._

object avg extends UserDefinedAggregateFunction {// 定义输入数据的schema,需要指定列名,但在实际使用中这里指定的列名没有意义overridedef inputSchema: StructType = StructType(List(StructField("input", LongType)))// 缓存的数据结构,bufferSchema定义了缓存的数据结构具有sum和count两个字段overridedef bufferSchema: StructType = StructType(List(StructField("sum", LongType), StructField("count", LongType)))// 聚合函数返回值的数据类型:返回值的类型必需与下面的evaluate返回类型一致overridedef dataType: DataType = LongType
  // 聚合函数的幂等性,相同输入总是能得到相同输出overridedef deterministic:Boolean=true// 初始化缓存:根据bufferSchema,缓存具有sum和count两个字段,这里会对sum和count两个变量的值进行初始化// tips:缓存buffer是MutableAggregationBuffer类型,你可以简单理解buffer就是一个数组// tips:在这里buffer是具有代表了sum和count数值的二元数组overridedef initialize(buffer: MutableAggregationBuffer):Unit={
    buffer(0)=0L
    buffer(1)=0L}// 更新缓存:接受并处理输入数据,更新buffer// tips:在实际处理中,输入数据是DataFrame,DataFrame是由多个Row组成的,每个Row会逐个传递给update,更新buffer中的值// tips:必须对输入的input进行检查,防止input.getLong(i)出现越界报错ArrayIndexOutOfBoundsExceptionoverridedef update(buffer: MutableAggregationBuffer, input: Row):Unit={if(input.isNullAt(0))return
    buffer(0)= buffer.getLong(0)+ input.getLong(0)
    buffer(1)= buffer.getLong(1)+1}// 合并缓存:对多个buffer进行合并,这里的合并方式类似于reduce,新来的buffer都会和左侧合并后的大buffer进行合并,合并后保留大buffer的值,buffer2会被丢弃overridedef merge(buffer1: MutableAggregationBuffer, buffer2: Row):Unit={
    buffer1(0)= buffer1.getLong(0)+ buffer2.getLong(0)
    buffer1(1)= buffer1.getLong(1)+ buffer2.getLong(1)}// 计算结果:根据所有buffer合并后的值,计算最终的结果// tips:这里所有buffer合并后对值为整体的sum和count,计算整体的sum和count比值,我们得到最终的平均值overridedef evaluate(buffer: Row):Any={
    buffer.getLong(0)/ buffer.getLong(1)}}

UDAF的使用

  1. 在sparkSQL中使用UDAF
  2. 在DataFrame中使用UDAF
def main(args: Array[String]):Unit={val spark = SparkSession.builder().master("local").getOrCreate()// 注册UDAF函数,和UDF函数一样
  spark.udf.register("my_avg", avg)// test.txt文件内容// score|user// 90|Tom// 95|Jerry// 100|Claris// sparkSQL读取文件,创建视图// sparkSQL的第一步:读取文件并创建视图
  spark.read.option("header","true").option("sep","|").csv("test.txt").createOrReplaceTempView("v_user")// sparkSQL的第二步:在spark.sql中调用UDAF,求分数的均值
  spark.sql("select u_avg(score) as avg_score from v_user").show()// DataFrame的第一步:读取文件,创建DataFrameval df1 = spark.read.option("header","true").option("sep","|").csv("data/other/test.txt")// DataFrame的第二步:在df.agg中,使用callUDF调用UDAF函数,求分数的均值val df2 = df1.agg(callUDF("my_avg",col("score")))
  df2.show(false)}
标签: spark 大数据

本文转载自: https://blog.csdn.net/ljp7759325/article/details/124443612
版权归原作者 鱼摆摆 所有, 如有侵权,请联系我们删除。

“【极简spark教程】spark聚合函数”的评论:

还没有评论