/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.sql import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DataType, LongType, StructType} class DataFrameWindowSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") val w = Window.partitionBy("key").orderBy("value") checkAnswer( df.select( lead("key", 1).over(w), lead("value", 1).over(w)), Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } test("reuse window orderBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") val w = Window.orderBy("value").partitionBy("key") checkAnswer( df.select( lead("key", 1).over(w), lead("value", 1).over(w)), Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } test("lead") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) } test("lag") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) } test("lead with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) } test("lag with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) } test("rank functions in unspecific window") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", max("key").over(Window.partitionBy("value").orderBy("key")), min("key").over(Window.partitionBy("value").orderBy("key")), mean("key").over(Window.partitionBy("value").orderBy("key")), count("key").over(Window.partitionBy("value").orderBy("key")), sum("key").over(Window.partitionBy("value").orderBy("key")), ntile(2).over(Window.partitionBy("value").orderBy("key")), row_number().over(Window.partitionBy("value").orderBy("key")), dense_rank().over(Window.partitionBy("value").orderBy("key")), rank().over(Window.partitionBy("value").orderBy("key")), cume_dist().over(Window.partitionBy("value").orderBy("key")), percent_rank().over(Window.partitionBy("value").orderBy("key"))), Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("aggregation and rows between") { val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) } test("aggregation and range between") { val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), Row(2.0d), Row(2.0d))) } test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), Row(4, 4, 4, 4))) } test("aggregation and range between with unbounded") { val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", last("value").over( Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) .equalTo("2") .as("last_v"), avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) .as("avg_key1"), avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) .as("avg_key2"), avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), Seq(Row(3, null, 3.0d, 4.0d, 3.0d), Row(5, false, 4.0d, 5.0d, 5.0d), Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) } test("reverse sliding range frame") { val df = Seq( (1, "Thin", "Cell Phone", 6000), (2, "Normal", "Tablet", 1500), (3, "Mini", "Tablet", 5500), (4, "Ultra thin", "Cell Phone", 5500), (5, "Very thin", "Cell Phone", 6000), (6, "Big", "Tablet", 2500), (7, "Bendable", "Cell Phone", 3000), (8, "Foldable", "Cell Phone", 3000), (9, "Pro", "Tablet", 4500), (10, "Pro2", "Tablet", 6500)). toDF("id", "product", "category", "revenue") val window = Window. partitionBy($"category"). orderBy($"revenue".desc). rangeBetween(-2000L, 1000L) checkAnswer( df.select( $"id", avg($"revenue").over(window).cast("int")), Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: Row(10, 6000) :: Nil) } // This is here to illustrate the fact that reverse order also reverses offsets. test("reverse unbounded range frame") { val df = Seq(1, 2, 4, 3, 2, 1). map(Tuple1.apply). toDF("value") val window = Window.orderBy($"value".desc) checkAnswer( df.select( $"value", sum($"value").over(window.rangeBetween(Long.MinValue, 1)), sum($"value").over(window.rangeBetween(1, Long.MaxValue))), Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) } test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") val window = Window.partitionBy($"key") checkAnswer( df.select( $"key", var_pop($"value").over(window), var_samp($"value").over(window), approxCountDistinct($"value").over(window)), Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) } test("window function with aggregates") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") val window = Window.orderBy() checkAnswer( df.groupBy($"key") .agg( sum($"value"), sum(sum($"value")).over(window) - sum($"value")), Seq(Row("a", 6, 9), Row("b", 9, 6))) } test("window function with udaf") { val udaf = new UserDefinedAggregateFunction { def inputSchema: StructType = new StructType() .add("a", LongType) .add("b", LongType) def bufferSchema: StructType = new StructType() .add("product", LongType) def dataType: DataType = LongType def deterministic: Boolean = true def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L } def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if (!(input.isNullAt(0) || input.isNullAt(1))) { buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) } } def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) } def evaluate(buffer: Row): Any = buffer.getLong(0) } val df = Seq( ("a", 1, 1), ("a", 1, 5), ("a", 2, 10), ("a", 2, -1), ("b", 4, 7), ("b", 3, 8), ("b", 2, 4)) .toDF("key", "a", "b") val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) checkAnswer( df.select( $"key", $"a", $"b", udaf($"a", $"b").over(window)), Seq( Row("a", 1, 1, 6), Row("a", 1, 5, 6), Row("a", 2, 10, 24), Row("a", 2, -1, 24), Row("b", 4, 7, 60), Row("b", 3, 8, 32), Row("b", 2, 4, 8))) } test("null inputs") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)) .toDF("key", "value") val window = Window.orderBy() checkAnswer( df.select( $"key", $"value", avg(lit(null)).over(window), sum(lit(null)).over(window)), Seq( Row("a", 1, null, null), Row("a", 1, null, null), Row("a", 2, null, null), Row("a", 2, null, null), Row("b", 4, null, null), Row("b", 3, null, null), Row("b", 2, null, null))) } test("last/first with ignoreNulls") { val nullStr: String = null val df = Seq( ("a", 0, nullStr), ("a", 1, "x"), ("a", 2, "y"), ("a", 3, "z"), ("a", 4, nullStr), ("b", 1, nullStr), ("b", 2, nullStr)). toDF("key", "order", "value") val window = Window.partitionBy($"key").orderBy($"order") checkAnswer( df.select( $"key", $"order", first($"value").over(window), first($"value", ignoreNulls = false).over(window), first($"value", ignoreNulls = true).over(window), last($"value").over(window), last($"value", ignoreNulls = false).over(window), last($"value", ignoreNulls = true).over(window)), Seq( Row("a", 0, null, null, null, null, null, null), Row("a", 1, null, null, "x", "x", "x", "x"), Row("a", 2, null, null, "x", "y", "y", "y"), Row("a", 3, null, null, "x", "z", "z", "z"), Row("a", 4, null, null, "x", null, null, "z"), Row("b", 1, null, null, null, null, null, null), Row("b", 2, null, null, null, null, null, null))) } test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { val src = Seq((0, 3, 5)).toDF("a", "b", "c") .withColumn("Data", struct("a", "b")) .drop("a") .drop("b") val winSpec = Window.partitionBy("Data.a", "Data.b").orderBy($"c".desc) val df = src.select($"*", max("c").over(winSpec) as "max") checkAnswer(df, Row(5, Row(0, 3), 5)) } }