github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/group_by_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/dolthub/vitess/go/vt/proto/query"
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/dolthub/go-mysql-server/memory"
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
    27  	"github.com/dolthub/go-mysql-server/sql/plan"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  func TestGroupBySchema(t *testing.T) {
    32  	require := require.New(t)
    33  
    34  	db := memory.NewDatabase("test")
    35  	child := memory.NewTable(db.BaseDatabase, "test", sql.PrimaryKeySchema{}, nil)
    36  	agg := []sql.Expression{
    37  		expression.NewAlias("c1", expression.NewLiteral("s", types.LongText)),
    38  		expression.NewAlias("c2", aggregation.NewCount(expression.NewStar())),
    39  	}
    40  	gb := plan.NewGroupBy(agg, nil, plan.NewResolvedTable(child, nil, nil))
    41  	require.Equal(sql.Schema{
    42  		{Name: "c1", Type: types.LongText},
    43  		{Name: "c2", Type: types.Int64},
    44  	}, gb.Schema())
    45  }
    46  
    47  func TestGroupByResolved(t *testing.T) {
    48  	require := require.New(t)
    49  
    50  	db := memory.NewDatabase("test")
    51  	child := memory.NewTable(db.BaseDatabase, "test", sql.PrimaryKeySchema{}, nil)
    52  	agg := []sql.Expression{
    53  		expression.NewAlias("c2", aggregation.NewCount(expression.NewStar())),
    54  	}
    55  	gb := plan.NewGroupBy(agg, nil, plan.NewResolvedTable(child, nil, nil))
    56  	require.True(gb.Resolved())
    57  
    58  	agg = []sql.Expression{
    59  		expression.NewStar(),
    60  	}
    61  	gb = plan.NewGroupBy(agg, nil, plan.NewResolvedTable(child, nil, nil))
    62  	require.False(gb.Resolved())
    63  }
    64  
    65  func TestGroupByRowIter(t *testing.T) {
    66  	require := require.New(t)
    67  
    68  	db := memory.NewDatabase("test")
    69  	pro := memory.NewDBProvider(db)
    70  	ctx := newContext(pro)
    71  
    72  	childSchema := sql.Schema{
    73  		{Name: "col1", Type: types.LongText},
    74  		{Name: "col2", Type: types.Int64},
    75  	}
    76  	child := memory.NewTable(db.BaseDatabase, "test", sql.NewPrimaryKeySchema(childSchema), nil)
    77  
    78  	rows := []sql.Row{
    79  		sql.NewRow("col1_1", int64(1111)),
    80  		sql.NewRow("col1_1", int64(1111)),
    81  		sql.NewRow("col1_2", int64(4444)),
    82  		sql.NewRow("col1_1", int64(1111)),
    83  		sql.NewRow("col1_2", int64(4444)),
    84  	}
    85  
    86  	for _, r := range rows {
    87  		require.NoError(child.Insert(ctx, r))
    88  	}
    89  
    90  	p := plan.NewSort(
    91  		[]sql.SortField{
    92  			{
    93  				Column: expression.NewGetField(0, types.LongText, "col1", true),
    94  				Order:  sql.Ascending,
    95  			}, {
    96  				Column: expression.NewGetField(1, types.Int64, "col2", true),
    97  				Order:  sql.Ascending,
    98  			},
    99  		},
   100  		plan.NewGroupBy(
   101  			[]sql.Expression{
   102  				expression.NewGetField(0, types.LongText, "col1", true),
   103  				expression.NewGetField(1, types.Int64, "col2", true),
   104  			},
   105  			[]sql.Expression{
   106  				expression.NewGetField(0, types.LongText, "col1", true),
   107  				expression.NewGetField(1, types.Int64, "col2", true),
   108  			},
   109  			plan.NewResolvedTable(child, nil, nil),
   110  		))
   111  
   112  	require.Equal(1, len(p.Children()))
   113  
   114  	rows, err := NodeToRows(ctx, p)
   115  	require.NoError(err)
   116  	require.Len(rows, 2)
   117  
   118  	require.Equal(sql.NewRow("col1_1", int64(1111)), rows[0])
   119  	require.Equal(sql.NewRow("col1_2", int64(4444)), rows[1])
   120  }
   121  
   122  func TestGroupByAggregationGrouping(t *testing.T) {
   123  	require := require.New(t)
   124  
   125  	db := memory.NewDatabase("test")
   126  	pro := memory.NewDBProvider(db)
   127  	ctx := newContext(pro)
   128  
   129  	childSchema := sql.Schema{
   130  		{Name: "col1", Type: types.LongText},
   131  		{Name: "col2", Type: types.Int64},
   132  	}
   133  
   134  	child := memory.NewTable(db.BaseDatabase, "test", sql.NewPrimaryKeySchema(childSchema), nil)
   135  
   136  	rows := []sql.Row{
   137  		sql.NewRow("col1_1", int64(1111)),
   138  		sql.NewRow("col1_1", int64(1111)),
   139  		sql.NewRow("col1_2", int64(4444)),
   140  		sql.NewRow("col1_1", int64(1111)),
   141  		sql.NewRow("col1_2", int64(4444)),
   142  	}
   143  
   144  	for _, r := range rows {
   145  		require.NoError(child.Insert(ctx, r))
   146  	}
   147  
   148  	p := plan.NewGroupBy(
   149  		[]sql.Expression{
   150  			aggregation.NewCount(expression.NewGetField(0, types.LongText, "col1", true)),
   151  			expression.NewIsNull(expression.NewGetField(1, types.Int64, "col2", true)),
   152  		},
   153  		[]sql.Expression{
   154  			expression.NewGetField(0, types.LongText, "col1", true),
   155  			expression.NewIsNull(expression.NewGetField(1, types.Int64, "col2", true)),
   156  		},
   157  		plan.NewResolvedTable(child, nil, nil),
   158  	)
   159  
   160  	rows, err := NodeToRows(ctx, p)
   161  	require.NoError(err)
   162  
   163  	expected := []sql.Row{
   164  		{int64(3), false},
   165  		{int64(2), false},
   166  	}
   167  
   168  	require.Equal(expected, rows)
   169  }
   170  
   171  func TestGroupByCollations(t *testing.T) {
   172  	tString := types.MustCreateString(query.Type_VARCHAR, 255, sql.Collation_utf8mb4_0900_ai_ci)
   173  	tEnum := types.MustCreateEnumType([]string{"col1_1", "col1_2"}, sql.Collation_utf8mb4_0900_ai_ci)
   174  	tSet := types.MustCreateSetType([]string{"col1_1", "col1_2"}, sql.Collation_utf8mb4_0900_ai_ci)
   175  
   176  	var testCases = []struct {
   177  		Type  sql.Type
   178  		Value func(t *testing.T, v string) any
   179  	}{
   180  		{
   181  			Type:  tString,
   182  			Value: func(t *testing.T, v string) any { return v },
   183  		},
   184  		{
   185  			Type: tEnum,
   186  			Value: func(t *testing.T, v string) any {
   187  				conv, _, err := tEnum.Convert(v)
   188  				require.NoError(t, err)
   189  				return conv
   190  			},
   191  		},
   192  		{
   193  			Type: tSet,
   194  			Value: func(t *testing.T, v string) any {
   195  				conv, _, err := tSet.Convert(v)
   196  				require.NoError(t, err)
   197  				return conv
   198  			},
   199  		},
   200  	}
   201  
   202  	for _, tc := range testCases {
   203  		t.Run(tc.Type.String(), func(t *testing.T) {
   204  			require := require.New(t)
   205  
   206  			childSchema := sql.Schema{
   207  				{Name: "col1", Type: tc.Type},
   208  				{Name: "col2", Type: types.Int64},
   209  			}
   210  
   211  			db := memory.NewDatabase("test")
   212  			pro := memory.NewDBProvider(db)
   213  			ctx := newContext(pro)
   214  
   215  			child := memory.NewTable(db.BaseDatabase, "test", sql.NewPrimaryKeySchema(childSchema), nil)
   216  
   217  			rows := []sql.Row{
   218  				sql.NewRow(tc.Value(t, "col1_1"), int64(1111)),
   219  				sql.NewRow(tc.Value(t, "Col1_1"), int64(1111)),
   220  				sql.NewRow(tc.Value(t, "col1_2"), int64(4444)),
   221  				sql.NewRow(tc.Value(t, "col1_1"), int64(1111)),
   222  				sql.NewRow(tc.Value(t, "Col1_2"), int64(4444)),
   223  			}
   224  
   225  			for _, r := range rows {
   226  				require.NoError(child.Insert(ctx, r))
   227  			}
   228  
   229  			p := plan.NewGroupBy(
   230  				[]sql.Expression{
   231  					aggregation.NewSum(
   232  						expression.NewGetFieldWithTable(1, 1, types.Int64, "", "test", "col2", false),
   233  					),
   234  				},
   235  				[]sql.Expression{
   236  					expression.NewGetFieldWithTable(0, 1, tc.Type, "", "test", "col1", false),
   237  				},
   238  				plan.NewResolvedTable(child, nil, nil),
   239  			)
   240  
   241  			rows, err := NodeToRows(ctx, p)
   242  			require.NoError(err)
   243  
   244  			expected := []sql.Row{
   245  				{float64(3333)},
   246  				{float64(8888)},
   247  			}
   248  
   249  			require.Equal(expected, rows)
   250  		})
   251  	}
   252  }
   253  
   254  func BenchmarkGroupBy(b *testing.B) {
   255  	table := benchmarkTable(b)
   256  
   257  	node := plan.NewGroupBy(
   258  		[]sql.Expression{
   259  			aggregation.NewMax(
   260  				expression.NewGetField(1, types.Int64, "b", false),
   261  			),
   262  		},
   263  		nil,
   264  		plan.NewResolvedTable(table, nil, nil),
   265  	)
   266  
   267  	expected := []sql.Row{{int64(200)}}
   268  
   269  	bench := func(node sql.Node, expected []sql.Row) func(*testing.B) {
   270  		return func(b *testing.B) {
   271  			require := require.New(b)
   272  
   273  			for i := 0; i < b.N; i++ {
   274  				ctx := sql.NewEmptyContext()
   275  				iter, err := DefaultBuilder.Build(ctx, node, nil)
   276  				require.NoError(err)
   277  
   278  				rows, err := sql.RowIterToRows(ctx, iter)
   279  				require.NoError(err)
   280  				require.ElementsMatch(expected, rows)
   281  			}
   282  		}
   283  	}
   284  
   285  	b.Run("no grouping", bench(node, expected))
   286  
   287  	node = plan.NewGroupBy(
   288  		[]sql.Expression{
   289  			expression.NewGetField(0, types.Int64, "a", false),
   290  			aggregation.NewMax(
   291  				expression.NewGetField(1, types.Int64, "b", false),
   292  			),
   293  		},
   294  		[]sql.Expression{
   295  			expression.NewGetField(0, types.Int64, "a", false),
   296  		},
   297  		plan.NewResolvedTable(table, nil, nil),
   298  	)
   299  
   300  	expected = []sql.Row{}
   301  	for i := int64(0); i < 50; i++ {
   302  		expected = append(expected, sql.NewRow(i, int64(200)))
   303  	}
   304  
   305  	b.Run("grouping", bench(node, expected))
   306  }
   307  
   308  func benchmarkTable(t testing.TB) sql.Table {
   309  	t.Helper()
   310  	require := require.New(t)
   311  
   312  	db := memory.NewDatabase("test")
   313  	table := memory.NewTable(db.BaseDatabase, "test", sql.NewPrimaryKeySchema(sql.Schema{
   314  		{Name: "a", Type: types.Int64},
   315  		{Name: "b", Type: types.Int64},
   316  	}), nil)
   317  
   318  	for i := int64(0); i < 50; i++ {
   319  		for j := int64(200); j > 0; j-- {
   320  			row := sql.NewRow(i, j)
   321  			require.NoError(table.Insert(sql.NewEmptyContext(), row))
   322  		}
   323  	}
   324  
   325  	return table
   326  }
   327  
   328  // NodeToRows converts a node to a slice of rows.
   329  func NodeToRows(ctx *sql.Context, n sql.Node) ([]sql.Row, error) {
   330  	// TODO can't have sql depend on rowexec
   331  	// move execution tests to rowexec
   332  	i, err := DefaultBuilder.Build(ctx, n, nil)
   333  	if err != nil {
   334  		return nil, err
   335  	}
   336  
   337  	return sql.RowIterToRows(ctx, i)
   338  }