github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/window_partition_test.go (about) 1 // Copyright 2022 DoltHub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "context" 19 "io" 20 "testing" 21 22 "github.com/stretchr/testify/require" 23 24 "github.com/dolthub/go-mysql-server/memory" 25 "github.com/dolthub/go-mysql-server/sql" 26 "github.com/dolthub/go-mysql-server/sql/expression" 27 "github.com/dolthub/go-mysql-server/sql/types" 28 ) 29 30 func TestWindowPartitionIter(t *testing.T) { 31 tests := []struct { 32 Name string 33 Iter *WindowPartitionIter 34 Expected []sql.Row 35 }{ 36 { 37 Name: "group by", 38 Iter: NewWindowPartitionIter( 39 &WindowPartition{ 40 PartitionBy: partitionByX, 41 SortBy: sortByW, 42 Aggs: []*Aggregation{ 43 NewAggregation(lastX, NewGroupByFramer()), 44 NewAggregation(firstY, NewGroupByFramer()), 45 NewAggregation(sumZ, NewGroupByFramer()), 46 }, 47 }, 48 ), 49 Expected: []sql.Row{ 50 {"forest", "leaf", float64(27)}, 51 {"desert", "sand", float64(23)}, 52 }, 53 }, 54 { 55 Name: "partition level desc", 56 Iter: NewWindowPartitionIter( 57 &WindowPartition{ 58 PartitionBy: partitionByX, 59 SortBy: sortByWDesc, 60 Aggs: []*Aggregation{ 61 NewAggregation(lastX, NewPartitionFramer()), 62 NewAggregation(firstY, NewPartitionFramer()), 63 NewAggregation(sumZ, NewPartitionFramer()), 64 }, 65 }), 66 Expected: []sql.Row{ 67 {"forest", "wildflower", float64(27)}, 68 {"forest", "wildflower", float64(27)}, 69 {"forest", "wildflower", float64(27)}, 70 {"forest", "wildflower", float64(27)}, 71 {"forest", "wildflower", float64(27)}, 72 {"desert", "mummy", float64(23)}, 73 {"desert", "mummy", float64(23)}, 74 {"desert", "mummy", float64(23)}, 75 {"desert", "mummy", float64(23)}, 76 }, 77 }, 78 } 79 80 for _, tt := range tests { 81 t.Run(tt.Name, func(t *testing.T) { 82 db := memory.NewDatabase("test") 83 pro := memory.NewDBProvider(db) 84 ctx := sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro))) 85 86 tt.Iter.child = mustNewRowIter(t, db, ctx) 87 res, err := sql.RowIterToRows(ctx, tt.Iter) 88 require.NoError(t, err) 89 require.Equal(t, tt.Expected, res) 90 }) 91 } 92 } 93 94 func mustNewRowIter(t *testing.T, db *memory.Database, ctx *sql.Context) sql.RowIter { 95 childSchema := sql.NewPrimaryKeySchema(sql.Schema{ 96 {Name: "w", Type: types.Int64, Nullable: true}, 97 {Name: "x", Type: types.Text, Nullable: true}, 98 {Name: "y", Type: types.Text, Nullable: true}, 99 {Name: "z", Type: types.Int32, Nullable: true}, 100 }) 101 table := memory.NewTable(db, "test", childSchema, nil) 102 103 rows := []sql.Row{ 104 {int64(1), "forest", "leaf", int32(4)}, 105 {int64(2), "forest", "bark", int32(4)}, 106 {int64(3), "forest", "canopy", int32(6)}, 107 {int64(4), "forest", "bug", int32(3)}, 108 {int64(5), "forest", "wildflower", int32(10)}, 109 {int64(6), "desert", "sand", int32(4)}, 110 {int64(7), "desert", "cactus", int32(6)}, 111 {int64(8), "desert", "scorpion", int32(8)}, 112 {int64(9), "desert", "mummy", int32(5)}, 113 } 114 115 for _, r := range rows { 116 require.NoError(t, table.Insert(ctx, r)) 117 } 118 119 pIter, err := table.Partitions(ctx) 120 require.NoError(t, err) 121 122 p, err := pIter.Next(ctx) 123 require.NoError(t, err) 124 125 child, err := table.PartitionRows(ctx, p) 126 require.NoError(t, err) 127 return child 128 } 129 130 func TestWindowPartition_MaterializeInput(t *testing.T) { 131 db := memory.NewDatabase("test") 132 pro := memory.NewDBProvider(db) 133 ctx := sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro))) 134 135 i := WindowPartitionIter{ 136 w: &WindowPartition{ 137 PartitionBy: partitionByX, 138 SortBy: sortByW, 139 }, 140 child: mustNewRowIter(t, db, ctx), 141 } 142 143 buf, ordering, err := i.materializeInput(ctx) 144 require.NoError(t, err) 145 expBuf := []sql.Row{ 146 {int64(1), "forest", "leaf", int32(4)}, 147 {int64(2), "forest", "bark", int32(4)}, 148 {int64(3), "forest", "canopy", int32(6)}, 149 {int64(4), "forest", "bug", int32(3)}, 150 {int64(5), "forest", "wildflower", int32(10)}, 151 {int64(6), "desert", "sand", int32(4)}, 152 {int64(7), "desert", "cactus", int32(6)}, 153 {int64(8), "desert", "scorpion", int32(8)}, 154 {int64(9), "desert", "mummy", int32(5)}, 155 } 156 require.ElementsMatch(t, expBuf, buf) 157 expOrd := []int{0, 1, 2, 3, 4, 5, 6, 7, 8} 158 require.ElementsMatch(t, expOrd, ordering) 159 } 160 161 func TestWindowPartition_InitializePartitions(t *testing.T) { 162 ctx := sql.NewEmptyContext() 163 i := NewWindowPartitionIter( 164 &WindowPartition{ 165 PartitionBy: partitionByX, 166 }) 167 i.input = []sql.Row{ 168 {int64(1), "forest", "leaf", int32(4)}, 169 {int64(2), "forest", "bark", int32(4)}, 170 {int64(3), "forest", "canopy", int32(6)}, 171 {int64(4), "forest", "bug", int32(3)}, 172 {int64(5), "forest", "wildflower", int32(10)}, 173 {int64(6), "desert", "sand", int32(4)}, 174 {int64(7), "desert", "cactus", int32(6)}, 175 {int64(8), "desert", "scorpion", int32(8)}, 176 {int64(9), "desert", "mummy", int32(5)}, 177 } 178 partitions, err := i.initializePartitions(ctx) 179 require.NoError(t, err) 180 expPartitions := []sql.WindowInterval{ 181 {Start: 0, End: 5}, 182 {Start: 5, End: 9}, 183 } 184 require.ElementsMatch(t, expPartitions, partitions) 185 } 186 187 func TestWindowPartition_MaterializeOutput(t *testing.T) { 188 t.Run("non nil input", func(t *testing.T) { 189 ctx := sql.NewEmptyContext() 190 i := NewWindowPartitionIter( 191 &WindowPartition{ 192 PartitionBy: partitionByX, 193 Aggs: []*Aggregation{ 194 NewAggregation(sumZ, NewGroupByFramer()), 195 }, 196 }) 197 i.input = []sql.Row{ 198 {int64(1), "forest", "leaf", 4}, 199 {int64(2), "forest", "bark", 4}, 200 {int64(3), "forest", "canopy", 6}, 201 {int64(4), "forest", "bug", 3}, 202 {int64(5), "forest", "wildflower", 10}, 203 {int64(6), "desert", "sand", 4}, 204 {int64(7), "desert", "cactus", 6}, 205 {int64(8), "desert", "scorpion", 8}, 206 {int64(9), "desert", "mummy", 5}, 207 } 208 i.partitions = []sql.WindowInterval{{0, 5}, {5, 9}} 209 i.outputOrdering = []int{0, 1, 2, 3, 4, 5, 6, 7, 8} 210 output, err := i.materializeOutput(ctx) 211 require.NoError(t, err) 212 expOutput := []sql.Row{ 213 {float64(27), 0}, 214 {float64(23), 5}, 215 } 216 require.ElementsMatch(t, expOutput, output) 217 }) 218 219 t.Run("nil input with partition by", func(t *testing.T) { 220 ctx := sql.NewEmptyContext() 221 i := NewWindowPartitionIter( 222 &WindowPartition{ 223 PartitionBy: partitionByX, 224 Aggs: []*Aggregation{ 225 NewAggregation(sumZ, NewGroupByFramer()), 226 }, 227 }) 228 i.input = []sql.Row{} 229 i.partitions = []sql.WindowInterval{{0, 0}} 230 i.outputOrdering = []int{} 231 output, err := i.materializeOutput(ctx) 232 require.Equal(t, io.EOF, err) 233 require.ElementsMatch(t, nil, output) 234 }) 235 236 t.Run("nil input no partition by", func(t *testing.T) { 237 ctx := sql.NewEmptyContext() 238 i := NewWindowPartitionIter( 239 &WindowPartition{ 240 PartitionBy: nil, 241 Aggs: []*Aggregation{ 242 NewAggregation(NewCountAgg(expression.NewGetField(0, types.Int64, "z", true)), NewGroupByFramer()), 243 }, 244 }) 245 i.input = []sql.Row{} 246 i.partitions = []sql.WindowInterval{{0, 0}} 247 i.outputOrdering = nil 248 output, err := i.materializeOutput(ctx) 249 require.NoError(t, err) 250 expOutput := []sql.Row{{int64(0), nil}} 251 require.ElementsMatch(t, expOutput, output) 252 }) 253 } 254 255 func TestWindowPartition_SortAndFilterOutput(t *testing.T) { 256 tests := []struct { 257 Name string 258 Output []sql.Row 259 Expected []sql.Row 260 }{ 261 { 262 Name: "no input rows filtered before output, contiguous output indexes", 263 Output: []sql.Row{ 264 {0, 0}, 265 {3, 3}, 266 {2, 2}, 267 {1, 1}, 268 }, 269 Expected: []sql.Row{ 270 {0}, 271 {1}, 272 {2}, 273 {3}, 274 }, 275 }, 276 { 277 Name: "input rows filtered before output, non contiguous output indexes", 278 Output: []sql.Row{ 279 {0, 0}, 280 {3, 3}, 281 {2, 7}, 282 {1, 1}, 283 }, 284 Expected: []sql.Row{ 285 {0}, 286 {1}, 287 {3}, 288 {2}, 289 }, 290 }, 291 { 292 Name: "empty output", 293 Output: []sql.Row{}, 294 Expected: nil, 295 }, 296 } 297 298 for _, tt := range tests { 299 t.Run(tt.Name, func(t *testing.T) { 300 i := NewWindowPartitionIter(&WindowPartition{}) 301 i.output = tt.Output 302 err := i.sortAndFilterOutput() 303 require.NoError(t, err) 304 require.ElementsMatch(t, tt.Expected, i.output) 305 }) 306 } 307 }