Poison

MonotonicallyIncreasingID

最近查阅 Spark SQL 源码时看到了很久之前用过的获取单调递增 id 方法的实现,本文简要记录。之前在离线场景下有给记录生成唯一 id 的需求,当时使用了 Spark SQL 中的 monotonically_increasing_id 方法,其源码位于 functions.scala at v2.4.4:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/**
* A column expression that generates monotonically increasing 64-bit integers.
*
* The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
* The current implementation puts the partition ID in the upper 31 bits, and the record number
* within each partition in the lower 33 bits. The assumption is that the data frame has
* less than 1 billion partitions, and each partition has less than 8 billion records.
*
* As an example, consider a `DataFrame` with two partitions, each with 3 records.
* This expression would return the following IDs:
*
* {{{
* 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
* }}}
*
* @group normal_funcs
* @since 1.6.0
*/
def monotonically_increasing_id(): Column = withExpr { MonotonicallyIncreasingID() }

追踪源码至 MonotonicallyIncreasingID.scala at v2.4.4:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/**
* Returns monotonically increasing 64-bit integers.
*
* The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
* The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits
* represent the record number within each partition. The assumption is that the data frame has
* less than 1 billion partitions, and each partition has less than 8 billion records.
*
* Since this expression is stateful, it cannot be a case object.
*/
@ExpressionDescription(
usage = """
_FUNC_() - Returns monotonically increasing 64-bit integers. The generated ID is guaranteed
to be monotonically increasing and unique, but not consecutive. The current implementation
puts the partition ID in the upper 31 bits, and the lower 33 bits represent the record number
within each partition. The assumption is that the data frame has less than 1 billion
partitions, and each partition has less than 8 billion records.
The function is non-deterministic because its result depends on partition IDs.
""")
case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {

/**
* Record ID within each partition. By being transient, count's value is reset to 0 every time
* we serialize and deserialize and initialize it.
*/
@transient private[this] var count: Long = _

@transient private[this] var partitionMask: Long = _

override protected def initializeInternal(partitionIndex: Int): Unit = {
count = 0L
partitionMask = partitionIndex.toLong << 33
}

override def nullable: Boolean = false

override def dataType: DataType = LongType

override protected def evalInternal(input: InternalRow): Long = {
val currentCount = count
count += 1
partitionMask + currentCount
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
val partitionMaskTerm = "partitionMask"
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = FalseLiteral)
}

override def prettyName: String = "monotonically_increasing_id"

override def sql: String = s"$prettyName()"

override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID()
}

该算法实现其实非常简单,生成的 id 占用 8 个字节,即 64 位,分区 id 占用高 31 位,记录 id 占用低 33 位。使用分区 id 左移 33 位然后再加上当前分区该记录的 id 即可得到该 64 位 id 值。该实现与 Twitter 的 Snowflake 算法很相似,都处于分布式的环境中,均未采用一个单独的发号器来实现。

使用该方法时遇到的一个问题就是该算法的起始值为 0,在之前的业务场景中,我们将离线数据生成了 id 并写入至 Hive 数据表后,同时使用 INSERT INTO 往 MySQL 写入了一份。为了保证数据一致,在 INSERT INTO 至 MySQL 时,使用的该算法生成的 id,最后发现竟然有记录不一致,经过一阵排查,原来是 id 为 0 的数据导致。在 MySQL 中,如果往自增主键写入值 0,此时会被当作写入 NULL 处理,MySQL 会为该字段获取自增值,导致这条 id 为 0 的记录看起来消失了,而多出了一条记录。之前为了处理这个问题,我在代码里将 monotonically_increasing_id 多加了 1 以规避该情况,业务代码如下:

1
2
3
4
5
val bizDF = spark.sql(s"SELECT ...")
.withColumn("id", monotonically_increasing_id)
// The starting value of monotonically_increasing_id is 0, and MySQL treats the 0 value as NULL to generate self-incrementing column values, resulting in data different from Hive
// https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html#sqlmode_no_auto_value_on_zero
.withColumn("id", col("id") + lit(1))

当然也可以通过更改 MySQL 的配置实现,配置选项可以参考:MySQL :: MySQL 5.7 Reference Manual :: 5.1.10 Server SQL Modes,但是在文档中已经提到了不建议在自增列中存储值 0,原文如下:

Storing 0 is not a recommended practice, by the way.

Reference

GitHub - twitter-archive/snowflake at snowflake-2010