github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/causetstore/mockeinsteindb/aggregate.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package mockeinsteindb 15 16 import ( 17 "context" 18 "time" 19 20 "github.com/whtcorpsinc/errors" 21 "github.com/whtcorpsinc/milevadb/memex" 22 "github.com/whtcorpsinc/milevadb/memex/aggregation" 23 "github.com/whtcorpsinc/milevadb/types" 24 "github.com/whtcorpsinc/milevadb/soliton/chunk" 25 "github.com/whtcorpsinc/milevadb/soliton/codec" 26 ) 27 28 type aggCtxsMapper map[string][]*aggregation.AggEvaluateContext 29 30 var ( 31 _ interlock = &hashAggInterDirc{} 32 _ interlock = &streamAggInterDirc{} 33 ) 34 35 type hashAggInterDirc struct { 36 evalCtx *evalContext 37 aggExprs []aggregation.Aggregation 38 aggCtxsMap aggCtxsMapper 39 groupByExprs []memex.Expression 40 relatedDefCausOffsets []int 41 event []types.Causet 42 groups map[string]struct{} 43 groupKeys [][]byte 44 groupKeyRows [][][]byte 45 executed bool 46 currGroupIdx int 47 count int64 48 execDetail *execDetail 49 50 src interlock 51 } 52 53 func (e *hashAggInterDirc) InterDircDetails() []*execDetail { 54 var suffix []*execDetail 55 if e.src != nil { 56 suffix = e.src.InterDircDetails() 57 } 58 return append(suffix, e.execDetail) 59 } 60 61 func (e *hashAggInterDirc) SetSrcInterDirc(exec interlock) { 62 e.src = exec 63 } 64 65 func (e *hashAggInterDirc) GetSrcInterDirc() interlock { 66 return e.src 67 } 68 69 func (e *hashAggInterDirc) ResetCounts() { 70 e.src.ResetCounts() 71 } 72 73 func (e *hashAggInterDirc) Counts() []int64 { 74 return e.src.Counts() 75 } 76 77 func (e *hashAggInterDirc) innerNext(ctx context.Context) (bool, error) { 78 values, err := e.src.Next(ctx) 79 if err != nil { 80 return false, errors.Trace(err) 81 } 82 if values == nil { 83 return false, nil 84 } 85 err = e.aggregate(values) 86 if err != nil { 87 return false, errors.Trace(err) 88 } 89 return true, nil 90 } 91 92 func (e *hashAggInterDirc) Cursor() ([]byte, bool) { 93 panic("don't not use interlock streaming API for hash aggregation!") 94 } 95 96 func (e *hashAggInterDirc) Next(ctx context.Context) (value [][]byte, err error) { 97 defer func(begin time.Time) { 98 e.execDetail.uFIDelate(begin, value) 99 }(time.Now()) 100 e.count++ 101 if e.aggCtxsMap == nil { 102 e.aggCtxsMap = make(aggCtxsMapper) 103 } 104 if !e.executed { 105 for { 106 hasMore, err := e.innerNext(ctx) 107 if err != nil { 108 return nil, errors.Trace(err) 109 } 110 if !hasMore { 111 break 112 } 113 } 114 e.executed = true 115 } 116 117 if e.currGroupIdx >= len(e.groups) { 118 return nil, nil 119 } 120 gk := e.groupKeys[e.currGroupIdx] 121 value = make([][]byte, 0, len(e.groupByExprs)+2*len(e.aggExprs)) 122 aggCtxs := e.getContexts(gk) 123 for i, agg := range e.aggExprs { 124 partialResults := agg.GetPartialResult(aggCtxs[i]) 125 for _, result := range partialResults { 126 data, err := codec.EncodeValue(e.evalCtx.sc, nil, result) 127 if err != nil { 128 return nil, errors.Trace(err) 129 } 130 value = append(value, data) 131 } 132 } 133 value = append(value, e.groupKeyRows[e.currGroupIdx]...) 134 e.currGroupIdx++ 135 136 return value, nil 137 } 138 139 func (e *hashAggInterDirc) getGroupKey() ([]byte, [][]byte, error) { 140 length := len(e.groupByExprs) 141 if length == 0 { 142 return nil, nil, nil 143 } 144 bufLen := 0 145 event := make([][]byte, 0, length) 146 for _, item := range e.groupByExprs { 147 v, err := item.Eval(chunk.MutRowFromCausets(e.event).ToRow()) 148 if err != nil { 149 return nil, nil, errors.Trace(err) 150 } 151 b, err := codec.EncodeValue(e.evalCtx.sc, nil, v) 152 if err != nil { 153 return nil, nil, errors.Trace(err) 154 } 155 bufLen += len(b) 156 event = append(event, b) 157 } 158 buf := make([]byte, 0, bufLen) 159 for _, col := range event { 160 buf = append(buf, col...) 161 } 162 return buf, event, nil 163 } 164 165 // aggregate uFIDelates aggregate functions with event. 166 func (e *hashAggInterDirc) aggregate(value [][]byte) error { 167 err := e.evalCtx.decodeRelatedDeferredCausetVals(e.relatedDefCausOffsets, value, e.event) 168 if err != nil { 169 return errors.Trace(err) 170 } 171 // Get group key. 172 gk, gbyKeyRow, err := e.getGroupKey() 173 if err != nil { 174 return errors.Trace(err) 175 } 176 if _, ok := e.groups[string(gk)]; !ok { 177 e.groups[string(gk)] = struct{}{} 178 e.groupKeys = append(e.groupKeys, gk) 179 e.groupKeyRows = append(e.groupKeyRows, gbyKeyRow) 180 } 181 // UFIDelate aggregate memexs. 182 aggCtxs := e.getContexts(gk) 183 for i, agg := range e.aggExprs { 184 err = agg.UFIDelate(aggCtxs[i], e.evalCtx.sc, chunk.MutRowFromCausets(e.event).ToRow()) 185 if err != nil { 186 return errors.Trace(err) 187 } 188 } 189 return nil 190 } 191 192 func (e *hashAggInterDirc) getContexts(groupKey []byte) []*aggregation.AggEvaluateContext { 193 groupKeyString := string(groupKey) 194 aggCtxs, ok := e.aggCtxsMap[groupKeyString] 195 if !ok { 196 aggCtxs = make([]*aggregation.AggEvaluateContext, 0, len(e.aggExprs)) 197 for _, agg := range e.aggExprs { 198 aggCtxs = append(aggCtxs, agg.CreateContext(e.evalCtx.sc)) 199 } 200 e.aggCtxsMap[groupKeyString] = aggCtxs 201 } 202 return aggCtxs 203 } 204 205 type streamAggInterDirc struct { 206 evalCtx *evalContext 207 aggExprs []aggregation.Aggregation 208 aggCtxs []*aggregation.AggEvaluateContext 209 groupByExprs []memex.Expression 210 relatedDefCausOffsets []int 211 event []types.Causet 212 tmpGroupByRow []types.Causet 213 currGroupByRow []types.Causet 214 nextGroupByRow []types.Causet 215 currGroupByValues [][]byte 216 executed bool 217 hasData bool 218 count int64 219 execDetail *execDetail 220 221 src interlock 222 } 223 224 func (e *streamAggInterDirc) InterDircDetails() []*execDetail { 225 var suffix []*execDetail 226 if e.src != nil { 227 suffix = e.src.InterDircDetails() 228 } 229 return append(suffix, e.execDetail) 230 } 231 232 func (e *streamAggInterDirc) SetSrcInterDirc(exec interlock) { 233 e.src = exec 234 } 235 236 func (e *streamAggInterDirc) GetSrcInterDirc() interlock { 237 return e.src 238 } 239 240 func (e *streamAggInterDirc) ResetCounts() { 241 e.src.ResetCounts() 242 } 243 244 func (e *streamAggInterDirc) Counts() []int64 { 245 return e.src.Counts() 246 } 247 248 func (e *streamAggInterDirc) getPartialResult() ([][]byte, error) { 249 value := make([][]byte, 0, len(e.groupByExprs)+2*len(e.aggExprs)) 250 for i, agg := range e.aggExprs { 251 partialResults := agg.GetPartialResult(e.aggCtxs[i]) 252 for _, result := range partialResults { 253 data, err := codec.EncodeValue(e.evalCtx.sc, nil, result) 254 if err != nil { 255 return nil, errors.Trace(err) 256 } 257 value = append(value, data) 258 } 259 // Clear the aggregate context. 260 e.aggCtxs[i] = agg.CreateContext(e.evalCtx.sc) 261 } 262 e.currGroupByValues = e.currGroupByValues[:0] 263 for _, d := range e.currGroupByRow { 264 buf, err := codec.EncodeValue(e.evalCtx.sc, nil, d) 265 if err != nil { 266 return nil, errors.Trace(err) 267 } 268 e.currGroupByValues = append(e.currGroupByValues, buf) 269 } 270 e.currGroupByRow = types.CloneRow(e.nextGroupByRow) 271 return append(value, e.currGroupByValues...), nil 272 } 273 274 func (e *streamAggInterDirc) meetNewGroup(event [][]byte) (bool, error) { 275 if len(e.groupByExprs) == 0 { 276 return false, nil 277 } 278 279 e.tmpGroupByRow = e.tmpGroupByRow[:0] 280 matched, firstGroup := true, false 281 if e.nextGroupByRow == nil { 282 matched, firstGroup = false, true 283 } 284 for i, item := range e.groupByExprs { 285 d, err := item.Eval(chunk.MutRowFromCausets(e.event).ToRow()) 286 if err != nil { 287 return false, errors.Trace(err) 288 } 289 if matched { 290 c, err := d.CompareCauset(e.evalCtx.sc, &e.nextGroupByRow[i]) 291 if err != nil { 292 return false, errors.Trace(err) 293 } 294 matched = c == 0 295 } 296 e.tmpGroupByRow = append(e.tmpGroupByRow, d) 297 } 298 if firstGroup { 299 e.currGroupByRow = types.CloneRow(e.tmpGroupByRow) 300 } 301 if matched { 302 return false, nil 303 } 304 e.nextGroupByRow = e.tmpGroupByRow 305 return !firstGroup, nil 306 } 307 308 func (e *streamAggInterDirc) Cursor() ([]byte, bool) { 309 panic("don't not use interlock streaming API for stream aggregation!") 310 } 311 312 func (e *streamAggInterDirc) Next(ctx context.Context) (retRow [][]byte, err error) { 313 defer func(begin time.Time) { 314 e.execDetail.uFIDelate(begin, retRow) 315 }(time.Now()) 316 e.count++ 317 if e.executed { 318 return nil, nil 319 } 320 321 for { 322 values, err := e.src.Next(ctx) 323 if err != nil { 324 return nil, errors.Trace(err) 325 } 326 if values == nil { 327 e.executed = true 328 if !e.hasData && len(e.groupByExprs) > 0 { 329 return nil, nil 330 } 331 return e.getPartialResult() 332 } 333 334 e.hasData = true 335 err = e.evalCtx.decodeRelatedDeferredCausetVals(e.relatedDefCausOffsets, values, e.event) 336 if err != nil { 337 return nil, errors.Trace(err) 338 } 339 newGroup, err := e.meetNewGroup(values) 340 if err != nil { 341 return nil, errors.Trace(err) 342 } 343 if newGroup { 344 retRow, err = e.getPartialResult() 345 if err != nil { 346 return nil, errors.Trace(err) 347 } 348 } 349 for i, agg := range e.aggExprs { 350 err = agg.UFIDelate(e.aggCtxs[i], e.evalCtx.sc, chunk.MutRowFromCausets(e.event).ToRow()) 351 if err != nil { 352 return nil, errors.Trace(err) 353 } 354 } 355 if newGroup { 356 return retRow, nil 357 } 358 } 359 }