github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/window_test.go (about)

     1  // Copyright 2021-2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
    25  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation/window"
    26  	"github.com/dolthub/go-mysql-server/sql/plan"
    27  	"github.com/dolthub/go-mysql-server/sql/types"
    28  )
    29  
    30  func TestWindowPlanToIter(t *testing.T) {
    31  	n1 := window.NewRowNumber().(sql.WindowAggregation).WithWindow(
    32  		&sql.WindowDefinition{
    33  			PartitionBy: []sql.Expression{
    34  				expression.NewGetField(2, types.Int64, "c", false)},
    35  			OrderBy: nil,
    36  		})
    37  	n2 := aggregation.NewMax(
    38  		expression.NewGetField(0, types.Int64, "a", false),
    39  	).WithWindow(
    40  		&sql.WindowDefinition{
    41  			PartitionBy: []sql.Expression{
    42  				expression.NewGetField(1, types.Int64, "b", false)},
    43  			OrderBy: nil,
    44  		})
    45  	n3 := expression.NewGetField(0, types.Int64, "a", false)
    46  	n4 := aggregation.NewMin(
    47  		expression.NewGetField(0, types.Int64, "a", false),
    48  	).WithWindow(
    49  		&sql.WindowDefinition{
    50  			PartitionBy: []sql.Expression{
    51  				expression.NewGetField(1, types.Int64, "b", false)},
    52  			OrderBy: nil,
    53  		})
    54  
    55  	fn1, err := n1.NewWindowFunction()
    56  	require.NoError(t, err)
    57  	fn2, err := n2.NewWindowFunction()
    58  	require.NoError(t, err)
    59  	fn3, err := aggregation.NewLast(n3).NewWindowFunction()
    60  	fn4, err := n4.NewWindowFunction()
    61  	require.NoError(t, err)
    62  
    63  	agg1 := aggregation.NewAggregation(fn1, fn1.DefaultFramer())
    64  	agg2 := aggregation.NewAggregation(fn2, fn2.DefaultFramer())
    65  	agg3 := aggregation.NewAggregation(fn3, fn3.DefaultFramer())
    66  	agg4 := aggregation.NewAggregation(fn4, fn4.DefaultFramer())
    67  
    68  	window := plan.NewWindow([]sql.Expression{n1, n2, n3, n4}, nil)
    69  	outputIters, outputOrdinals, err := windowToIter(window)
    70  	require.NoError(t, err)
    71  
    72  	require.Equal(t, len(outputIters), 3)
    73  	require.Equal(t, len(outputOrdinals), 3)
    74  	accOrdinals := make([]int, 0)
    75  	for _, p := range outputOrdinals {
    76  		accOrdinals = append(accOrdinals, p...)
    77  	}
    78  	require.ElementsMatch(t, accOrdinals, []int{0, 1, 2, 3})
    79  
    80  	// check aggs
    81  	allOutputAggs := make([]*aggregation.Aggregation, 0)
    82  	for _, i := range outputIters {
    83  		allOutputAggs = append(allOutputAggs, i.WindowBlock().Aggs...)
    84  	}
    85  	require.ElementsMatch(t, allOutputAggs, []*aggregation.Aggregation{agg1, agg2, agg3, agg4})
    86  
    87  	// check partitionBy
    88  	allPartitionBy := make([][]sql.Expression, 0)
    89  	for _, i := range outputIters {
    90  		allPartitionBy = append(allPartitionBy, i.WindowBlock().PartitionBy)
    91  	}
    92  	require.ElementsMatch(t, allPartitionBy, [][]sql.Expression{
    93  		nil,
    94  		{
    95  			expression.NewGetField(1, types.Int64, "b", false),
    96  		}, {
    97  			expression.NewGetField(2, types.Int64, "c", false),
    98  		}})
    99  }