github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/merge_result.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 "strconv" 20 "strings" 21 22 "github.com/XiaoMi/Gaea/mysql" 23 "github.com/XiaoMi/Gaea/parser/ast" 24 "github.com/XiaoMi/Gaea/util/hack" 25 "github.com/XiaoMi/Gaea/util/math" 26 ) 27 28 // ResultRow is one Row in Result 29 type ResultRow []interface{} 30 31 // GetInt get int value from column 32 // copy from Resultset.GetInt() 33 func (r ResultRow) GetInt(column int) (int64, error) { 34 d := r[column] 35 switch v := d.(type) { 36 case uint64: 37 return int64(v), nil 38 case int64: 39 return v, nil 40 case float64: 41 return int64(v), nil 42 case string: 43 return strconv.ParseInt(v, 10, 64) 44 case []byte: 45 return strconv.ParseInt(string(v), 10, 64) 46 case nil: 47 return 0, nil 48 default: 49 return 0, fmt.Errorf("data type is %T", v) 50 } 51 } 52 53 // GetUint get uint64 value from column 54 func (r ResultRow) GetUint(column int) (uint64, error) { 55 d := r[column] 56 switch v := d.(type) { 57 case uint64: 58 return v, nil 59 case int64: 60 return uint64(v), nil 61 case float64: 62 return uint64(v), nil 63 case string: 64 return strconv.ParseUint(v, 10, 64) 65 case []byte: 66 return strconv.ParseUint(string(v), 10, 64) 67 case nil: 68 return 0, nil 69 default: 70 return 0, fmt.Errorf("data type is %T", v) 71 } 72 } 73 74 // GetFloat get float64 value from column 75 func (r ResultRow) GetFloat(column int) (float64, error) { 76 d := r[column] 77 switch v := d.(type) { 78 case float64: 79 return v, nil 80 case uint64: 81 return float64(v), nil 82 case int64: 83 return float64(v), nil 84 case string: 85 return strconv.ParseFloat(v, 64) 86 case []byte: 87 return strconv.ParseFloat(string(v), 64) 88 case nil: 89 return 0, nil 90 default: 91 return 0, fmt.Errorf("data type is %T", v) 92 } 93 } 94 95 // SetValue set value to column 96 func (r ResultRow) SetValue(column int, value interface{}) { 97 r[column] = value 98 } 99 100 // GetValue get value from column 101 func (r ResultRow) GetValue(column int) interface{} { 102 return r[column] 103 } 104 105 // AggregateFuncMerger is the merger of aggregate function 106 type AggregateFuncMerger interface { 107 // MergeTo 合并结果集, from为待合并行, to为结果聚合行 108 MergeTo(from, to ResultRow) error 109 } 110 111 type aggregateFuncBaseMerger struct { 112 fieldIndex int // 所在列位置 113 } 114 115 // CreateAggregateFunctionMerger create AggregateFunctionMerger by function type 116 // currently support: "count", "sum", "max", "min" 117 func CreateAggregateFunctionMerger(funcType string, fieldIndex int) (AggregateFuncMerger, error) { 118 switch strings.ToLower(funcType) { 119 case "count": 120 ret := new(AggregateFuncCountMerger) 121 ret.fieldIndex = fieldIndex 122 return ret, nil 123 case "sum": 124 ret := new(AggregateFuncSumMerger) 125 ret.fieldIndex = fieldIndex 126 return ret, nil 127 case "max": 128 ret := new(AggregateFuncMaxMerger) 129 ret.fieldIndex = fieldIndex 130 return ret, nil 131 case "min": 132 ret := new(AggregateFuncMinMerger) 133 ret.fieldIndex = fieldIndex 134 return ret, nil 135 default: 136 return nil, fmt.Errorf("aggregate function type is not support: %s", funcType) 137 } 138 } 139 140 // AggregateFuncCountMerger merge COUNT() column in result 141 type AggregateFuncCountMerger struct { 142 aggregateFuncBaseMerger 143 } 144 145 // MergeTo implement AggregateFuncMerger 146 func (a *AggregateFuncCountMerger) MergeTo(from, to ResultRow) error { 147 idx := a.fieldIndex 148 if idx >= len(from) || idx >= len(to) { 149 return fmt.Errorf("field index out of bound: %d", a.fieldIndex) 150 } 151 152 valueToMerge, err := from.GetInt(idx) 153 if err != nil { 154 return fmt.Errorf("get from int value error: %v", err) 155 } 156 originValue, err := to.GetInt(idx) 157 if err != nil { 158 return fmt.Errorf("get to int value error: %v", err) 159 } 160 to.SetValue(idx, originValue+valueToMerge) 161 return nil 162 } 163 164 // AggregateFuncSumMerger merge SUM() column in result 165 type AggregateFuncSumMerger struct { 166 aggregateFuncBaseMerger 167 } 168 169 // MergeTo implement AggregateFuncMerger 170 func (a *AggregateFuncSumMerger) MergeTo(from, to ResultRow) error { 171 idx := a.fieldIndex 172 if idx >= len(from) || idx >= len(to) { 173 return fmt.Errorf("field index out of bound: %d", a.fieldIndex) 174 } 175 176 fromValueI := from.GetValue(idx) 177 178 // nil对应NULL, NULL不参与比较 179 if fromValueI == nil { 180 return nil 181 } 182 183 switch to.GetValue(idx).(type) { 184 case int64: 185 return a.sumToInt64(from, to) 186 case uint64: 187 return a.sumToUint64(from, to) 188 case float64, string, []byte, nil: 189 return a.sumToFloat64(from, to) 190 default: 191 fromValue := from.GetValue(idx) 192 toValue := to.GetValue(idx) 193 return fmt.Errorf("cannot sum value %v (%T) to %v (%T)", fromValue, fromValue, toValue, toValue) 194 } 195 } 196 197 func (a *AggregateFuncSumMerger) sumToInt64(from, to ResultRow) error { 198 idx := a.fieldIndex // does not need to check 199 valueToMerge, err := from.GetInt(idx) 200 if err != nil { 201 return fmt.Errorf("get from int value error: %v", err) 202 } 203 originValue, err := to.GetInt(idx) 204 if err != nil { 205 return fmt.Errorf("get to int value error: %v", err) 206 } 207 to.SetValue(idx, originValue+valueToMerge) 208 return nil 209 } 210 211 func (a *AggregateFuncSumMerger) sumToUint64(from, to ResultRow) error { 212 idx := a.fieldIndex // does not need to check 213 valueToMerge, err := from.GetUint(idx) 214 if err != nil { 215 return fmt.Errorf("get from int value error: %v", err) 216 } 217 originValue, err := to.GetUint(idx) 218 if err != nil { 219 return fmt.Errorf("get to int value error: %v", err) 220 } 221 to.SetValue(idx, originValue+valueToMerge) 222 return nil 223 } 224 225 func (a *AggregateFuncSumMerger) sumToFloat64(from, to ResultRow) error { 226 idx := a.fieldIndex // does not need to check 227 valueToMerge, err := from.GetFloat(idx) 228 if err != nil { 229 return fmt.Errorf("get from int value error: %v", err) 230 } 231 originValue, err := to.GetFloat(idx) 232 if err != nil { 233 return fmt.Errorf("get to int value error: %v", err) 234 } 235 to.SetValue(idx, originValue+valueToMerge) 236 return nil 237 } 238 239 // AggregateFuncMaxMerger merge MAX() column in result 240 type AggregateFuncMaxMerger struct { 241 aggregateFuncBaseMerger 242 } 243 244 // MergeTo implement AggregateFuncMerger 245 func (a *AggregateFuncMaxMerger) MergeTo(from, to ResultRow) error { 246 idx := a.fieldIndex 247 if idx >= len(from) || idx >= len(to) { 248 return fmt.Errorf("field index out of bound: %d", a.fieldIndex) 249 } 250 251 fromValueI := from.GetValue(idx) 252 toValueI := to.GetValue(idx) 253 254 // nil对应NULL, NULL不参与比较 255 if fromValueI == nil { 256 return nil 257 } 258 259 switch toValue := toValueI.(type) { 260 case nil: 261 to.SetValue(idx, fromValueI) 262 return nil 263 case int64: 264 if fromValueI.(int64) > toValue { 265 to.SetValue(idx, fromValueI) 266 } 267 return nil 268 case uint64: 269 if fromValueI.(uint64) > toValue { 270 to.SetValue(idx, fromValueI) 271 } 272 return nil 273 case float64: 274 if fromValueI.(float64) > toValue { 275 to.SetValue(idx, fromValueI) 276 } 277 return nil 278 case string: 279 if fromValueI.(string) > toValue { 280 to.SetValue(idx, fromValueI) 281 } 282 return nil 283 // does not handle []byte 284 default: 285 return fmt.Errorf("cannot compare value %v (%T) to %v (%T)", fromValueI, fromValueI, toValueI, toValueI) 286 } 287 } 288 289 // AggregateFuncMinMerger merge MIN() column in result 290 type AggregateFuncMinMerger struct { 291 aggregateFuncBaseMerger 292 } 293 294 // MergeTo implement AggregateFuncMerger 295 func (a *AggregateFuncMinMerger) MergeTo(from, to ResultRow) error { 296 idx := a.fieldIndex 297 if idx >= len(from) || idx >= len(to) { 298 return fmt.Errorf("field index out of bound: %d", a.fieldIndex) 299 } 300 301 fromValueI := from.GetValue(idx) 302 toValueI := to.GetValue(idx) 303 304 // nil对应NULL, NULL不参与比较 305 if fromValueI == nil { 306 return nil 307 } 308 309 switch toValue := toValueI.(type) { 310 case nil: 311 to.SetValue(idx, fromValueI) 312 return nil 313 case int64: 314 if fromValueI.(int64) < toValue { 315 to.SetValue(idx, fromValueI) 316 } 317 return nil 318 case uint64: 319 if fromValueI.(uint64) < toValue { 320 to.SetValue(idx, fromValueI) 321 } 322 return nil 323 case float64: 324 if fromValueI.(float64) < toValue { 325 to.SetValue(idx, fromValueI) 326 } 327 return nil 328 case string: 329 if fromValueI.(string) < toValue { 330 to.SetValue(idx, fromValueI) 331 } 332 return nil 333 // does not handle []byte 334 default: 335 return fmt.Errorf("cannot compare value %v (%T) to %v (%T)", fromValueI, fromValueI, toValueI, toValueI) 336 } 337 } 338 339 // MergeExecResult merge execution results, like UPDATE, INSERT, DELETE, ... 340 func MergeExecResult(rs []*mysql.Result) (*mysql.Result, error) { 341 r := new(mysql.Result) 342 for _, v := range rs { 343 r.Status |= v.Status 344 r.AffectedRows += v.AffectedRows 345 if r.InsertID == 0 { 346 r.InsertID = v.InsertID 347 } else if v.InsertID != 0 && r.InsertID > v.InsertID { 348 //last insert id is first gen id for multi row inserted 349 //see http://dev.mysql.com/doc/refman/5.6/en/information-functions.html#function_last-insert-id 350 r.InsertID = v.InsertID 351 } 352 } 353 354 return r, nil 355 } 356 357 // MergeSelectResult merge select results 358 func MergeSelectResult(p *SelectPlan, stmt *ast.SelectStmt, rs []*mysql.Result) (*mysql.Result, error) { 359 ret := mergeMultiResultSet(rs) 360 361 if p.distinct { 362 if err := removeDistinctRowInResult(p, ret); err != nil { 363 return nil, err 364 } 365 } 366 367 if stmt.GroupBy != nil { 368 if err := buildSelectGroupByResult(p, ret); err != nil { 369 return nil, err 370 } 371 } else { 372 if err := buildSelectOnlyResult(p, ret); err != nil { 373 return nil, err 374 } 375 } 376 377 if err := sortSelectResult(p, stmt, ret); err != nil { 378 return nil, err 379 } 380 381 if err := limitSelectResult(p, ret); err != nil { 382 return nil, err 383 } 384 385 if err := trimExtraFields(p, ret); err != nil { 386 return nil, fmt.Errorf("trimExtraFields error: %v", err) 387 } 388 389 if err := GenerateSelectResultRowData(ret); err != nil { 390 return nil, fmt.Errorf("generate RowData error: %v", err) 391 } 392 393 return ret, nil 394 } 395 396 // 合并结果集, 返回一个Result 397 func mergeMultiResultSet(rs []*mysql.Result) *mysql.Result { 398 if len(rs) == 1 { 399 return rs[0] 400 } 401 402 // 列信息认为相同, 因此只合并结果 403 for i := 1; i < len(rs); i++ { 404 rs[0].Status |= rs[i].Status 405 rs[0].Values = append(rs[0].Values, rs[i].Values...) 406 rs[0].RowDatas = append(rs[0].RowDatas, rs[i].RowDatas...) 407 } 408 409 return rs[0] 410 } 411 412 func removeDistinctRowInResult(p *SelectPlan, r *mysql.Result) error { 413 distinctKeySet := make(map[string]bool) 414 var rowToRemove []int 415 // 计算除补列之外的原始列数 416 resultFieldLength := len(r.Fields) 417 originColumnCount := p.GetColumnCount() 418 deltaColumnCount := resultFieldLength - originColumnCount 419 colCnt := p.originColumnCount + deltaColumnCount 420 421 // 根据原始列的值进行去重 422 rowCount := len(r.Values) 423 for i := 0; i < rowCount; i++ { 424 keySlice := r.Values[i][0:colCnt] 425 mk, err := generateMapKey(keySlice) 426 if err != nil { 427 return err 428 } 429 430 _, ok := distinctKeySet[mk] 431 if !ok { 432 distinctKeySet[mk] = true 433 } else { 434 rowToRemove = append(rowToRemove, i) 435 } 436 } 437 438 rowToRemoveCnt := len(rowToRemove) 439 if rowToRemoveCnt == 0 { 440 return nil 441 } 442 443 originRows := r.Values 444 r.RowDatas = nil 445 r.Values = originRows[:0] 446 447 var j int 448 for i := 0; i < rowCount; i++ { 449 if j == rowToRemoveCnt { 450 r.Values = append(r.Values, originRows[i:]...) 451 break 452 } 453 if i == rowToRemove[j] { 454 j++ 455 } else { 456 r.Values = append(r.Values, originRows[i]) 457 } 458 } 459 460 return nil 461 } 462 463 // contains mergeGroupByWithoutFunc() and mergeGroupByWithFunc() 464 func buildSelectGroupByResult(p *SelectPlan, r *mysql.Result) error { 465 resultMap := make(map[string]ResultRow) 466 467 resultFieldLength := len(r.Fields) 468 originColumnCount := p.GetColumnCount() 469 deltaColumnCount := resultFieldLength - originColumnCount 470 471 // 根据group by的列进行结果聚合 472 for i, v := range r.Values { 473 keySlice := make([]interface{}, 0) 474 for _, index := range p.GetGroupByColumnInfo() { 475 keySlice = append(keySlice, v[index+deltaColumnCount]) 476 } 477 mk, err := generateMapKey(keySlice) 478 if err != nil { 479 return err 480 } 481 482 // 用找到的第一个结果行作为聚合结果 483 _, ok := resultMap[mk] 484 if !ok { 485 resultMap[mk] = ResultRow(r.Values[i]) 486 continue 487 } 488 489 if len(p.aggregateFuncs) == 0 { 490 continue 491 } 492 493 // 如果存在聚合函数, 则对聚合列进行结果聚合, 非聚合列不处理 494 retToMerge := ResultRow(r.Values[i]) 495 for _, mfunc := range p.aggregateFuncs { 496 if err := mfunc.MergeTo(retToMerge, resultMap[mk]); err != nil { 497 return fmt.Errorf("MergeTo error, func: %v, value: %v, err: %v", mfunc, retToMerge, err) 498 } 499 } 500 } 501 502 err := buildResultFromResultMap(r, resultMap) 503 if err != nil { 504 return fmt.Errorf("buildResultFromResultMap error: %v", err) 505 } 506 507 return nil 508 } 509 510 func buildSelectOnlyResult(p *SelectPlan, rs *mysql.Result) error { 511 r := rs.Resultset 512 // 没有聚合函数, 直接把所有分片结果添加到同一个ResultSet下面 513 if len(p.aggregateFuncs) == 0 { 514 return nil 515 } 516 517 // 存在聚合函数, 需要改写聚合列的值, 然后返回 (应该只有一行记录) 518 isSet := false 519 var currRet ResultRow 520 for i, v := range r.Values { 521 if !isSet { 522 isSet = true 523 currRet = ResultRow(v) 524 continue 525 } 526 527 retToMerge := ResultRow(r.Values[i]) 528 for _, mfunc := range p.aggregateFuncs { 529 if err := mfunc.MergeTo(retToMerge, currRet); err != nil { 530 return fmt.Errorf("MergeTo error, func: %v, value: %v, err: %v", mfunc, retToMerge, err) 531 } 532 } 533 } 534 535 r.Values = r.Values[:0] 536 r.Values = append(r.Values, currRet) 537 r.RowDatas = nil 538 539 return nil 540 } 541 542 // this function modifies the first value of origin results 543 func buildResultFromResultMap(r *mysql.Result, resultMap map[string]ResultRow) error { 544 // no group by result means the result row count is 0, so return the first result 545 if len(resultMap) == 0 { 546 return nil 547 } 548 549 r.Values = nil 550 r.RowDatas = nil 551 for _, v := range resultMap { 552 r.Values = append(r.Values, v) 553 } 554 555 return nil 556 } 557 558 // 去掉补充的列 559 // 与补充列的顺序相反, 先去掉ORDER BY补充的列, 再去掉GROUP BY补充的列 560 func trimExtraFields(p *SelectPlan, r *mysql.Result) error { 561 resultFieldLength := len(r.Fields) 562 originColumnCount := p.GetColumnCount() 563 deltaColumnCount := resultFieldLength - originColumnCount 564 extraFieldStartIndex := deltaColumnCount + p.GetOriginColumnCount() 565 566 if extraFieldStartIndex != -1 { 567 r.Fields = r.Fields[0:extraFieldStartIndex] 568 for i := 0; i < len(r.Values); i++ { 569 r.Values[i] = r.Values[i][0:extraFieldStartIndex] 570 } 571 } 572 573 return nil 574 } 575 576 func sortSelectResult(p *SelectPlan, stmt *ast.SelectStmt, ret *mysql.Result) error { 577 if !p.HasOrderBy() { 578 return nil 579 } 580 581 resultFieldLength := len(ret.Fields) 582 originColumnCount := p.GetColumnCount() 583 deltaColumnCount := resultFieldLength - originColumnCount 584 585 orderByColumns, orderByDirections := p.GetOrderByColumnInfo() 586 var sortKeys []mysql.SortKey 587 for i := 0; i < len(orderByDirections); i++ { 588 sortKey := mysql.SortKey{} 589 sortKey.Column = orderByColumns[i] + deltaColumnCount 590 if orderByDirections[i] { 591 sortKey.Direction = mysql.SortDesc 592 } else { 593 sortKey.Direction = mysql.SortAsc 594 } 595 sortKeys = append(sortKeys, sortKey) 596 } 597 598 return ret.SortWithoutColumnName(sortKeys) 599 } 600 601 // the result from backend is aggregated and offset = 0, count = (originOffset + originCount) 602 func limitSelectResult(p *SelectPlan, ret *mysql.Result) error { 603 if !p.HasLimit() { 604 return nil 605 } 606 607 start, count := p.GetLimitValue() 608 609 rowLen := int64(len(ret.Values)) 610 end := math.MinInt64(start+count, rowLen) 611 612 if start >= rowLen { 613 ret.RowDatas = ret.RowDatas[:0] 614 ret.Values = ret.Values[:0] 615 return nil 616 } 617 618 ret.Values = ret.Values[start:end] 619 return nil 620 } 621 622 // GenerateSelectResultRowData generate raw RowData from values 623 // 根据value反向构造RowData 624 // copy from server.buildResultset() 625 func GenerateSelectResultRowData(r *mysql.Result) error { 626 r.RowDatas = nil 627 for i, vs := range r.Values { 628 if len(vs) != len(r.Fields) { 629 return fmt.Errorf("row %d has %d column not equal %d", i, len(vs), len(r.Fields)) 630 } 631 632 var row []byte 633 for _, value := range vs { 634 // build row values 635 if value == nil { 636 row = append(row, 0xfb) 637 } else { 638 b, err := formatValue(value) 639 if err != nil { 640 return err 641 } 642 row = mysql.AppendLenEncStringBytes(row, b) 643 } 644 } 645 646 r.RowDatas = append(r.RowDatas, row) 647 } 648 649 return nil 650 } 651 652 // copy from server.generateMapKey() 653 func generateMapKey(groupColumns []interface{}) (string, error) { 654 bk := make([]byte, 0, 8) 655 separatorBuf, err := formatValue("+") 656 if err != nil { 657 return "", err 658 } 659 660 for _, v := range groupColumns { 661 b, err := formatValue(v) 662 if err != nil { 663 return "", err 664 } 665 bk = append(bk, b...) 666 bk = append(bk, separatorBuf...) 667 } 668 669 return string(bk), nil 670 } 671 672 // copy from server.formatValue() 673 // formatValue encode value into a string format 674 func formatValue(value interface{}) ([]byte, error) { 675 if value == nil { 676 return hack.Slice("NULL"), nil 677 } 678 switch v := value.(type) { 679 case int8: 680 return strconv.AppendInt(nil, int64(v), 10), nil 681 case int16: 682 return strconv.AppendInt(nil, int64(v), 10), nil 683 case int32: 684 return strconv.AppendInt(nil, int64(v), 10), nil 685 case int64: 686 return strconv.AppendInt(nil, int64(v), 10), nil 687 case int: 688 return strconv.AppendInt(nil, int64(v), 10), nil 689 case uint8: 690 return strconv.AppendUint(nil, uint64(v), 10), nil 691 case uint16: 692 return strconv.AppendUint(nil, uint64(v), 10), nil 693 case uint32: 694 return strconv.AppendUint(nil, uint64(v), 10), nil 695 case uint64: 696 return strconv.AppendUint(nil, uint64(v), 10), nil 697 case uint: 698 return strconv.AppendUint(nil, uint64(v), 10), nil 699 case float32: 700 return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil 701 case float64: 702 return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil 703 case []byte: 704 return v, nil 705 case string: 706 return hack.Slice(v), nil 707 default: 708 return nil, fmt.Errorf("invalid type %T", value) 709 } 710 }