github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/window_test.go (about) 1 // Copyright 2021-2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 25 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation/window" 26 "github.com/dolthub/go-mysql-server/sql/plan" 27 "github.com/dolthub/go-mysql-server/sql/types" 28 ) 29 30 func TestWindowPlanToIter(t *testing.T) { 31 n1 := window.NewRowNumber().(sql.WindowAggregation).WithWindow( 32 &sql.WindowDefinition{ 33 PartitionBy: []sql.Expression{ 34 expression.NewGetField(2, types.Int64, "c", false)}, 35 OrderBy: nil, 36 }) 37 n2 := aggregation.NewMax( 38 expression.NewGetField(0, types.Int64, "a", false), 39 ).WithWindow( 40 &sql.WindowDefinition{ 41 PartitionBy: []sql.Expression{ 42 expression.NewGetField(1, types.Int64, "b", false)}, 43 OrderBy: nil, 44 }) 45 n3 := expression.NewGetField(0, types.Int64, "a", false) 46 n4 := aggregation.NewMin( 47 expression.NewGetField(0, types.Int64, "a", false), 48 ).WithWindow( 49 &sql.WindowDefinition{ 50 PartitionBy: []sql.Expression{ 51 expression.NewGetField(1, types.Int64, "b", false)}, 52 OrderBy: nil, 53 }) 54 55 fn1, err := n1.NewWindowFunction() 56 require.NoError(t, err) 57 fn2, err := n2.NewWindowFunction() 58 require.NoError(t, err) 59 fn3, err := aggregation.NewLast(n3).NewWindowFunction() 60 fn4, err := n4.NewWindowFunction() 61 require.NoError(t, err) 62 63 agg1 := aggregation.NewAggregation(fn1, fn1.DefaultFramer()) 64 agg2 := aggregation.NewAggregation(fn2, fn2.DefaultFramer()) 65 agg3 := aggregation.NewAggregation(fn3, fn3.DefaultFramer()) 66 agg4 := aggregation.NewAggregation(fn4, fn4.DefaultFramer()) 67 68 window := plan.NewWindow([]sql.Expression{n1, n2, n3, n4}, nil) 69 outputIters, outputOrdinals, err := windowToIter(window) 70 require.NoError(t, err) 71 72 require.Equal(t, len(outputIters), 3) 73 require.Equal(t, len(outputOrdinals), 3) 74 accOrdinals := make([]int, 0) 75 for _, p := range outputOrdinals { 76 accOrdinals = append(accOrdinals, p...) 77 } 78 require.ElementsMatch(t, accOrdinals, []int{0, 1, 2, 3}) 79 80 // check aggs 81 allOutputAggs := make([]*aggregation.Aggregation, 0) 82 for _, i := range outputIters { 83 allOutputAggs = append(allOutputAggs, i.WindowBlock().Aggs...) 84 } 85 require.ElementsMatch(t, allOutputAggs, []*aggregation.Aggregation{agg1, agg2, agg3, agg4}) 86 87 // check partitionBy 88 allPartitionBy := make([][]sql.Expression, 0) 89 for _, i := range outputIters { 90 allPartitionBy = append(allPartitionBy, i.WindowBlock().PartitionBy) 91 } 92 require.ElementsMatch(t, allPartitionBy, [][]sql.Expression{ 93 nil, 94 { 95 expression.NewGetField(1, types.Int64, "b", false), 96 }, { 97 expression.NewGetField(2, types.Int64, "c", false), 98 }}) 99 }