github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_partition_test.go (about)

     1  // Copyright 2022 DoltHub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregation
    16  
    17  import (
    18  	"context"
    19  	"io"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/dolthub/go-mysql-server/memory"
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/go-mysql-server/sql/expression"
    27  	"github.com/dolthub/go-mysql-server/sql/types"
    28  )
    29  
    30  func TestWindowPartitionIter(t *testing.T) {
    31  	tests := []struct {
    32  		Name     string
    33  		Iter     *WindowPartitionIter
    34  		Expected []sql.Row
    35  	}{
    36  		{
    37  			Name: "group by",
    38  			Iter: NewWindowPartitionIter(
    39  				&WindowPartition{
    40  					PartitionBy: partitionByX,
    41  					SortBy:      sortByW,
    42  					Aggs: []*Aggregation{
    43  						NewAggregation(lastX, NewGroupByFramer()),
    44  						NewAggregation(firstY, NewGroupByFramer()),
    45  						NewAggregation(sumZ, NewGroupByFramer()),
    46  					},
    47  				},
    48  			),
    49  			Expected: []sql.Row{
    50  				{"forest", "leaf", float64(27)},
    51  				{"desert", "sand", float64(23)},
    52  			},
    53  		},
    54  		{
    55  			Name: "partition level desc",
    56  			Iter: NewWindowPartitionIter(
    57  				&WindowPartition{
    58  					PartitionBy: partitionByX,
    59  					SortBy:      sortByWDesc,
    60  					Aggs: []*Aggregation{
    61  						NewAggregation(lastX, NewPartitionFramer()),
    62  						NewAggregation(firstY, NewPartitionFramer()),
    63  						NewAggregation(sumZ, NewPartitionFramer()),
    64  					},
    65  				}),
    66  			Expected: []sql.Row{
    67  				{"forest", "wildflower", float64(27)},
    68  				{"forest", "wildflower", float64(27)},
    69  				{"forest", "wildflower", float64(27)},
    70  				{"forest", "wildflower", float64(27)},
    71  				{"forest", "wildflower", float64(27)},
    72  				{"desert", "mummy", float64(23)},
    73  				{"desert", "mummy", float64(23)},
    74  				{"desert", "mummy", float64(23)},
    75  				{"desert", "mummy", float64(23)},
    76  			},
    77  		},
    78  	}
    79  
    80  	for _, tt := range tests {
    81  		t.Run(tt.Name, func(t *testing.T) {
    82  			db := memory.NewDatabase("test")
    83  			pro := memory.NewDBProvider(db)
    84  			ctx := sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro)))
    85  
    86  			tt.Iter.child = mustNewRowIter(t, db, ctx)
    87  			res, err := sql.RowIterToRows(ctx, tt.Iter)
    88  			require.NoError(t, err)
    89  			require.Equal(t, tt.Expected, res)
    90  		})
    91  	}
    92  }
    93  
    94  func mustNewRowIter(t *testing.T, db *memory.Database, ctx *sql.Context) sql.RowIter {
    95  	childSchema := sql.NewPrimaryKeySchema(sql.Schema{
    96  		{Name: "w", Type: types.Int64, Nullable: true},
    97  		{Name: "x", Type: types.Text, Nullable: true},
    98  		{Name: "y", Type: types.Text, Nullable: true},
    99  		{Name: "z", Type: types.Int32, Nullable: true},
   100  	})
   101  	table := memory.NewTable(db, "test", childSchema, nil)
   102  
   103  	rows := []sql.Row{
   104  		{int64(1), "forest", "leaf", int32(4)},
   105  		{int64(2), "forest", "bark", int32(4)},
   106  		{int64(3), "forest", "canopy", int32(6)},
   107  		{int64(4), "forest", "bug", int32(3)},
   108  		{int64(5), "forest", "wildflower", int32(10)},
   109  		{int64(6), "desert", "sand", int32(4)},
   110  		{int64(7), "desert", "cactus", int32(6)},
   111  		{int64(8), "desert", "scorpion", int32(8)},
   112  		{int64(9), "desert", "mummy", int32(5)},
   113  	}
   114  
   115  	for _, r := range rows {
   116  		require.NoError(t, table.Insert(ctx, r))
   117  	}
   118  
   119  	pIter, err := table.Partitions(ctx)
   120  	require.NoError(t, err)
   121  
   122  	p, err := pIter.Next(ctx)
   123  	require.NoError(t, err)
   124  
   125  	child, err := table.PartitionRows(ctx, p)
   126  	require.NoError(t, err)
   127  	return child
   128  }
   129  
   130  func TestWindowPartition_MaterializeInput(t *testing.T) {
   131  	db := memory.NewDatabase("test")
   132  	pro := memory.NewDBProvider(db)
   133  	ctx := sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro)))
   134  
   135  	i := WindowPartitionIter{
   136  		w: &WindowPartition{
   137  			PartitionBy: partitionByX,
   138  			SortBy:      sortByW,
   139  		},
   140  		child: mustNewRowIter(t, db, ctx),
   141  	}
   142  
   143  	buf, ordering, err := i.materializeInput(ctx)
   144  	require.NoError(t, err)
   145  	expBuf := []sql.Row{
   146  		{int64(1), "forest", "leaf", int32(4)},
   147  		{int64(2), "forest", "bark", int32(4)},
   148  		{int64(3), "forest", "canopy", int32(6)},
   149  		{int64(4), "forest", "bug", int32(3)},
   150  		{int64(5), "forest", "wildflower", int32(10)},
   151  		{int64(6), "desert", "sand", int32(4)},
   152  		{int64(7), "desert", "cactus", int32(6)},
   153  		{int64(8), "desert", "scorpion", int32(8)},
   154  		{int64(9), "desert", "mummy", int32(5)},
   155  	}
   156  	require.ElementsMatch(t, expBuf, buf)
   157  	expOrd := []int{0, 1, 2, 3, 4, 5, 6, 7, 8}
   158  	require.ElementsMatch(t, expOrd, ordering)
   159  }
   160  
   161  func TestWindowPartition_InitializePartitions(t *testing.T) {
   162  	ctx := sql.NewEmptyContext()
   163  	i := NewWindowPartitionIter(
   164  		&WindowPartition{
   165  			PartitionBy: partitionByX,
   166  		})
   167  	i.input = []sql.Row{
   168  		{int64(1), "forest", "leaf", int32(4)},
   169  		{int64(2), "forest", "bark", int32(4)},
   170  		{int64(3), "forest", "canopy", int32(6)},
   171  		{int64(4), "forest", "bug", int32(3)},
   172  		{int64(5), "forest", "wildflower", int32(10)},
   173  		{int64(6), "desert", "sand", int32(4)},
   174  		{int64(7), "desert", "cactus", int32(6)},
   175  		{int64(8), "desert", "scorpion", int32(8)},
   176  		{int64(9), "desert", "mummy", int32(5)},
   177  	}
   178  	partitions, err := i.initializePartitions(ctx)
   179  	require.NoError(t, err)
   180  	expPartitions := []sql.WindowInterval{
   181  		{Start: 0, End: 5},
   182  		{Start: 5, End: 9},
   183  	}
   184  	require.ElementsMatch(t, expPartitions, partitions)
   185  }
   186  
   187  func TestWindowPartition_MaterializeOutput(t *testing.T) {
   188  	t.Run("non nil input", func(t *testing.T) {
   189  		ctx := sql.NewEmptyContext()
   190  		i := NewWindowPartitionIter(
   191  			&WindowPartition{
   192  				PartitionBy: partitionByX,
   193  				Aggs: []*Aggregation{
   194  					NewAggregation(sumZ, NewGroupByFramer()),
   195  				},
   196  			})
   197  		i.input = []sql.Row{
   198  			{int64(1), "forest", "leaf", 4},
   199  			{int64(2), "forest", "bark", 4},
   200  			{int64(3), "forest", "canopy", 6},
   201  			{int64(4), "forest", "bug", 3},
   202  			{int64(5), "forest", "wildflower", 10},
   203  			{int64(6), "desert", "sand", 4},
   204  			{int64(7), "desert", "cactus", 6},
   205  			{int64(8), "desert", "scorpion", 8},
   206  			{int64(9), "desert", "mummy", 5},
   207  		}
   208  		i.partitions = []sql.WindowInterval{{0, 5}, {5, 9}}
   209  		i.outputOrdering = []int{0, 1, 2, 3, 4, 5, 6, 7, 8}
   210  		output, err := i.materializeOutput(ctx)
   211  		require.NoError(t, err)
   212  		expOutput := []sql.Row{
   213  			{float64(27), 0},
   214  			{float64(23), 5},
   215  		}
   216  		require.ElementsMatch(t, expOutput, output)
   217  	})
   218  
   219  	t.Run("nil input with partition by", func(t *testing.T) {
   220  		ctx := sql.NewEmptyContext()
   221  		i := NewWindowPartitionIter(
   222  			&WindowPartition{
   223  				PartitionBy: partitionByX,
   224  				Aggs: []*Aggregation{
   225  					NewAggregation(sumZ, NewGroupByFramer()),
   226  				},
   227  			})
   228  		i.input = []sql.Row{}
   229  		i.partitions = []sql.WindowInterval{{0, 0}}
   230  		i.outputOrdering = []int{}
   231  		output, err := i.materializeOutput(ctx)
   232  		require.Equal(t, io.EOF, err)
   233  		require.ElementsMatch(t, nil, output)
   234  	})
   235  
   236  	t.Run("nil input no partition by", func(t *testing.T) {
   237  		ctx := sql.NewEmptyContext()
   238  		i := NewWindowPartitionIter(
   239  			&WindowPartition{
   240  				PartitionBy: nil,
   241  				Aggs: []*Aggregation{
   242  					NewAggregation(NewCountAgg(expression.NewGetField(0, types.Int64, "z", true)), NewGroupByFramer()),
   243  				},
   244  			})
   245  		i.input = []sql.Row{}
   246  		i.partitions = []sql.WindowInterval{{0, 0}}
   247  		i.outputOrdering = nil
   248  		output, err := i.materializeOutput(ctx)
   249  		require.NoError(t, err)
   250  		expOutput := []sql.Row{{int64(0), nil}}
   251  		require.ElementsMatch(t, expOutput, output)
   252  	})
   253  }
   254  
   255  func TestWindowPartition_SortAndFilterOutput(t *testing.T) {
   256  	tests := []struct {
   257  		Name     string
   258  		Output   []sql.Row
   259  		Expected []sql.Row
   260  	}{
   261  		{
   262  			Name: "no input rows filtered before output, contiguous output indexes",
   263  			Output: []sql.Row{
   264  				{0, 0},
   265  				{3, 3},
   266  				{2, 2},
   267  				{1, 1},
   268  			},
   269  			Expected: []sql.Row{
   270  				{0},
   271  				{1},
   272  				{2},
   273  				{3},
   274  			},
   275  		},
   276  		{
   277  			Name: "input rows filtered before output, non contiguous output indexes",
   278  			Output: []sql.Row{
   279  				{0, 0},
   280  				{3, 3},
   281  				{2, 7},
   282  				{1, 1},
   283  			},
   284  			Expected: []sql.Row{
   285  				{0},
   286  				{1},
   287  				{3},
   288  				{2},
   289  			},
   290  		},
   291  		{
   292  			Name:     "empty output",
   293  			Output:   []sql.Row{},
   294  			Expected: nil,
   295  		},
   296  	}
   297  
   298  	for _, tt := range tests {
   299  		t.Run(tt.Name, func(t *testing.T) {
   300  			i := NewWindowPartitionIter(&WindowPartition{})
   301  			i.output = tt.Output
   302  			err := i.sortAndFilterOutput()
   303  			require.NoError(t, err)
   304  			require.ElementsMatch(t, tt.Expected, i.output)
   305  		})
   306  	}
   307  }