github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/decorator_pattern_in_expr.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"sort"
    21  
    22  	"github.com/XiaoMi/Gaea/parser/ast"
    23  	"github.com/XiaoMi/Gaea/parser/format"
    24  	driver "github.com/XiaoMi/Gaea/parser/tidb-types/parser_driver"
    25  	"github.com/XiaoMi/Gaea/parser/types"
    26  	"github.com/XiaoMi/Gaea/proxy/router"
    27  	"github.com/XiaoMi/Gaea/util"
    28  )
    29  
    30  // type check
    31  var _ ast.ExprNode = &PatternInExprDecorator{}
    32  
    33  // PatternInExprDecorator decorate PatternInExpr
    34  // 这里记录tableIndexes和indexValueMap是没有问题的, 因为如果是OR条件, 导致路由索引[]int变宽,
    35  // 改写的SQL只是IN这一项没有值, 并不会影响SQL正确性和执行结果.
    36  type PatternInExprDecorator struct {
    37  	Expr ast.ExprNode
    38  	List []ast.ExprNode
    39  	Not  bool
    40  
    41  	tableIndexes  []int
    42  	indexValueMap map[int][]ast.ExprNode // tableIndex - valueList
    43  
    44  	rule   router.Rule
    45  	result *RouteResult
    46  }
    47  
    48  // NeedCreatePatternInExprDecorator check if PatternInExpr needs decoration
    49  func NeedCreatePatternInExprDecorator(p *TableAliasStmtInfo, n *ast.PatternInExpr) (router.Rule, bool, bool, error) {
    50  	if n.Sel != nil {
    51  		return nil, false, false, fmt.Errorf("TableName does not support Sel in sharding")
    52  	}
    53  
    54  	// 如果不是ColumnNameExpr, 则不做任何路由计算和装饰, 直接返回
    55  	columnNameExpr, ok := n.Expr.(*ast.ColumnNameExpr)
    56  	if !ok {
    57  		return nil, false, false, nil
    58  	}
    59  
    60  	rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, columnNameExpr)
    61  	if err != nil {
    62  		return nil, false, false, fmt.Errorf("check ColumnName error: %v", err)
    63  	}
    64  
    65  	if !need && rule == nil {
    66  		return nil, false, false, nil
    67  	}
    68  
    69  	// ColumnName不需要装饰, 不代表PatternInExpr不需要装饰, 对全局表来说, PatternInExpr也需要装饰
    70  	if rule.GetType() == router.GlobalTableRuleType {
    71  		return rule, true, isAlias, nil
    72  	}
    73  
    74  	return rule, need, isAlias, nil
    75  }
    76  
    77  // CreatePatternInExprDecorator create PatternInExprDecorator
    78  // 必须先检查是否需要装饰
    79  func CreatePatternInExprDecorator(n *ast.PatternInExpr, rule router.Rule, isAlias bool, result *RouteResult) (*PatternInExprDecorator, error) {
    80  	columnNameExpr := n.Expr.(*ast.ColumnNameExpr)
    81  	columnNameExprDecorator := CreateColumnNameExprDecorator(columnNameExpr, rule, isAlias, result)
    82  
    83  	tableIndexes, indexValueMap, err := getPatternInRouteResult(columnNameExpr.Name, n.Not, rule, n.List)
    84  	if err != nil {
    85  		return nil, fmt.Errorf("getPatternInRouteResult error: %v", err)
    86  	}
    87  
    88  	ret := &PatternInExprDecorator{
    89  		Expr:          columnNameExprDecorator,
    90  		List:          n.List,
    91  		Not:           n.Not,
    92  		rule:          rule,
    93  		result:        result,
    94  		tableIndexes:  tableIndexes,
    95  		indexValueMap: indexValueMap,
    96  	}
    97  
    98  	return ret, nil
    99  }
   100  
   101  // 返回路由, 并构建路由索引到值的映射.
   102  // 如果是分片条件, 则构建值到索引的映射.
   103  // 例如, 1,2,3,4分别映射到索引0,2则[]int = [0,2], map=[0:[1,2], 2:[3,4]]
   104  // 如果是全路由, 则每个分片都要返回所有的值.
   105  func getPatternInRouteResult(n *ast.ColumnName, isNotIn bool, rule router.Rule, values []ast.ExprNode) ([]int, map[int][]ast.ExprNode, error) {
   106  	// 如果是全局表, 则返回广播路由
   107  	if rule.GetType() == router.GlobalTableRuleType {
   108  		indexes := rule.GetSubTableIndexes()
   109  		valueMap := getBroadcastValueMap(indexes, values)
   110  		return indexes, valueMap, nil
   111  	}
   112  
   113  	if err := checkValueType(values); err != nil {
   114  		return nil, nil, fmt.Errorf("check value error: %v", err)
   115  	}
   116  
   117  	_, _, column := getColumnInfoFromColumnName(n)
   118  
   119  	if isNotIn {
   120  		indexes := rule.GetSubTableIndexes()
   121  		valueMap := getBroadcastValueMap(indexes, values)
   122  		return indexes, valueMap, nil
   123  	}
   124  	if rule.GetShardingColumn() != column {
   125  		indexes := rule.GetSubTableIndexes()
   126  		valueMap := getBroadcastValueMap(indexes, values)
   127  		return indexes, valueMap, nil
   128  	}
   129  
   130  	var indexes []int
   131  	valueMap := make(map[int][]ast.ExprNode)
   132  	for _, vi := range values {
   133  		v, _ := vi.(*driver.ValueExpr)
   134  		value, err := util.GetValueExprResult(v)
   135  		if err != nil {
   136  			return nil, nil, err
   137  		}
   138  		idx, err := rule.FindTableIndex(value)
   139  		if err != nil {
   140  			return nil, nil, err
   141  		}
   142  		if _, ok := valueMap[idx]; !ok {
   143  			indexes = append(indexes, idx)
   144  		}
   145  		valueMap[idx] = append(valueMap[idx], vi)
   146  	}
   147  	sort.Ints(indexes)
   148  	return indexes, valueMap, nil
   149  }
   150  
   151  // 所有的值类型必须为*driver.ValueExpr
   152  func checkValueType(values []ast.ExprNode) error {
   153  	for i, v := range values {
   154  		if _, ok := v.(*driver.ValueExpr); !ok {
   155  			return fmt.Errorf("value is not ValueExpr, index: %d, type: %T", i, v)
   156  		}
   157  	}
   158  	return nil
   159  }
   160  
   161  func getBroadcastValueMap(subTableIndexes []int, nodes []ast.ExprNode) map[int][]ast.ExprNode {
   162  	ret := make(map[int][]ast.ExprNode)
   163  	for _, idx := range subTableIndexes {
   164  		ret[idx] = append(ret[idx], nodes...)
   165  	}
   166  	return ret
   167  }
   168  
   169  // GetCurrentRouteResult get route result of current decorator
   170  func (p *PatternInExprDecorator) GetCurrentRouteResult() []int {
   171  	return p.tableIndexes
   172  }
   173  
   174  // Restore implement ast.Node
   175  func (p *PatternInExprDecorator) Restore(ctx *format.RestoreCtx) error {
   176  	tableIndex, err := p.result.GetCurrentTableIndex()
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	if err := p.Expr.Restore(ctx); err != nil {
   182  		return fmt.Errorf("an error occurred while restore PatternInExpr.Expr: %v", err)
   183  	}
   184  	if p.Not {
   185  		ctx.WriteKeyWord(" NOT IN ")
   186  	} else {
   187  		ctx.WriteKeyWord(" IN ")
   188  	}
   189  
   190  	ctx.WritePlain("(")
   191  	for i, expr := range p.indexValueMap[tableIndex] {
   192  		if i != 0 {
   193  			ctx.WritePlain(",")
   194  		}
   195  		if err := expr.Restore(ctx); err != nil {
   196  			return fmt.Errorf("an error occurred while restore PatternInExpr.List[%d], err: %v", i, err)
   197  		}
   198  	}
   199  	ctx.WritePlain(")")
   200  
   201  	return nil
   202  }
   203  
   204  // Accept implement ast.Node
   205  func (p *PatternInExprDecorator) Accept(v ast.Visitor) (node ast.Node, ok bool) {
   206  	return p, ok
   207  }
   208  
   209  // Text implement ast.Node
   210  func (p *PatternInExprDecorator) Text() string {
   211  	return ""
   212  }
   213  
   214  // SetText implement ast.Node
   215  func (p *PatternInExprDecorator) SetText(text string) {
   216  	return
   217  }
   218  
   219  // SetType implement ast.ExprNode
   220  func (p *PatternInExprDecorator) SetType(tp *types.FieldType) {
   221  	return
   222  }
   223  
   224  // GetType implement ast.ExprNode
   225  func (p *PatternInExprDecorator) GetType() *types.FieldType {
   226  	return nil
   227  }
   228  
   229  // SetFlag implement ast.ExprNode
   230  func (p *PatternInExprDecorator) SetFlag(flag uint64) {
   231  	return
   232  }
   233  
   234  // GetFlag implement ast.ExprNode
   235  func (p *PatternInExprDecorator) GetFlag() uint64 {
   236  	return 0
   237  }
   238  
   239  // Format implement ast.ExprNode
   240  func (p *PatternInExprDecorator) Format(w io.Writer) {
   241  	return
   242  }