github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/group_by_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "testing" 19 20 "github.com/dolthub/vitess/go/vt/proto/query" 21 "github.com/stretchr/testify/require" 22 23 "github.com/dolthub/go-mysql-server/memory" 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 27 "github.com/dolthub/go-mysql-server/sql/plan" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 func TestGroupBySchema(t *testing.T) { 32 require := require.New(t) 33 34 db := memory.NewDatabase("test") 35 child := memory.NewTable(db.BaseDatabase, "test", sql.PrimaryKeySchema{}, nil) 36 agg := []sql.Expression{ 37 expression.NewAlias("c1", expression.NewLiteral("s", types.LongText)), 38 expression.NewAlias("c2", aggregation.NewCount(expression.NewStar())), 39 } 40 gb := plan.NewGroupBy(agg, nil, plan.NewResolvedTable(child, nil, nil)) 41 require.Equal(sql.Schema{ 42 {Name: "c1", Type: types.LongText}, 43 {Name: "c2", Type: types.Int64}, 44 }, gb.Schema()) 45 } 46 47 func TestGroupByResolved(t *testing.T) { 48 require := require.New(t) 49 50 db := memory.NewDatabase("test") 51 child := memory.NewTable(db.BaseDatabase, "test", sql.PrimaryKeySchema{}, nil) 52 agg := []sql.Expression{ 53 expression.NewAlias("c2", aggregation.NewCount(expression.NewStar())), 54 } 55 gb := plan.NewGroupBy(agg, nil, plan.NewResolvedTable(child, nil, nil)) 56 require.True(gb.Resolved()) 57 58 agg = []sql.Expression{ 59 expression.NewStar(), 60 } 61 gb = plan.NewGroupBy(agg, nil, plan.NewResolvedTable(child, nil, nil)) 62 require.False(gb.Resolved()) 63 } 64 65 func TestGroupByRowIter(t *testing.T) { 66 require := require.New(t) 67 68 db := memory.NewDatabase("test") 69 pro := memory.NewDBProvider(db) 70 ctx := newContext(pro) 71 72 childSchema := sql.Schema{ 73 {Name: "col1", Type: types.LongText}, 74 {Name: "col2", Type: types.Int64}, 75 } 76 child := memory.NewTable(db.BaseDatabase, "test", sql.NewPrimaryKeySchema(childSchema), nil) 77 78 rows := []sql.Row{ 79 sql.NewRow("col1_1", int64(1111)), 80 sql.NewRow("col1_1", int64(1111)), 81 sql.NewRow("col1_2", int64(4444)), 82 sql.NewRow("col1_1", int64(1111)), 83 sql.NewRow("col1_2", int64(4444)), 84 } 85 86 for _, r := range rows { 87 require.NoError(child.Insert(ctx, r)) 88 } 89 90 p := plan.NewSort( 91 []sql.SortField{ 92 { 93 Column: expression.NewGetField(0, types.LongText, "col1", true), 94 Order: sql.Ascending, 95 }, { 96 Column: expression.NewGetField(1, types.Int64, "col2", true), 97 Order: sql.Ascending, 98 }, 99 }, 100 plan.NewGroupBy( 101 []sql.Expression{ 102 expression.NewGetField(0, types.LongText, "col1", true), 103 expression.NewGetField(1, types.Int64, "col2", true), 104 }, 105 []sql.Expression{ 106 expression.NewGetField(0, types.LongText, "col1", true), 107 expression.NewGetField(1, types.Int64, "col2", true), 108 }, 109 plan.NewResolvedTable(child, nil, nil), 110 )) 111 112 require.Equal(1, len(p.Children())) 113 114 rows, err := NodeToRows(ctx, p) 115 require.NoError(err) 116 require.Len(rows, 2) 117 118 require.Equal(sql.NewRow("col1_1", int64(1111)), rows[0]) 119 require.Equal(sql.NewRow("col1_2", int64(4444)), rows[1]) 120 } 121 122 func TestGroupByAggregationGrouping(t *testing.T) { 123 require := require.New(t) 124 125 db := memory.NewDatabase("test") 126 pro := memory.NewDBProvider(db) 127 ctx := newContext(pro) 128 129 childSchema := sql.Schema{ 130 {Name: "col1", Type: types.LongText}, 131 {Name: "col2", Type: types.Int64}, 132 } 133 134 child := memory.NewTable(db.BaseDatabase, "test", sql.NewPrimaryKeySchema(childSchema), nil) 135 136 rows := []sql.Row{ 137 sql.NewRow("col1_1", int64(1111)), 138 sql.NewRow("col1_1", int64(1111)), 139 sql.NewRow("col1_2", int64(4444)), 140 sql.NewRow("col1_1", int64(1111)), 141 sql.NewRow("col1_2", int64(4444)), 142 } 143 144 for _, r := range rows { 145 require.NoError(child.Insert(ctx, r)) 146 } 147 148 p := plan.NewGroupBy( 149 []sql.Expression{ 150 aggregation.NewCount(expression.NewGetField(0, types.LongText, "col1", true)), 151 expression.NewIsNull(expression.NewGetField(1, types.Int64, "col2", true)), 152 }, 153 []sql.Expression{ 154 expression.NewGetField(0, types.LongText, "col1", true), 155 expression.NewIsNull(expression.NewGetField(1, types.Int64, "col2", true)), 156 }, 157 plan.NewResolvedTable(child, nil, nil), 158 ) 159 160 rows, err := NodeToRows(ctx, p) 161 require.NoError(err) 162 163 expected := []sql.Row{ 164 {int64(3), false}, 165 {int64(2), false}, 166 } 167 168 require.Equal(expected, rows) 169 } 170 171 func TestGroupByCollations(t *testing.T) { 172 tString := types.MustCreateString(query.Type_VARCHAR, 255, sql.Collation_utf8mb4_0900_ai_ci) 173 tEnum := types.MustCreateEnumType([]string{"col1_1", "col1_2"}, sql.Collation_utf8mb4_0900_ai_ci) 174 tSet := types.MustCreateSetType([]string{"col1_1", "col1_2"}, sql.Collation_utf8mb4_0900_ai_ci) 175 176 var testCases = []struct { 177 Type sql.Type 178 Value func(t *testing.T, v string) any 179 }{ 180 { 181 Type: tString, 182 Value: func(t *testing.T, v string) any { return v }, 183 }, 184 { 185 Type: tEnum, 186 Value: func(t *testing.T, v string) any { 187 conv, _, err := tEnum.Convert(v) 188 require.NoError(t, err) 189 return conv 190 }, 191 }, 192 { 193 Type: tSet, 194 Value: func(t *testing.T, v string) any { 195 conv, _, err := tSet.Convert(v) 196 require.NoError(t, err) 197 return conv 198 }, 199 }, 200 } 201 202 for _, tc := range testCases { 203 t.Run(tc.Type.String(), func(t *testing.T) { 204 require := require.New(t) 205 206 childSchema := sql.Schema{ 207 {Name: "col1", Type: tc.Type}, 208 {Name: "col2", Type: types.Int64}, 209 } 210 211 db := memory.NewDatabase("test") 212 pro := memory.NewDBProvider(db) 213 ctx := newContext(pro) 214 215 child := memory.NewTable(db.BaseDatabase, "test", sql.NewPrimaryKeySchema(childSchema), nil) 216 217 rows := []sql.Row{ 218 sql.NewRow(tc.Value(t, "col1_1"), int64(1111)), 219 sql.NewRow(tc.Value(t, "Col1_1"), int64(1111)), 220 sql.NewRow(tc.Value(t, "col1_2"), int64(4444)), 221 sql.NewRow(tc.Value(t, "col1_1"), int64(1111)), 222 sql.NewRow(tc.Value(t, "Col1_2"), int64(4444)), 223 } 224 225 for _, r := range rows { 226 require.NoError(child.Insert(ctx, r)) 227 } 228 229 p := plan.NewGroupBy( 230 []sql.Expression{ 231 aggregation.NewSum( 232 expression.NewGetFieldWithTable(1, 1, types.Int64, "", "test", "col2", false), 233 ), 234 }, 235 []sql.Expression{ 236 expression.NewGetFieldWithTable(0, 1, tc.Type, "", "test", "col1", false), 237 }, 238 plan.NewResolvedTable(child, nil, nil), 239 ) 240 241 rows, err := NodeToRows(ctx, p) 242 require.NoError(err) 243 244 expected := []sql.Row{ 245 {float64(3333)}, 246 {float64(8888)}, 247 } 248 249 require.Equal(expected, rows) 250 }) 251 } 252 } 253 254 func BenchmarkGroupBy(b *testing.B) { 255 table := benchmarkTable(b) 256 257 node := plan.NewGroupBy( 258 []sql.Expression{ 259 aggregation.NewMax( 260 expression.NewGetField(1, types.Int64, "b", false), 261 ), 262 }, 263 nil, 264 plan.NewResolvedTable(table, nil, nil), 265 ) 266 267 expected := []sql.Row{{int64(200)}} 268 269 bench := func(node sql.Node, expected []sql.Row) func(*testing.B) { 270 return func(b *testing.B) { 271 require := require.New(b) 272 273 for i := 0; i < b.N; i++ { 274 ctx := sql.NewEmptyContext() 275 iter, err := DefaultBuilder.Build(ctx, node, nil) 276 require.NoError(err) 277 278 rows, err := sql.RowIterToRows(ctx, iter) 279 require.NoError(err) 280 require.ElementsMatch(expected, rows) 281 } 282 } 283 } 284 285 b.Run("no grouping", bench(node, expected)) 286 287 node = plan.NewGroupBy( 288 []sql.Expression{ 289 expression.NewGetField(0, types.Int64, "a", false), 290 aggregation.NewMax( 291 expression.NewGetField(1, types.Int64, "b", false), 292 ), 293 }, 294 []sql.Expression{ 295 expression.NewGetField(0, types.Int64, "a", false), 296 }, 297 plan.NewResolvedTable(table, nil, nil), 298 ) 299 300 expected = []sql.Row{} 301 for i := int64(0); i < 50; i++ { 302 expected = append(expected, sql.NewRow(i, int64(200))) 303 } 304 305 b.Run("grouping", bench(node, expected)) 306 } 307 308 func benchmarkTable(t testing.TB) sql.Table { 309 t.Helper() 310 require := require.New(t) 311 312 db := memory.NewDatabase("test") 313 table := memory.NewTable(db.BaseDatabase, "test", sql.NewPrimaryKeySchema(sql.Schema{ 314 {Name: "a", Type: types.Int64}, 315 {Name: "b", Type: types.Int64}, 316 }), nil) 317 318 for i := int64(0); i < 50; i++ { 319 for j := int64(200); j > 0; j-- { 320 row := sql.NewRow(i, j) 321 require.NoError(table.Insert(sql.NewEmptyContext(), row)) 322 } 323 } 324 325 return table 326 } 327 328 // NodeToRows converts a node to a slice of rows. 329 func NodeToRows(ctx *sql.Context, n sql.Node) ([]sql.Row, error) { 330 // TODO can't have sql depend on rowexec 331 // move execution tests to rowexec 332 i, err := DefaultBuilder.Build(ctx, n, nil) 333 if err != nil { 334 return nil, err 335 } 336 337 return sql.RowIterToRows(ctx, i) 338 }