0


【Spark原理系列】自定义聚合函数 UserDefinedAggregateFunction 原理用法示例源码分析

Spark 自定义聚合函数(UDAF)UserDefinedAggregateFunction 原理用法示例源码分析

文章目录

原理

UserDefinedAggregateFunction

是 Spark SQL 中用于实现用户自定义聚合函数(UDAF)的抽象类。通过继承该类并实现其中的方法,可以创建自定义的聚合函数,并在 Spark SQL 中使用。

UserDefinedAggregateFunction

原理是基于 Spark SQL 的聚合操作流程。当一个 UDAF 被应用到 DataFrame 上时,Spark SQL 会将 UDAF 转化为一个

AggregateExpression

对象,其中包含了对应的

ScalaUDAF

实例和聚合操作类型。然后,Spark SQL 会对数据进行分组、聚合等操作,并调用 UDAF 中的方法来执行具体的聚合逻辑。

在具体实现中,

UserDefinedAggregateFunction

提供了一系列方法,如

inputSchema

bufferSchema

dataType

等,用于定义输入参数的数据类型、缓冲区中值的数据类型以及返回值的数据类型。同时,它还提供了

initialize

update

merge

evaluate

方法,用于初始化聚合缓冲区、更新缓冲区、合并缓冲区以及计算最终结果。此外,

UserDefinedAggregateFunction

还提供了

apply

distinct

方法,用于创建

Column

对象,方便在 DataFrame 中使用自定义聚合函数。

总的来说,

UserDefinedAggregateFunction

通过定义一系列方法,使得用户可以灵活地实现自定义的聚合逻辑,并将其应用到 Spark SQL 的聚合操作中。通过这种方式,用户可以扩展 Spark SQL 中的聚合能力,满足特定的业务需求。

用法

方法名描述

inputSchema

返回聚合函数的输入参数的数据类型

StructType

bufferSchema

返回聚合缓冲区中值的数据类型的

StructType

dataType

返回聚合函数的返回值数据类型

deterministic

返回布尔值,指示此函数是否是确定性的。

initialize(buffer)

初始化给定的聚合缓冲区。

update(buffer, input)

使用新的输入数据更新聚合缓冲区。

merge(buffer1, buffer2)

合并两个聚合缓冲区。

evaluate(buffer)

根据给定的聚合缓冲区计算最终结果

apply(exprs)

使用给定的

Column

参数创建一个

Column

对象来调用 UDAF。

distinct(exprs)

使用给定的不同值的

Column

参数创建一个

Column

对象来调用 UDAF。

update(i, value)

更新可变聚合缓冲区的第 i 个值。

示例

packageorg.example.sparkimportorg.apache.spark.sql.{Row, SparkSession}importorg.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}importorg.apache.spark.sql.types._
object AverageVecDemo {// 创建自定义聚合函数class MyAverage extends UserDefinedAggregateFunction {// 输入参数的数据类型def inputSchema: StructType =new StructType().add("value", DoubleType)// 聚合缓冲区中值的数据类型def bufferSchema: StructType =new StructType().add("sum", DoubleType).add("count", LongType)// 返回值的数据类型def dataType: DataType = DoubleType

    // 是否是确定性的def deterministic:Boolean=true// 初始化聚合缓冲区def initialize(buffer: MutableAggregationBuffer):Unit={
      buffer(0)=0.0// sum
      buffer(1)=0L// count}// 更新聚合缓冲区def update(buffer: MutableAggregationBuffer, input: Row):Unit={if(!input.isNullAt(0)){val value = input.getDouble(0)
        buffer(0)= buffer.getDouble(0)+ value
        buffer(1)= buffer.getLong(1)+1}}// 合并两个聚合缓冲区def merge(buffer1: MutableAggregationBuffer, buffer2: Row):Unit={
      buffer1(0)= buffer1.getDouble(0)+ buffer2.getDouble(0)
      buffer1(1)= buffer1.getLong(1)+ buffer2.getLong(1)}// 计算最终结果def evaluate(buffer: Row):Any={
      buffer.getDouble(0)/ buffer.getLong(1)}}def main(args: Array[String]):Unit={val spark = SparkSession.builder().appName("UDAFDemo").master("local[*]").getOrCreate()importspark.implicits._

    // 创建一个 DataFrameval data = Seq(1.0,2.0,3.0,4.0,5.0).toDF("value")// 注册自定义聚合函数
    spark.udf.register("myAverage",new MyAverage)// 使用自定义聚合函数进行聚合操作val result = data.selectExpr("myAverage(value) as average")

    result.show()

    spark.stop()}}//+-------+//|average|//+-------+//|    3.0|//+-------+

这个示例中,我们创建了一个自定义聚合函数

MyAverage

,用于计算输入数据列的平均值。然后,我们将该函数注册到 Spark 的 UDF(用户定义函数)中,并在 DataFrame 中使用

