github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_functions_test.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "errors" 19 "io" 20 "testing" 21 22 "github.com/stretchr/testify/require" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 func TestGroupedAggFuncs(t *testing.T) { 30 tests := []struct { 31 Name string 32 Agg sql.WindowFunction 33 Expected sql.Row 34 }{ 35 { 36 Name: "count star", 37 Agg: NewCountAgg(expression.NewStar()), 38 Expected: sql.Row{int64(4), int64(4), int64(6)}, 39 }, 40 { 41 Name: "count without nulls", 42 Agg: NewCountAgg(expression.NewGetField(1, types.LongText, "x", true)), 43 Expected: sql.Row{int64(4), int64(4), int64(6)}, 44 }, 45 { 46 Name: "count with nulls", 47 Agg: NewCountAgg(expression.NewGetField(0, types.LongText, "x", true)), 48 Expected: sql.Row{int64(3), int64(3), int64(4)}, 49 }, 50 { 51 Name: "max ints", 52 Agg: NewMaxAgg(expression.NewGetField(1, types.LongText, "x", true)), 53 Expected: sql.Row{4, 4, 6}, 54 }, 55 { 56 Name: "max int64", 57 Agg: NewMaxAgg(expression.NewGetField(2, types.LongText, "x", true)), 58 Expected: sql.Row{int64(3), int64(3), int64(5)}, 59 }, 60 { 61 Name: "max w/ nulls", 62 Agg: NewMaxAgg(expression.NewGetField(0, types.LongText, "x", true)), 63 Expected: sql.Row{4, 4, 6}, 64 }, 65 { 66 Name: "max w/ float", 67 Agg: NewMaxAgg(expression.NewGetField(3, types.LongText, "x", true)), 68 Expected: sql.Row{float64(4), float64(4), float64(6)}, 69 }, 70 { 71 Name: "min ints", 72 Agg: NewMinAgg(expression.NewGetField(1, types.LongText, "x", true)), 73 Expected: sql.Row{1, 1, 1}, 74 }, 75 { 76 Name: "min int64", 77 Agg: NewMinAgg(expression.NewGetField(2, types.LongText, "x", true)), 78 Expected: sql.Row{int64(1), int64(1), int64(1)}, 79 }, 80 { 81 Name: "min w/ nulls", 82 Agg: NewMinAgg(expression.NewGetField(0, types.LongText, "x", true)), 83 Expected: sql.Row{1, 1, 1}, 84 }, 85 { 86 Name: "min w/ float", 87 Agg: NewMinAgg(expression.NewGetField(3, types.LongText, "x", true)), 88 Expected: sql.Row{float64(1), float64(1), float64(1)}, 89 }, 90 { 91 Name: "avg nulls", 92 Agg: NewAvgAgg(expression.NewGetField(0, types.LongText, "x", true)), 93 Expected: sql.Row{float64(8) / float64(3), float64(8) / float64(3), float64(14) / float64(4)}, 94 }, 95 { 96 Name: "avg int", 97 Agg: NewAvgAgg(expression.NewGetField(1, types.LongText, "x", true)), 98 Expected: sql.Row{float64(10) / float64(4), float64(10) / float64(4), float64(21) / float64(6)}, 99 }, 100 { 101 Name: "avg int64", 102 Agg: NewAvgAgg(expression.NewGetField(2, types.LongText, "x", true)), 103 Expected: sql.Row{float64(8) / float64(4), float64(8) / float64(4), float64(17) / float64(6)}, 104 }, 105 { 106 Name: "avg float", 107 Agg: NewAvgAgg(expression.NewGetField(3, types.LongText, "x", true)), 108 Expected: sql.Row{float64(10) / float64(4), float64(10) / float64(4), float64(21) / float64(6)}, 109 }, 110 { 111 Name: "sum nulls", 112 Agg: NewSumAgg(expression.NewGetField(0, types.LongText, "x", true)), 113 Expected: sql.Row{float64(8), float64(8), float64(14)}, 114 }, 115 { 116 Name: "sum ints", 117 Agg: NewSumAgg(expression.NewGetField(1, types.LongText, "x", true)), 118 Expected: sql.Row{float64(10), float64(10), float64(21)}, 119 }, 120 { 121 Name: "sum int64", 122 Agg: NewSumAgg(expression.NewGetField(2, types.LongText, "x", true)), 123 Expected: sql.Row{float64(8), float64(8), float64(17)}, 124 }, 125 { 126 Name: "sum float64", 127 Agg: NewSumAgg(expression.NewGetField(3, types.LongText, "x", true)), 128 Expected: sql.Row{float64(10), float64(10), float64(21)}, 129 }, 130 { 131 Name: "first", 132 Agg: NewFirstAgg(expression.NewGetField(0, types.LongText, "x", true)), 133 Expected: sql.Row{1, 1, 1}, 134 }, 135 { 136 Name: "last", 137 Agg: NewLastAgg(expression.NewGetField(0, types.LongText, "x", true)), 138 Expected: sql.Row{4, 4, 6}, 139 }, 140 // list aggregations 141 { 142 Name: "group concat null", 143 Agg: NewGroupConcatAgg(NewGroupConcat("", nil, ",", []sql.Expression{expression.NewGetField(0, types.LongText, "x", true)}, 1042)), 144 Expected: sql.Row{"1,3,4", "1,3,4", "1,2,5,6"}, 145 }, 146 { 147 Name: "group concat int", 148 Agg: NewGroupConcatAgg(NewGroupConcat("", nil, ",", []sql.Expression{expression.NewGetField(1, types.LongText, "x", true)}, 1042)), 149 Expected: sql.Row{"1,2,3,4", "1,2,3,4", "1,2,3,4,5,6"}, 150 }, 151 { 152 Name: "group concat float", 153 Agg: NewGroupConcatAgg(NewGroupConcat("", nil, ",", []sql.Expression{expression.NewGetField(3, types.LongText, "x", true)}, 1042)), 154 Expected: sql.Row{"1,2,3,4", "1,2,3,4", "1,2,3,4,5,6"}, 155 }, 156 { 157 Name: "json array null", 158 Agg: NewJsonArrayAgg(expression.NewGetField(0, types.LongText, "x", true)), 159 Expected: sql.Row{ 160 types.JSONDocument{Val: []interface{}{1, nil, 3, 4}}, 161 types.JSONDocument{Val: []interface{}{1, nil, 3, 4}}, 162 types.JSONDocument{Val: []interface{}{1, 2, nil, nil, 5, 6}}, 163 }, 164 }, 165 { 166 Name: "json array int", 167 Agg: NewJsonArrayAgg(expression.NewGetField(1, types.LongText, "x", true)), 168 Expected: sql.Row{ 169 types.JSONDocument{Val: []interface{}{1, 2, 3, 4}}, 170 types.JSONDocument{Val: []interface{}{1, 2, 3, 4}}, 171 types.JSONDocument{Val: []interface{}{1, 2, 3, 4, 5, 6}}, 172 }, 173 }, 174 { 175 Name: "json array float", 176 Agg: NewJsonArrayAgg(expression.NewGetField(3, types.LongText, "x", true)), 177 Expected: sql.Row{ 178 types.JSONDocument{Val: []interface{}{float64(1), float64(2), float64(3), float64(4)}}, 179 types.JSONDocument{Val: []interface{}{float64(1), float64(2), float64(3), float64(4)}}, 180 types.JSONDocument{Val: []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5), float64(6)}}, 181 }, 182 }, 183 { 184 Name: "json object null", 185 Agg: NewWindowedJSONObjectAgg( 186 NewJSONObjectAgg( 187 expression.NewGetField(1, types.LongText, "x", true), 188 expression.NewGetField(0, types.LongText, "y", true), 189 ).(*JSONObjectAgg), 190 ), 191 Expected: sql.Row{ 192 types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": nil, "3": 3, "4": 4}}, 193 types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": nil, "3": 3, "4": 4}}, 194 types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": 2, "3": nil, "4": nil, "5": 5, "6": 6}}, 195 }, 196 }, 197 { 198 Name: "json object int", 199 Agg: NewWindowedJSONObjectAgg( 200 NewJSONObjectAgg( 201 expression.NewGetField(1, types.LongText, "x", true), 202 expression.NewGetField(0, types.LongText, "x", true), 203 ).(*JSONObjectAgg), 204 ), 205 Expected: sql.Row{ 206 types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": nil, "3": 3, "4": 4}}, 207 types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": nil, "3": 3, "4": 4}}, 208 types.JSONDocument{Val: map[string]interface{}{"1": 1, "2": 2, "3": nil, "4": nil, "5": 5, "6": 6}}, 209 }, 210 }, 211 { 212 Name: "json object float", 213 Agg: NewWindowedJSONObjectAgg( 214 NewJSONObjectAgg( 215 expression.NewGetField(1, types.LongText, "x", true), 216 expression.NewGetField(3, types.LongText, "x", true), 217 ).(*JSONObjectAgg), 218 ), 219 Expected: sql.Row{ 220 types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4)}}, 221 types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4)}}, 222 types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4), "5": float64(5), "6": float64(6)}}, 223 }, 224 }, 225 { 226 Name: "json object float", 227 Agg: NewWindowedJSONObjectAgg( 228 NewJSONObjectAgg( 229 expression.NewGetField(1, types.LongText, "x", true), 230 expression.NewGetField(3, types.LongText, "x", true), 231 ).(*JSONObjectAgg), 232 ), 233 Expected: sql.Row{ 234 types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4)}}, 235 types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4)}}, 236 types.JSONDocument{Val: map[string]interface{}{"1": float64(1), "2": float64(2), "3": float64(3), "4": float64(4), "5": float64(5), "6": float64(6)}}, 237 }, 238 }, 239 } 240 241 buf := []sql.Row{ 242 {1, 1, int64(1), float64(1), 1, 1}, 243 {nil, 2, int64(2), float64(2), 1, 1}, 244 {3, 3, int64(3), float64(3), 1, 2}, 245 {4, 4, int64(2), float64(4), 1, 3}, 246 {1, 1, int64(1), float64(1), 2, 1}, 247 {nil, 2, int64(2), float64(2), 2, 2}, 248 {3, 3, int64(3), float64(3), 2, 2}, 249 {4, 4, int64(2), float64(4), 3}, 250 {1, 1, int64(1), float64(1), 1}, 251 {2, 2, int64(2), float64(2), 2}, 252 {nil, 3, int64(3), float64(3), 2}, 253 {nil, 4, int64(4), float64(4), 3}, 254 {5, 5, int64(5), float64(5), 3}, 255 {6, 6, int64(2), float64(6), 4}, 256 } 257 258 partitions := []sql.WindowInterval{ 259 {Start: 0, End: 4}, 260 {Start: 4, End: 8}, 261 {Start: 8, End: 14}, 262 } 263 264 for _, tt := range tests { 265 t.Run(tt.Name, func(t *testing.T) { 266 ctx := sql.NewEmptyContext() 267 res := make(sql.Row, len(partitions)) 268 for i, p := range partitions { 269 err := tt.Agg.StartPartition(ctx, p, buf) 270 require.NoError(t, err) 271 res[i] = tt.Agg.Compute(ctx, p, buf) 272 } 273 require.Equal(t, tt.Expected, res) 274 }) 275 } 276 } 277 278 func TestWindowedAggFuncs(t *testing.T) { 279 tests := []struct { 280 Name string 281 Agg sql.WindowFunction 282 Expected sql.Row 283 }{ 284 { 285 Name: "lag", 286 Agg: NewLag(expression.NewGetField(1, types.LongText, "x", true), nil, 2), 287 Expected: sql.Row{nil, nil, 1, 2, nil, nil, 1, 2, nil, nil, 1, 2, 3, 4}, 288 }, 289 { 290 Name: "lag w/ default", 291 Agg: NewLag( 292 expression.NewGetField(1, types.LongText, "x", true), 293 expression.NewGetField(1, types.LongText, "x", true), 294 2, 295 ), 296 Expected: sql.Row{1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4}, 297 }, 298 { 299 Name: "lag nil", 300 Agg: NewLag( 301 expression.NewGetField(0, types.LongText, "x", true), 302 nil, 303 1, 304 ), 305 Expected: sql.Row{nil, 1, nil, 3, nil, 1, nil, 3, nil, 1, 2, nil, nil, 5}, 306 }, 307 { 308 Name: "lead", 309 Agg: NewLead(expression.NewGetField(1, types.LongText, "x", true), nil, 2), 310 Expected: sql.Row{3, 4, nil, nil, 3, 4, nil, nil, 3, 4, 5, 6, nil, nil}, 311 }, 312 { 313 Name: "lead w/ default", 314 Agg: NewLead( 315 expression.NewGetField(1, types.LongText, "x", true), 316 expression.NewGetField(1, types.LongText, "x", true), 317 2, 318 ), 319 Expected: sql.Row{3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 6, 5, 6}, 320 }, 321 { 322 Name: "row number", 323 Agg: NewRowNumber(), 324 Expected: sql.Row{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5, 6}, 325 }, 326 { 327 Name: "percent rank no peers", 328 Agg: NewPercentRank([]sql.Expression{}), 329 Expected: sql.Row{ 330 float64(0), float64(0), float64(0), float64(0), 331 float64(0), float64(0), float64(0), float64(0), 332 float64(0), float64(0), float64(0), float64(0), float64(0), float64(0), 333 }, 334 }, 335 { 336 Name: "percent rank peer groups", 337 Agg: NewPercentRank([]sql.Expression{expression.NewGetField(5, types.LongText, "x", true)}), 338 Expected: sql.Row{ 339 float64(0), float64(0) / float64(3), float64(2) / float64(3), float64(3) / float64(3), 340 float64(0), float64(1) / float64(3), float64(1) / float64(3), float64(3) / float64(3), 341 float64(0), float64(1) / float64(5), float64(1) / float64(5), float64(3) / float64(5), float64(3) / float64(5), float64(1), 342 }, 343 }, 344 } 345 346 buf := []sql.Row{ 347 {1, 1, int64(1), float64(1), 1, 1}, 348 {nil, 2, int64(2), float64(2), 1, 1}, 349 {3, 3, int64(3), float64(3), 1, 2}, 350 {4, 4, int64(2), float64(4), 1, 3}, 351 {1, 1, int64(1), float64(1), 2, 1}, 352 {nil, 2, int64(2), float64(2), 2, 2}, 353 {3, 3, int64(3), float64(3), 2, 2}, 354 {4, 4, int64(2), float64(4), 2, 3}, 355 {1, 1, int64(1), float64(1), 3, 3}, 356 {2, 2, int64(2), float64(2), 3, 4}, 357 {nil, 3, int64(3), float64(3), 3, 4}, 358 {nil, 4, int64(4), float64(4), 3, 5}, 359 {5, 5, int64(5), float64(5), 3, 5}, 360 {6, 6, int64(2), float64(6), 3, 6}, 361 } 362 363 partitions := []sql.WindowInterval{ 364 {Start: 0, End: 4}, 365 {Start: 4, End: 8}, 366 {Start: 8, End: 14}, 367 } 368 369 for _, tt := range tests { 370 t.Run(tt.Name, func(t *testing.T) { 371 ctx := sql.NewEmptyContext() 372 res := make(sql.Row, len(buf)) 373 i := 0 374 for _, p := range partitions { 375 err := tt.Agg.StartPartition(ctx, p, buf) 376 require.NoError(t, err) 377 var framer sql.WindowFramer = NewUnboundedPrecedingToCurrentRowFramer() 378 framer, err = tt.Agg.DefaultFramer().NewFramer(p) 379 require.NoError(t, err) 380 for { 381 interval, err := framer.Next(ctx, buf) 382 if errors.Is(err, io.EOF) { 383 break 384 } 385 res[i] = tt.Agg.Compute(ctx, interval, buf) 386 i++ 387 } 388 } 389 require.Equal(t, tt.Expected, res) 390 }) 391 } 392 393 }