github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_functions_test.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregation
    16  
    17  import (
    18  	"errors"
    19  	"io"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  func TestGroupedAggFuncs(t *testing.T) {
    30  	tests := []struct {
    31  		Name     string
    32  		Agg      sql.WindowFunction
    33  		Expected sql.Row
    34  	}{
    35  		{
    36  			Name:     "count star",
    37  			Agg:      NewCountAgg(expression.NewStar()),
    38  			Expected: sql.Row{int64(4), int64(4), int64(6)},
    39  		},
    40  		{
    41  			Name:     "count without nulls",
    42  			Agg:      NewCountAgg(expression.NewGetField(1, types.LongText, "x", true)),
    43  			Expected: sql.Row{int64(4), int64(4), int64(6)},
    44  		},
    45  		{
    46  			Name:     "count with nulls",
    47  			Agg:      NewCountAgg(expression.NewGetField(0, types.LongText, "x", true)),
    48  			Expected: sql.Row{int64(3), int64(3), int64(4)},
    49  		},
    50  		{
    51  			Name:     "max ints",
    52  			Agg:      NewMaxAgg(expression.NewGetField(1, types.LongText, "x", true)),
    53  			Expected: sql.Row{4, 4, 6},
    54  		},
    55  		{
    56  			Name:     "max int64",
    57  			Agg:      NewMaxAgg(expression.NewGetField(2, types.LongText, "x", true)),
    58  			Expected: sql.Row{int64(3), int64(3), int64(5)},
    59  		},
    60  		{
    61  			Name:     "max w/ nulls",
    62  			Agg:      NewMaxAgg(expression.NewGetField(0, types.LongText, "x", true)),
    63  			Expected: sql.Row{4, 4, 6},
    64  		},
    65  		{
    66  			Name:     "max w/ float",
    67  			Agg:      NewMaxAgg(expression.NewGetField(3, types.LongText, "x", true)),
    68  			Expected: sql.Row{float64(4), float64(4), float64(6)},
    69  		},
    70  		{
    71  			Name:     "min ints",
    72  			Agg:      NewMinAgg(expression.NewGetField(1, types.LongText, "x", true)),
    73  			Expected: sql.Row{1, 1, 1},
    74  		},
    75  		{
    76  			Name:     "min int64",
    77  			Agg:      NewMinAgg(expression.NewGetField(2, types.LongText, "x", true)),
    78  			Expected: sql.Row{int64(1), int64(1), int64(1)},
    79  		},
    80  		{
    81  			Name:     "min w/ nulls",
    82  			Agg:      NewMinAgg(expression.NewGetField(0, types.LongText, "x", true)),
    83  			Expected: sql.Row{1, 1, 1},
    84  		},
    85  		{
    86  			Name:     "min w/ float",
    87  			Agg:      NewMinAgg(expression.NewGetField(3, types.LongText, "x", true)),
    88  			Expected: sql.Row{float64(1), float64(1), float64(1)},
    89  		},
    90  		{
    91  			Name:     "avg nulls",
    92  			Agg:      NewAvgAgg(expression.NewGetField(0, types.LongText, "x", true)),
    93  			Expected: sql.Row{float64(8) / float64(3), float64(8) / float64(3), float64(14) / float64(4)},
    94  		},
    95  		{
    96  			Name:     "avg int",
    97  			Agg:      NewAvgAgg(expression.NewGetField(1, types.LongText, "x", true)),
    98  			Expected: sql.Row{float64(10) / float64(4), float64(10) / float64(4), float64(21) / float64(6)},
    99  		},
   100  		{
   101  			Name:     "avg int64",
   102  			Agg:      NewAvgAgg(expression.NewGetField(2, types.LongText, "x", true)),
   103  			Expected: sql.Row{float64(8) / float64(4), float64(8) / float64(4), float64(17) / float64(6)},
   104  		},
   105  		{
   106  			Name:     "avg float",
   107  			Agg:      NewAvgAgg(expression.NewGetField(3, types.LongText, "x", true)),
   108  			Expected: sql.Row{float64(10) / float64(4), float64(10) / float64(4), float64(21) / float64(6)},
   109  		},
   110  		{
   111  			Name:     "sum nulls",
   112  			Agg:      NewSumAgg(expression.NewGetField(0, types.LongText, "x", true)),
   113  			Expected: sql.Row{float64(8), float64(8), float64(14)},
   114  		},
   115  		{
   116  			Name:     "sum ints",
   117  			Agg:      NewSumAgg(expression.NewGetField(1, types.LongText, "x", true)),
   118  			Expected: sql.Row{float64(10), float64(10), float64(21)},
   119  		},
   120  		{
   121  			Name:     "sum int64",
   122  			Agg:      NewSumAgg(expression.NewGetField(2, types.LongText, "x", true)),
   123  			Expected: sql.Row{float64(8), float64(8), float64(17)},
   124  		},
   125  		{
   126  			Name:     "sum float64",
   127  			Agg:      NewSumAgg(expression.NewGetField(3, types.LongText, "x", true)),
   128  			Expected: sql.Row{float64(10), float64(10), float64(21)},
   129  		},
   130  		{
   131  			Name:     "first",
   132  			Agg:      NewFirstAgg(expression.NewGetField(0, types.LongText, "x", true)),
   133  			Expected: sql.Row{1, 1, 1},
   134  		},
   135  		{
   136  			Name:     "last",
   137  			Agg:      NewLastAgg(expression.NewGetField(0, types.LongText, "x", true)),
   138  			Expected: sql.Row{4, 4, 6},
   139  		},
   140  		// list aggregations
   141  		{
   142  			Name:     "group concat null",
   143  			Agg:      NewGroupConcatAgg(NewGroupConcat("", nil, ",", []sql.Expression{expression.NewGetField(0, types.LongText, "x", true)}, 1042)),
   144  			Expected: sql.Row{"1,3,4", "1,3,4", "1,2,5,6"},
   145  		},
   146  		{
   147  			Name:     "group concat int",
   148  			Agg:      NewGroupConcatAgg(NewGroupConcat("", nil, ",", []sql.Expression{expression.NewGetField(1, types.LongText, "x", true)}, 1042)),
   149  			Expected: sql.Row{"1,2,3,4", "1,2,3,4", "1,2,3,4,5,6"},
   150  		},
   151  		{
   152  			Name:     "group concat float",
   153  			Agg:      NewGroupConcatAgg(NewGroupConcat("", nil, ",", []sql.Expression{expression.NewGetField(3, types.LongText, "x", true)}, 1042)),
   154  			Expected: sql.Row{"1,2,3,4", "1,2,3,4", "1,2,3,4,5,6"},
   155  		},
   156  		{
   157  			Name: "json array null",
   158  			Agg:  NewJsonArrayAgg(expression.NewGetField(0, types.LongText, "x", true)),
   159  			Expected: sql.Row{
   160  				types.JSONDocument{Val: []interface{}{1, nil, 3, 4}},
   161  				types.JSONDocument{Val: []interface{}{1, nil, 3, 4}},
   162  				types.JSONDocument{Val: []interface{}{1, 2, nil, nil, 5, 6}},
   163  			},
   164  		},
   165  		{
   166  			Name: "json array int",
   167  			Agg:  NewJsonArrayAgg(expression.NewGetField(1, types.LongText, "x", true)),
   168  			Expected: sql.Row{
   169  				types.JSONDocument{Val: []interface{}{1, 2, 3, 4}},
   170  				types.JSONDocument{Val: []interface{}{1, 2, 3, 4}},
   171  				types.JSONDocument{Val: []interface{}{1, 2, 3, 4, 5, 6}},
   172  			},
   173  		},
   174  		{
   175  			Name: "json array float",
   176  			Agg:  NewJsonArrayAgg(expression.NewGetField(3, types.LongText, "x", true)),
   177  			Expected: sql.Row{
   178  				types.JSONDocument{Val: []interface{}{float64(1), float64(2), float64(3), float64(4)}},
   179  				types.JSONDocument{Val: []interface{}{float64(1), float64(2), float64(3), float64(4)}},
   180  				types.JSONDocument{Val: []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5), float64(6)}},
   181  			},
   182  		},
   183  		{
   184  			Name: "json object null",
   185  			Agg: NewWindowedJSONObjectAgg(
   186  				NewJSONObjectAgg(
   187  					expression.NewGetField(1, types.LongText, "x", true),
   188  					expression.NewGetField(0, types.LongText, "y", true),
   189  				).(*JSONObjectAgg),
   190  			),
   191  			Expected: sql.Row{
   192  				types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": nil, "3": 3, "4": 4}},
   193  				types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": nil, "3": 3, "4": 4}},
   194  				types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": 2, "3": nil, "4": nil, "5": 5, "6": 6}},
   195  			},
   196  		},
   197  		{
   198  			Name: "json object int",
   199  			Agg: NewWindowedJSONObjectAgg(
   200  				NewJSONObjectAgg(
   201  					expression.NewGetField(1, types.LongText, "x", true),
   202  					expression.NewGetField(0, types.LongText, "x", true),
   203  				).(*JSONObjectAgg),
   204  			),
   205  			Expected: sql.Row{
   206  				types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": nil, "3": 3, "4": 4}},
   207  				types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": nil, "3": 3, "4": 4}},
   208  				types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": 2, "3": nil, "4": nil, "5": 5, "6": 6}},
   209  			},
   210  		},
   211  		{
   212  			Name: "json object float",
   213  			Agg: NewWindowedJSONObjectAgg(
   214  				NewJSONObjectAgg(
   215  					expression.NewGetField(1, types.LongText, "x", true),
   216  					expression.NewGetField(3, types.LongText, "x", true),
   217  				).(*JSONObjectAgg),
   218  			),
   219  			Expected: sql.Row{
   220  				types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4)}},
   221  				types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4)}},
   222  				types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4), "5": float64(5), "6": float64(6)}},
   223  			},
   224  		},
   225  		{
   226  			Name: "json object float",
   227  			Agg: NewWindowedJSONObjectAgg(
   228  				NewJSONObjectAgg(
   229  					expression.NewGetField(1, types.LongText, "x", true),
   230  					expression.NewGetField(3, types.LongText, "x", true),
   231  				).(*JSONObjectAgg),
   232  			),
   233  			Expected: sql.Row{
   234  				types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4)}},
   235  				types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4)}},
   236  				types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4), "5": float64(5), "6": float64(6)}},
   237  			},
   238  		},
   239  	}
   240  
   241  	buf := []sql.Row{
   242  		{1, 1, int64(1), float64(1), 1, 1},
   243  		{nil, 2, int64(2), float64(2), 1, 1},
   244  		{3, 3, int64(3), float64(3), 1, 2},
   245  		{4, 4, int64(2), float64(4), 1, 3},
   246  		{1, 1, int64(1), float64(1), 2, 1},
   247  		{nil, 2, int64(2), float64(2), 2, 2},
   248  		{3, 3, int64(3), float64(3), 2, 2},
   249  		{4, 4, int64(2), float64(4), 3},
   250  		{1, 1, int64(1), float64(1), 1},
   251  		{2, 2, int64(2), float64(2), 2},
   252  		{nil, 3, int64(3), float64(3), 2},
   253  		{nil, 4, int64(4), float64(4), 3},
   254  		{5, 5, int64(5), float64(5), 3},
   255  		{6, 6, int64(2), float64(6), 4},
   256  	}
   257  
   258  	partitions := []sql.WindowInterval{
   259  		{Start: 0, End: 4},
   260  		{Start: 4, End: 8},
   261  		{Start: 8, End: 14},
   262  	}
   263  
   264  	for _, tt := range tests {
   265  		t.Run(tt.Name, func(t *testing.T) {
   266  			ctx := sql.NewEmptyContext()
   267  			res := make(sql.Row, len(partitions))
   268  			for i, p := range partitions {
   269  				err := tt.Agg.StartPartition(ctx, p, buf)
   270  				require.NoError(t, err)
   271  				res[i] = tt.Agg.Compute(ctx, p, buf)
   272  			}
   273  			require.Equal(t, tt.Expected, res)
   274  		})
   275  	}
   276  }
   277  
   278  func TestWindowedAggFuncs(t *testing.T) {
   279  	tests := []struct {
   280  		Name     string
   281  		Agg      sql.WindowFunction
   282  		Expected sql.Row
   283  	}{
   284  		{
   285  			Name:     "lag",
   286  			Agg:      NewLag(expression.NewGetField(1, types.LongText, "x", true), nil, 2),
   287  			Expected: sql.Row{nil, nil, 1, 2, nil, nil, 1, 2, nil, nil, 1, 2, 3, 4},
   288  		},
   289  		{
   290  			Name: "lag w/ default",
   291  			Agg: NewLag(
   292  				expression.NewGetField(1, types.LongText, "x", true),
   293  				expression.NewGetField(1, types.LongText, "x", true),
   294  				2,
   295  			),
   296  			Expected: sql.Row{1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4},
   297  		},
   298  		{
   299  			Name: "lag nil",
   300  			Agg: NewLag(
   301  				expression.NewGetField(0, types.LongText, "x", true),
   302  				nil,
   303  				1,
   304  			),
   305  			Expected: sql.Row{nil, 1, nil, 3, nil, 1, nil, 3, nil, 1, 2, nil, nil, 5},
   306  		},
   307  		{
   308  			Name:     "lead",
   309  			Agg:      NewLead(expression.NewGetField(1, types.LongText, "x", true), nil, 2),
   310  			Expected: sql.Row{3, 4, nil, nil, 3, 4, nil, nil, 3, 4, 5, 6, nil, nil},
   311  		},
   312  		{
   313  			Name: "lead w/ default",
   314  			Agg: NewLead(
   315  				expression.NewGetField(1, types.LongText, "x", true),
   316  				expression.NewGetField(1, types.LongText, "x", true),
   317  				2,
   318  			),
   319  			Expected: sql.Row{3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 6, 5, 6},
   320  		},
   321  		{
   322  			Name:     "row number",
   323  			Agg:      NewRowNumber(),
   324  			Expected: sql.Row{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5, 6},
   325  		},
   326  		{
   327  			Name: "percent rank no peers",
   328  			Agg:  NewPercentRank([]sql.Expression{}),
   329  			Expected: sql.Row{
   330  				float64(0), float64(0), float64(0), float64(0),
   331  				float64(0), float64(0), float64(0), float64(0),
   332  				float64(0), float64(0), float64(0), float64(0), float64(0), float64(0),
   333  			},
   334  		},
   335  		{
   336  			Name: "percent rank peer groups",
   337  			Agg:  NewPercentRank([]sql.Expression{expression.NewGetField(5, types.LongText, "x", true)}),
   338  			Expected: sql.Row{
   339  				float64(0), float64(0) / float64(3), float64(2) / float64(3), float64(3) / float64(3),
   340  				float64(0), float64(1) / float64(3), float64(1) / float64(3), float64(3) / float64(3),
   341  				float64(0), float64(1) / float64(5), float64(1) / float64(5), float64(3) / float64(5), float64(3) / float64(5), float64(1),
   342  			},
   343  		},
   344  	}
   345  
   346  	buf := []sql.Row{
   347  		{1, 1, int64(1), float64(1), 1, 1},
   348  		{nil, 2, int64(2), float64(2), 1, 1},
   349  		{3, 3, int64(3), float64(3), 1, 2},
   350  		{4, 4, int64(2), float64(4), 1, 3},
   351  		{1, 1, int64(1), float64(1), 2, 1},
   352  		{nil, 2, int64(2), float64(2), 2, 2},
   353  		{3, 3, int64(3), float64(3), 2, 2},
   354  		{4, 4, int64(2), float64(4), 2, 3},
   355  		{1, 1, int64(1), float64(1), 3, 3},
   356  		{2, 2, int64(2), float64(2), 3, 4},
   357  		{nil, 3, int64(3), float64(3), 3, 4},
   358  		{nil, 4, int64(4), float64(4), 3, 5},
   359  		{5, 5, int64(5), float64(5), 3, 5},
   360  		{6, 6, int64(2), float64(6), 3, 6},
   361  	}
   362  
   363  	partitions := []sql.WindowInterval{
   364  		{Start: 0, End: 4},
   365  		{Start: 4, End: 8},
   366  		{Start: 8, End: 14},
   367  	}
   368  
   369  	for _, tt := range tests {
   370  		t.Run(tt.Name, func(t *testing.T) {
   371  			ctx := sql.NewEmptyContext()
   372  			res := make(sql.Row, len(buf))
   373  			i := 0
   374  			for _, p := range partitions {
   375  				err := tt.Agg.StartPartition(ctx, p, buf)
   376  				require.NoError(t, err)
   377  				var framer sql.WindowFramer = NewUnboundedPrecedingToCurrentRowFramer()
   378  				framer, err = tt.Agg.DefaultFramer().NewFramer(p)
   379  				require.NoError(t, err)
   380  				for {
   381  					interval, err := framer.Next(ctx, buf)
   382  					if errors.Is(err, io.EOF) {
   383  						break
   384  					}
   385  					res[i] = tt.Agg.Compute(ctx, interval, buf)
   386  					i++
   387  				}
   388  			}
   389  			require.Equal(t, tt.Expected, res)
   390  		})
   391  	}
   392  
   393  }