selectExpr

方法调用它进行聚合操作。最后,我们展示了聚合结果。

源码

importorg.apache.spark.annotation.InterfaceStability
importorg.apache.spark.sql.{Column, Row}importorg.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}importorg.apache.spark.sql.execution.aggregate.ScalaUDAF
importorg.apache.spark.sql.types._

/**
 * 实现用户自定义聚合函数(UDAF)的基类。
 *
 * @since 1.5.0
 */@InterfaceStability.Stableabstractclass UserDefinedAggregateFunction extends Serializable {/**
   * `StructType` 表示此聚合函数的输入参数的数据类型。
   * 例如,如果一个[[UserDefinedAggregateFunction]]期望两个输入参数,
   * 分别是`DoubleType`和`LongType`类型,返回的`StructType`将如下所示:
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * 此`StructType`的字段名称仅用于标识对应的输入参数。用户可以选择名称以标识输入参数。
   *
   * @since 1.5.0
   */def inputSchema: StructType

  /**
   * `StructType` 表示聚合缓冲区中值的数据类型。
   * 例如,如果一个[[UserDefinedAggregateFunction]]的缓冲区有两个值
   * (即两个中间值),分别是`DoubleType`和`LongType`类型,
   * 返回的`StructType`将如下所示:
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * 此`StructType`的字段名称仅用于标识对应的缓冲区值。用户可以选择名称以标识输入参数。
   *
   * @since 1.5.0
   */def bufferSchema: StructType

  /**
   * [[UserDefinedAggregateFunction]] 返回值的 `DataType`。
   *
   * @since 1.5.0
   */def dataType: DataType

  /**
   * 如果此函数是确定性的,则返回true,即给定相同的输入,总是返回相同的输出。
   *
   * @since 1.5.0
   */def deterministic:Boolean/**
   * 初始化给定的聚合缓冲区,即聚合缓冲区的初始值。
   *
   * 即应用于两个初始缓冲区的合并函数只应返回初始缓冲区本身,即
   * `merge(initialBuffer, initialBuffer)` 应等于 `initialBuffer`。
   *
   * @since 1.5.0
   */def initialize(buffer: MutableAggregationBuffer):Unit/**
   * 使用来自`input`的新输入数据更新给定的聚合缓冲区`buffer`。
   *
   * 每行输入调用一次此方法。
   *
   * @since 1.5.0
   */def update(buffer: MutableAggregationBuffer, input: Row):Unit/**
   * 合并两个聚合缓冲区,并将更新后的缓冲区值存储回`buffer1`。
   *
   * 当我们合并两个部分聚合的数据时,会调用此方法。
   *
   * @since 1.5.0
   */def merge(buffer1: MutableAggregationBuffer, buffer2: Row):Unit/**
   * 根据给定的聚合缓冲区计算此[[UserDefinedAggregateFunction]]的最终结果。
   *
   * @since 1.5.0
   */def evaluate(buffer: Row):Any/**
   * 使用给定的`Column`s作为输入参数创建此UDAF的`Column`。
   *
   * @since 1.5.0
   */@scala.annotation.varargsdef apply(exprs: Column*): Column ={val aggregateExpression =
      AggregateExpression(
        ScalaUDAF(exprs.map(_.expr),this),
        Complete,
        isDistinct =false)
    Column(aggregateExpression)}/**
   * 使用给定的`Column`s的不同值作为输入参数创建此UDAF的`Column`。
   *
   * @since 1.5.0
   */@scala.annotation.varargsdef distinct(exprs: Column*): Column ={val aggregateExpression =
      AggregateExpression(
        ScalaUDAF(exprs.map(_.expr),this),
        Complete,
        isDistinct =true)
    Column(aggregateExpression)}}/**
 * 表示可变聚合缓冲区的`Row`。
 *
 * 不建议在Spark之外扩展它。
 *
 * @since 1.5.0
 */@InterfaceStability.Stableabstractclass MutableAggregationBuffer extends Row {/** 更新此缓冲区的第i个值。 */def update(i:Int, value:Any):Unit}
gregateExpression)}}/**
 * 表示可变聚合缓冲区的`Row`。
 *
 * 不建议在Spark之外扩展它。
 *
 * @since 1.5.0
 */@InterfaceStability.Stableabstractclass MutableAggregationBuffer extends Row {/** 更新此缓冲区的第i个值。 */def update(i:Int, value:Any):Unit}

参考链接

https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html

标签: spark 大数据 scala

本文转载自: https://blog.csdn.net/wang2leee/article/details/135136485
版权归原作者 BigDataMLApplication 所有, 如有侵权,请联系我们删除。

“【Spark原理系列】自定义聚合函数 UserDefinedAggregateFunction 原理用法示例源码分析”的评论:

还没有评论