github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/agg.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "errors" 19 "fmt" 20 "io" 21 22 "github.com/cespare/xxhash/v2" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 type groupByIter struct { 30 selectedExprs []sql.Expression 31 child sql.RowIter 32 ctx *sql.Context 33 buf []sql.AggregationBuffer 34 done bool 35 } 36 37 func newGroupByIter(selectedExprs []sql.Expression, child sql.RowIter) *groupByIter { 38 return &groupByIter{ 39 selectedExprs: selectedExprs, 40 child: child, 41 buf: make([]sql.AggregationBuffer, len(selectedExprs)), 42 } 43 } 44 45 func (i *groupByIter) Next(ctx *sql.Context) (sql.Row, error) { 46 if i.done { 47 return nil, io.EOF 48 } 49 50 // special case for any_value 51 var err error 52 onlyAnyValue := true 53 for j, a := range i.selectedExprs { 54 i.buf[j], err = newAggregationBuffer(a) 55 if err != nil { 56 return nil, err 57 } 58 if agg, ok := a.(sql.Aggregation); ok { 59 if _, ok = agg.(*aggregation.AnyValue); !ok { 60 onlyAnyValue = false 61 } 62 } 63 } 64 65 // if no aggregate functions other than any_value, it's just a normal select 66 if onlyAnyValue { 67 row, err := i.child.Next(ctx) 68 if err != nil { 69 i.done = true 70 return nil, err 71 } 72 73 if err := updateBuffers(ctx, i.buf, row); err != nil { 74 return nil, err 75 } 76 return evalBuffers(ctx, i.buf) 77 } 78 i.done = true 79 80 for { 81 row, err := i.child.Next(ctx) 82 if err != nil { 83 if err == io.EOF { 84 break 85 } 86 return nil, err 87 } 88 89 if err := updateBuffers(ctx, i.buf, row); err != nil { 90 return nil, err 91 } 92 } 93 94 row, err := evalBuffers(ctx, i.buf) 95 if err != nil { 96 return nil, err 97 } 98 return row, nil 99 } 100 101 func (i *groupByIter) Close(ctx *sql.Context) error { 102 i.Dispose() 103 i.buf = nil 104 return i.child.Close(ctx) 105 } 106 107 func (i *groupByIter) Dispose() { 108 for _, b := range i.buf { 109 b.Dispose() 110 } 111 } 112 113 type groupByGroupingIter struct { 114 selectedExprs []sql.Expression 115 groupByExprs []sql.Expression 116 aggregations sql.KeyValueCache 117 keys []uint64 118 pos int 119 child sql.RowIter 120 dispose sql.DisposeFunc 121 } 122 123 func newGroupByGroupingIter( 124 ctx *sql.Context, 125 selectedExprs, groupByExprs []sql.Expression, 126 child sql.RowIter, 127 ) *groupByGroupingIter { 128 return &groupByGroupingIter{ 129 selectedExprs: selectedExprs, 130 groupByExprs: groupByExprs, 131 child: child, 132 } 133 } 134 135 func (i *groupByGroupingIter) Next(ctx *sql.Context) (sql.Row, error) { 136 if i.aggregations == nil { 137 i.aggregations, i.dispose = ctx.Memory.NewHistoryCache() 138 if err := i.compute(ctx); err != nil { 139 return nil, err 140 } 141 } 142 143 if i.pos >= len(i.keys) { 144 return nil, io.EOF 145 } 146 147 buffers, err := i.get(i.keys[i.pos]) 148 if err != nil { 149 return nil, err 150 } 151 i.pos++ 152 153 row, err := evalBuffers(ctx, buffers) 154 if err != nil { 155 return nil, err 156 } 157 158 return row, nil 159 } 160 161 func (i *groupByGroupingIter) compute(ctx *sql.Context) error { 162 for { 163 row, err := i.child.Next(ctx) 164 if err != nil { 165 if err == io.EOF { 166 break 167 } 168 return err 169 } 170 171 key, err := groupingKey(ctx, i.groupByExprs, row) 172 if err != nil { 173 return err 174 } 175 176 b, err := i.get(key) 177 if errors.Is(err, sql.ErrKeyNotFound) { 178 b = make([]sql.AggregationBuffer, len(i.selectedExprs)) 179 for j, a := range i.selectedExprs { 180 b[j], err = newAggregationBuffer(a) 181 if err != nil { 182 return err 183 } 184 } 185 186 if err := i.aggregations.Put(key, b); err != nil { 187 return err 188 } 189 190 i.keys = append(i.keys, key) 191 } else if err != nil { 192 return err 193 } 194 195 err = updateBuffers(ctx, b, row) 196 if err != nil { 197 return err 198 } 199 } 200 201 return nil 202 } 203 204 func (i *groupByGroupingIter) get(key uint64) ([]sql.AggregationBuffer, error) { 205 v, err := i.aggregations.Get(key) 206 if err != nil { 207 return nil, err 208 } 209 if v == nil { 210 return nil, nil 211 } 212 return v.([]sql.AggregationBuffer), err 213 } 214 215 func (i *groupByGroupingIter) put(key uint64, val []sql.AggregationBuffer) error { 216 return i.aggregations.Put(key, val) 217 } 218 219 func (i *groupByGroupingIter) Close(ctx *sql.Context) error { 220 i.Dispose() 221 i.aggregations = nil 222 if i.dispose != nil { 223 i.dispose() 224 i.dispose = nil 225 } 226 227 return i.child.Close(ctx) 228 } 229 230 func (i *groupByGroupingIter) Dispose() { 231 for _, k := range i.keys { 232 bs, _ := i.get(k) 233 if bs != nil { 234 for _, b := range bs { 235 b.Dispose() 236 } 237 } 238 } 239 } 240 241 func groupingKey( 242 ctx *sql.Context, 243 exprs []sql.Expression, 244 row sql.Row, 245 ) (uint64, error) { 246 hash := xxhash.New() 247 for i, expr := range exprs { 248 v, err := expr.Eval(ctx, row) 249 if err != nil { 250 return 0, err 251 } 252 253 if i > 0 { 254 // separate each expression in the grouping key with a nil byte 255 if _, err = hash.Write([]byte{0}); err != nil { 256 return 0, err 257 } 258 } 259 260 t, isStringType := expr.Type().(sql.StringType) 261 if isStringType && v != nil { 262 v, err = types.ConvertToString(v, t) 263 if err == nil { 264 err = t.Collation().WriteWeightString(hash, v.(string)) 265 } 266 } else { 267 _, err = fmt.Fprintf(hash, "%v", v) 268 } 269 if err != nil { 270 return 0, err 271 } 272 } 273 274 return hash.Sum64(), nil 275 } 276 277 func newAggregationBuffer(expr sql.Expression) (sql.AggregationBuffer, error) { 278 switch n := expr.(type) { 279 case sql.Aggregation: 280 return n.NewBuffer() 281 default: 282 // The semantics for a non-aggregation expression in a group by node is First. 283 // When ONLY_FULL_GROUP_BY is enabled, this is an error, but it's allowed otherwise. 284 return aggregation.NewFirst(expr).NewBuffer() 285 } 286 } 287 288 func updateBuffers( 289 ctx *sql.Context, 290 buffers []sql.AggregationBuffer, 291 row sql.Row, 292 ) error { 293 for _, b := range buffers { 294 if err := b.Update(ctx, row); err != nil { 295 return err 296 } 297 } 298 299 return nil 300 } 301 302 func evalBuffers( 303 ctx *sql.Context, 304 buffers []sql.AggregationBuffer, 305 ) (sql.Row, error) { 306 var row = make(sql.Row, len(buffers)) 307 308 var err error 309 for i, b := range buffers { 310 row[i], err = b.Eval(ctx) 311 if err != nil { 312 return nil, err 313 } 314 } 315 316 return row, nil 317 }