一、Distinct aggregation 算法
包含 distinct 關鍵字的 aggregation 由 4 個物理執行步驟組成。我們使用以下 query 來介紹:
val dataset = Seq(
(1, "a"), (1, "a"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset.groupBy($"nr").agg(functions.countDistinct("letter")).explain(true)
① partial aggregation 步驟
第一步是創建一個 partial aggregate,此 partial aggregate 的 grouping key 將不僅包括 query 中定義的 grouping key(nr),還包含 distinct 的列(letter),效果如 group by nr、letter
,執行計劃如下:
HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
② partial merge aggregation 步驟
這一步將通過 shuffle 將具有相同 grouping key(此處為 nr、letter)的數據劃分為同一分區:
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
③ partial aggregation for distinct 步驟
第三步,Spark 最終開始執行聚合,執行的是 partial aggregate:
+- HashAggregate(keys=[nr#5], functions=[partial_count(distinct letter#6)], output=[nr#5, count#18L])
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
④ final aggregation 步驟
第四步,partial aggregate(第三步)的結果將合并到最終結果中,并進行返回。它涉及 shuffle:
HashAggregate(keys=[nr#5], functions=[count(distinct letter#6)], output=[nr#5, count(DISTINCT letter)#12L])
+- Exchange hashpartitioning(nr#5, 200)
+- HashAggregate(keys=[nr#5], functions=[partial_count(distinct letter#6)], output=[nr#5, count#18L])
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
我們用下面的這張圖來總結上述幾個步驟:
二、無 Distinct aggregation 算法
無 Distinct aggregation 會簡單一些,僅包含兩個步驟,我們通過下面的例子來說明:
val dataset = Seq(
(1, "a"), (1, "a"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset.groupBy($"nr").count().explain(true)
①、partial aggregations 步驟
第一步即進行局部聚合:
HashAggregate(keys=[nr#5], functions=[partial_count(1)], output=[nr#5, count#17L])
+- PlanLater LocalRelation [nr#5]
②、final aggregation 步驟
第二步,毫無疑問,對部分結果進行了最終匯總:
HashAggregate(keys=[nr#5], functions=[count(1)], output=[nr#5, count#12L])
+- HashAggregate(keys=[nr#5], functions=[partial_count(1)], output=[nr#5, count#17L])
+- PlanLater LocalRelation [nr#5]
三、Hash-based 和 Sort-based aggregation
上述兩種模式都會調用到 createAggregate
方法,該方法為以下 3 種策略創建物理執行計劃:
- hash-based
- object-hash-based
- sort-based
這 3 中策略有一些共性。一個 Spark Sql aggregation 主要由兩部分組成:
- 一個 agg buffer(聚合緩沖區:包含 grouping keys 和 agg value)
- 一個 agg state(聚合狀態:僅 agg value)
每次調用 GROUP BY key
并對其使用一些聚合時,框架都會創建一個聚合緩沖區,保留給定的聚合(GROUP BY key)。指定 key(COUNT,SUM等)所涉及的聚合都在此聚合緩沖區存儲其部分(partial)或最終聚合結果,稱為聚合狀態。該狀態的存儲格式取決于聚合:
- 對于 AVG,它將是2個值,一個是出現次數,另一個是值的總和
- 對于 MIN,它將是到目前為止所看到的最小值
依此類推
hash-based
策略使用可變的、原始的、固定 size 的類型來作為 agg state,包括:
- NullType
- BooleanType
- ByteType
- ShortType
- IntegerType
- LongType
- FloatType
- DoubleType
- DateType
- TimestampType
這里的可變能力非常重要,因為 Spark 會直接修改該值(如對于 count 來說,遇到新的 row,就會把 count 的值(agg state)加上 1)。
對于 agg state 的值是其他類型的情況,使用 object-hash-based
策略,該策略自 2.2.0 版本引入,目的是為了解決 hash-based
策略的局限性(必須使用可變的、原始的、固定 size 的類型來作為 agg state)。在 2.2.0 之前,針對 HashAggregateExec 不支持的其他類型執行的聚合都會轉換為 sort-based
的策略。大部分情況下,sort-based
的性能會比 hash-based
的差,因為在聚合前會進行額外的排序。通過參數 spark.sql.execution.useObjectHashAggregateExec
來控制是否使用 object-hash-based
聚合,默認為 true。我們通過下面的例子來理解 sort-based
和 object-hash-based
的區別:
查詢
val dataset2 = Seq(
(1, "a"), (1, "aa"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset2.groupBy("nr").agg(functions.collect_list("letter").as("collected_letters")).explain(true)
如你所見,上圖兩個物理執行計劃均只進行一次 shuffle,但 sort-based
聚合相對于 object-hash-based
額外多了兩次排序,帶來性能開銷。
另一個值得關注的點是,hash-based
和 object-hash-based
運行過程中如果內存不夠用,會切換成 sort-based
聚合。對于 object-hash-based
聚合,通過參數 spark.sql.objectHashAggregate.sortBased.fallbackThreshold
控內存中(一種 hashMap)最多持有多少個 agg buffer(一個 grouping key 的組合一個),若超過該值,則切換為 sort-based
agg,該配置默認值為 128。如果切換為 sort-based
agg,會打印如下日志:
ObjectAggregationIterator: Aggregation hash map reaches threshold capacity (128 entries), spilling and falling back to sort based aggregation. You may change the threshold by adjust option spark.sql.objectHashAggregate.sortBased.fallbackThreshold
對于 hash-based
,該值為 Integer.MaxValue