github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_iter_test.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "context" 19 "testing" 20 21 "github.com/stretchr/testify/require" 22 23 "github.com/dolthub/go-mysql-server/memory" 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 var ( 30 partitionByX = []sql.Expression{ 31 expression.NewGetFieldWithTable(1, 0, types.Text, "mydb", "a", "x", false), 32 } 33 sortByW = sql.SortFields{{ 34 Column: expression.NewGetFieldWithTable(0, 0, types.Int64, "mydb", "a", "w", false), 35 }} 36 sortByWDesc = sql.SortFields{ 37 { 38 Column: expression.NewGetFieldWithTable(0, 0, types.Int64, "mydb", "a", "w", false), 39 Order: sql.Descending, 40 }, 41 } 42 lastX = NewLastAgg(expression.NewGetField(1, types.Text, "x", true)) 43 firstY = NewFirstAgg(expression.NewGetField(2, types.Text, "y", true)) 44 sumZ = NewSumAgg(expression.NewGetField(3, types.Int64, "z", true)) 45 ) 46 47 func TestWindowIter(t *testing.T) { 48 tests := []struct { 49 Name string 50 PartitionIters []*WindowPartitionIter 51 OutputOrdinals [][]int 52 Expected []sql.Row 53 }{ 54 { 55 Name: "unbounded preceding to current row", 56 PartitionIters: []*WindowPartitionIter{ 57 NewWindowPartitionIter( 58 &WindowPartition{ 59 PartitionBy: partitionByX, 60 SortBy: sortByW, 61 Aggs: []*Aggregation{ 62 NewAggregation(lastX, NewUnboundedPrecedingToCurrentRowFramer()), 63 NewAggregation(sumZ, NewUnboundedPrecedingToCurrentRowFramer()), 64 }, 65 }), 66 NewWindowPartitionIter( 67 &WindowPartition{ 68 PartitionBy: partitionByX, 69 SortBy: sortByWDesc, 70 Aggs: []*Aggregation{ 71 NewAggregation(firstY, NewUnboundedPrecedingToCurrentRowFramer()), 72 }, 73 }), 74 }, 75 OutputOrdinals: [][]int{{0, 1}, {2}}, 76 Expected: []sql.Row{ 77 {"forest", float64(4), "wildflower"}, 78 {"forest", float64(8), "wildflower"}, 79 {"forest", float64(14), "wildflower"}, 80 {"forest", float64(17), "wildflower"}, 81 {"forest", float64(27), "wildflower"}, 82 {"desert", float64(4), "mummy"}, 83 {"desert", float64(10), "mummy"}, 84 {"desert", float64(18), "mummy"}, 85 {"desert", float64(23), "mummy"}, 86 }, 87 }, 88 } 89 90 for _, tt := range tests { 91 t.Run(tt.Name, func(t *testing.T) { 92 db := memory.NewDatabase("test") 93 pro := memory.NewDBProvider(db) 94 ctx := sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro))) 95 96 i := NewWindowIter(tt.PartitionIters, tt.OutputOrdinals, mustNewRowIter(t, db, ctx)) 97 res, err := sql.RowIterToRows(ctx, i) 98 require.NoError(t, err) 99 require.Equal(t, tt.Expected, res) 100 }) 101 } 102 }