github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_select.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/XiaoMi/Gaea/mysql"
    21  	"github.com/XiaoMi/Gaea/parser/ast"
    22  	"github.com/XiaoMi/Gaea/parser/opcode"
    23  	driver "github.com/XiaoMi/Gaea/parser/tidb-types/parser_driver"
    24  	"github.com/XiaoMi/Gaea/proxy/router"
    25  	"github.com/XiaoMi/Gaea/util"
    26  )
    27  
    28  // SelectPlan is the plan for select statement
    29  type SelectPlan struct {
    30  	basePlan
    31  	*TableAliasStmtInfo
    32  
    33  	stmt *ast.SelectStmt
    34  
    35  	distinct          bool   // 是否是SELECT DISTINCT
    36  	groupByColumn     []int  // GROUP BY 列索引
    37  	orderByColumn     []int  // ORDER BY 列索引
    38  	orderByDirections []bool // ORDER BY 方向, true: DESC
    39  	originColumnCount int    // 补列前的列长度
    40  	columnCount       int    // 补列后的列长度
    41  
    42  	aggregateFuncs map[int]AggregateFuncMerger // key = column index
    43  
    44  	offset int64 // LIMIT offset
    45  	count  int64 // LIMIT count, 未设置则为-1
    46  
    47  	sqls map[string]map[string][]string
    48  }
    49  
    50  // NewSelectPlan constructor of SelectPlan
    51  // db is the session db
    52  func NewSelectPlan(db string, sql string, r *router.Router) *SelectPlan {
    53  	return &SelectPlan{
    54  		TableAliasStmtInfo: NewTableAliasStmtInfo(db, sql, r),
    55  		aggregateFuncs:     make(map[int]AggregateFuncMerger),
    56  		offset:             -1,
    57  		count:              -1,
    58  	}
    59  }
    60  
    61  // ExecuteIn implement Plan
    62  func (s *SelectPlan) ExecuteIn(reqCtx *util.RequestContext, sess Executor) (*mysql.Result, error) {
    63  	sqls := s.GetSQLs()
    64  	if sqls == nil {
    65  		return nil, fmt.Errorf("SQL has not generated")
    66  	}
    67  
    68  	if len(sqls) == 0 {
    69  		r := newEmptyResultset(s, s.GetStmt())
    70  		ret := &mysql.Result{
    71  			Resultset: r,
    72  		}
    73  		return ret, nil
    74  	}
    75  
    76  	rs, err := sess.ExecuteSQLs(reqCtx, sqls)
    77  	if err != nil {
    78  		return nil, fmt.Errorf("execute in SelectPlan error: %v", err)
    79  	}
    80  
    81  	if s.isExecOnSingleNode() {
    82  		return rs[0], nil
    83  	} else {
    84  		r, err := MergeSelectResult(s, s.stmt, rs)
    85  		if err != nil {
    86  			return nil, fmt.Errorf("merge select result error: %v", err)
    87  		}
    88  		return r, nil
    89  	}
    90  }
    91  
    92  // GetStmt SelectStmt
    93  func (s *SelectPlan) GetStmt() *ast.SelectStmt {
    94  	return s.stmt
    95  }
    96  
    97  func (s *SelectPlan) setAggregateFuncMerger(idx int, merger AggregateFuncMerger) error {
    98  	if _, ok := s.aggregateFuncs[idx]; ok {
    99  		return fmt.Errorf("column %d already set", idx)
   100  	}
   101  	s.aggregateFuncs[idx] = merger
   102  	return nil
   103  }
   104  
   105  // HasLimit if the select statement has limit clause, return true
   106  func (s *SelectPlan) HasLimit() bool {
   107  	return s.count != -1
   108  }
   109  
   110  // GetLimitValue get offset, count in limit clause
   111  func (s *SelectPlan) GetLimitValue() (int64, int64) {
   112  	return s.offset, s.count
   113  }
   114  
   115  // HasGroupBy if the select statement has group by clause, return true
   116  func (s *SelectPlan) HasGroupBy() bool {
   117  	return len(s.groupByColumn) != 0
   118  }
   119  
   120  // GetOriginColumnCount get origin column count in statement,
   121  // since group by and order by may add extra columns to FieldList.
   122  func (s *SelectPlan) GetOriginColumnCount() int {
   123  	return s.originColumnCount
   124  }
   125  
   126  // GetColumnCount get column count with extra columns
   127  func (s *SelectPlan) GetColumnCount() int {
   128  	return s.columnCount
   129  }
   130  
   131  // GetGroupByColumnInfo get extra column offset and length for group by
   132  func (s *SelectPlan) GetGroupByColumnInfo() []int {
   133  	return s.groupByColumn
   134  }
   135  
   136  // HasOrderBy if select statement has order by clause, return true
   137  func (s *SelectPlan) HasOrderBy() bool {
   138  	return len(s.orderByDirections) != 0
   139  }
   140  
   141  // GetOrderByColumnInfo get extra column offset and length for order by
   142  func (s *SelectPlan) GetOrderByColumnInfo() ([]int, []bool) {
   143  	return s.orderByColumn, s.orderByDirections
   144  }
   145  
   146  // GetSQLs get generated SQLs
   147  // the first key is slice, the second key is backend database name, the value is sql list.
   148  func (s *SelectPlan) GetSQLs() map[string]map[string][]string {
   149  	return s.sqls
   150  }
   151  
   152  //执行计划是否仅仅涉及一个分片
   153  func (s *SelectPlan) isExecOnSingleNode() bool {
   154  	if len(s.result.indexes) == 1 {
   155  		return true
   156  	} else {
   157  		return false
   158  	}
   159  }
   160  
   161  // HandleSelectStmt build a SelectPlan
   162  // 处理SelectStmt语法树, 改写其中一些节点, 并获取路由信息和结果聚合函数
   163  func HandleSelectStmt(p *SelectPlan, stmt *ast.SelectStmt) error {
   164  	p.stmt = stmt // hold the reference of stmt
   165  
   166  	p.distinct = stmt.Distinct
   167  
   168  	if err := handleTableRefs(p, stmt); err != nil {
   169  		return fmt.Errorf("handle From error: %v", err)
   170  	}
   171  
   172  	// field list的处理必须在group by之前, 因为group by, order by会补列, 而这些补充的列是已经处理过的
   173  	if stmt.Fields != nil {
   174  		if err := handleFieldList(p, stmt); err != nil {
   175  			return fmt.Errorf("handle Fields error: %v", err)
   176  		}
   177  
   178  		// 记录补列前的Fields长度
   179  		p.originColumnCount = len(stmt.Fields.Fields)
   180  	}
   181  
   182  	if err := handleWhere(p, stmt); err != nil {
   183  		return fmt.Errorf("handle Where error: %v", err)
   184  	}
   185  
   186  	if err := postHandleHintDatabaseFunction(p); err != nil {
   187  		return fmt.Errorf("handle Hint error: %v", err)
   188  	}
   189  
   190  
   191  //如果是多节点执行,特殊处理groupby orderby having limit
   192  	if !p.isExecOnSingleNode() {
   193  		// group by的处理必须在table处理之后
   194  		if err := handleGroupBy(p, stmt); err != nil {
   195  			return fmt.Errorf("handle GroupBy error: %v", err)
   196  		}
   197  
   198  		// order by的处理必须在table处理之后
   199  		// 与group by补列的顺序没有要求, 只要保证处理返回结果去掉这些补充列时保持相反的顺序, 这里放在group by之后
   200  		if err := handleOrderBy(p, stmt); err != nil {
   201  			return fmt.Errorf("handle OrderBy error: %v", err)
   202  		}
   203  
   204  		handleExtraFieldList(p, stmt)
   205  
   206  		// 记录补列后的Fields长度, 后面的handler不会补列了
   207  		if stmt.Fields != nil {
   208  			p.columnCount = len(stmt.Fields.Fields)
   209  		}
   210  
   211  		
   212  
   213  		if err := handleHaving(p, stmt); err != nil {
   214  			return fmt.Errorf("handle Having error: %v", err)
   215  		}
   216  
   217  		if err := handleLimit(p, stmt); err != nil {
   218  			return fmt.Errorf("handle Limit error: %v", err)
   219  		}
   220  	}
   221  
   222  	if err := postHandleGlobalTableRouteResultInQuery(p.StmtInfo); err != nil {
   223  		return fmt.Errorf("post handle global table error: %v", err)
   224  	}
   225  
   226  	sqls, err := generateShardingSQLs(p.stmt, p.result, p.router)
   227  	if err != nil {
   228  		return fmt.Errorf("generate select SQL error: %v", err)
   229  	}
   230  
   231  	p.sqls = sqls
   232  
   233  	return nil
   234  }
   235  
   236  // 处理GroupBy, 把GroupBy的列补到FieldList中, 然后把GroupBy去掉
   237  func handleGroupBy(p *SelectPlan, stmt *ast.SelectStmt) error {
   238  	if stmt.GroupBy == nil {
   239  		return nil
   240  	}
   241  
   242  	groupByFields, err := createSelectFieldsFromByItems(p, stmt.GroupBy.Items)
   243  	if err != nil {
   244  		return fmt.Errorf("get group by fields error: %v", err)
   245  	}
   246  
   247  	for i := 0; i < len(groupByFields); i++ {
   248  		p.groupByColumn = append(p.groupByColumn, i+len(stmt.Fields.Fields))
   249  	}
   250  
   251  	// append group by fields
   252  	stmt.Fields.Fields = append(stmt.Fields.Fields, groupByFields...)
   253  
   254  	return nil
   255  }
   256  
   257  func handleOrderBy(p *SelectPlan, stmt *ast.SelectStmt) error {
   258  	if stmt.OrderBy == nil {
   259  		return nil
   260  	}
   261  
   262  	orderByFields, err := createSelectFieldsFromByItems(p, stmt.OrderBy.Items)
   263  	if err != nil {
   264  		return fmt.Errorf("get order by fields error: %v", err)
   265  	}
   266  
   267  	for i := 0; i < len(orderByFields); i++ {
   268  		p.orderByColumn = append(p.orderByColumn, i+len(stmt.Fields.Fields))
   269  	}
   270  
   271  	for _, f := range stmt.OrderBy.Items {
   272  		p.orderByDirections = append(p.orderByDirections, f.Desc)
   273  	}
   274  
   275  	stmt.Fields.Fields = append(stmt.Fields.Fields, orderByFields...)
   276  	return nil
   277  }
   278  
   279  func handleExtraFieldList(p *SelectPlan, stmt *ast.SelectStmt) {
   280  	selectFields := make(map[string]int)
   281  	for i := 0; i < p.originColumnCount; i++ {
   282  		field := stmt.Fields.Fields[i]
   283  		if field.AsName.L != "" {
   284  			selectFields[field.AsName.L] = i
   285  		}
   286  		if field, isColumnExpr := stmt.Fields.Fields[i].Expr.(*ast.ColumnNameExpr); isColumnExpr {
   287  			selectFields[field.Name.Name.L] = i
   288  		}
   289  	}
   290  
   291  	deleteNum := 0
   292  	for i := 0; i < len(p.groupByColumn); i++ {
   293  		p.groupByColumn[i] -= deleteNum
   294  		currColumnIndex := p.originColumnCount + i - deleteNum
   295  		field, isColumnExpr := stmt.Fields.Fields[currColumnIndex].Expr.(*ast.ColumnNameExpr)
   296  		if !isColumnExpr {
   297  			continue
   298  		}
   299  		if index, ok := selectFields[field.Name.Name.L]; !ok {
   300  			continue
   301  		} else {
   302  			stmt.Fields.Fields = append(stmt.Fields.Fields[:currColumnIndex], stmt.Fields.Fields[currColumnIndex+1:]...)
   303  			p.groupByColumn[i] = index
   304  			deleteNum++
   305  		}
   306  	}
   307  
   308  	for i := 0; i < len(p.orderByColumn); i++ {
   309  		p.orderByColumn[i] -= deleteNum
   310  		currColumnIndex := p.originColumnCount + len(p.groupByColumn) + i - deleteNum
   311  		field, isColumnExpr := stmt.Fields.Fields[currColumnIndex].Expr.(*ast.ColumnNameExpr)
   312  		if !isColumnExpr {
   313  			continue
   314  		}
   315  		if index, ok := selectFields[field.Name.Name.L]; !ok {
   316  			continue
   317  		} else {
   318  			stmt.Fields.Fields = append(stmt.Fields.Fields[:currColumnIndex], stmt.Fields.Fields[currColumnIndex+1:]...)
   319  			p.orderByColumn[i] = index
   320  			deleteNum++
   321  		}
   322  	}
   323  }
   324  
   325  func createSelectFieldsFromByItems(p *SelectPlan, items []*ast.ByItem) ([]*ast.SelectField, error) {
   326  	var ret []*ast.SelectField
   327  	for _, item := range items {
   328  		selectField, err := createSelectFieldFromByItem(p, item)
   329  		if err != nil {
   330  			return nil, err
   331  		}
   332  		ret = append(ret, selectField)
   333  	}
   334  	return ret, nil
   335  }
   336  
   337  func createSelectFieldFromByItem(p *SelectPlan, item *ast.ByItem) (*ast.SelectField, error) {
   338  	// 特殊处理DATABASE()这种情况
   339  	if funcExpr, ok := item.Expr.(*ast.FuncCallExpr); ok {
   340  		if funcExpr.FnName.L == "database" {
   341  			ret := &ast.SelectField{
   342  				Expr: item.Expr,
   343  			}
   344  			return ret, nil
   345  		}
   346  		return nil, fmt.Errorf("ByItem.Expr is a FuncCallExpr but not DATABASE()")
   347  	}
   348  
   349  	columnExpr, ok := item.Expr.(*ast.ColumnNameExpr)
   350  	if !ok {
   351  		return nil, fmt.Errorf("ByItem.Expr is not a ColumnNameExpr")
   352  	}
   353  
   354  	rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInField(p.TableAliasStmtInfo, columnExpr)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  
   359  	if need {
   360  		decorator := CreateColumnNameExprDecorator(columnExpr, rule, isAlias, p.GetRouteResult())
   361  		item.Expr = decorator
   362  	}
   363  
   364  	ret := &ast.SelectField{
   365  		Expr: item.Expr,
   366  	}
   367  	return ret, nil
   368  }
   369  
   370  // 处理from table和join on部分
   371  // 主要是改写table ExprNode, 并找到路由条件
   372  func handleTableRefs(p *SelectPlan, stmt *ast.SelectStmt) error {
   373  	tableRefs := stmt.From
   374  	if tableRefs == nil {
   375  		return nil
   376  	}
   377  
   378  	join := tableRefs.TableRefs
   379  	if join == nil {
   380  		return nil
   381  	}
   382  
   383  	return handleJoin(p.TableAliasStmtInfo, join)
   384  }
   385  
   386  func handleJoin(p *TableAliasStmtInfo, join *ast.Join) error {
   387  	if err := precheckJoinClause(join); err != nil {
   388  		return fmt.Errorf("precheck Join error: %v", err)
   389  	}
   390  
   391  	// 只允许最多两个表的JOIN
   392  	if join.Left != nil {
   393  		switch left := join.Left.(type) {
   394  		case *ast.TableSource:
   395  			// 改写两个表的node
   396  			if err := rewriteTableSource(p, left); err != nil {
   397  				return fmt.Errorf("rewrite left TableSource error: %v", err)
   398  			}
   399  		case *ast.Join:
   400  			if err := handleJoin(p, left); err != nil {
   401  				return fmt.Errorf("handle nested left Join error: %v", err)
   402  			}
   403  		default:
   404  			return fmt.Errorf("invalid left Join type: %T", join.Left)
   405  		}
   406  	}
   407  	if join.Right != nil {
   408  		right, ok := join.Right.(*ast.TableSource)
   409  		if !ok {
   410  			return fmt.Errorf("right is not TableSource, type: %T", join.Right)
   411  		}
   412  
   413  		if err := rewriteTableSource(p, right); err != nil {
   414  			return fmt.Errorf("rewrite right TableSource error: %v", err)
   415  		}
   416  	}
   417  
   418  	// 改写ON条件
   419  	if join.On != nil {
   420  		err := rewriteOnCondition(p, join.On)
   421  		if err != nil {
   422  			return fmt.Errorf("rewrite on condition error: %v", err)
   423  		}
   424  	}
   425  
   426  	return nil
   427  }
   428  
   429  func handleWhere(p *SelectPlan, stmt *ast.SelectStmt) (err error) {
   430  	if stmt.Where == nil {
   431  		return nil
   432  	}
   433  
   434  	has, result, decorator, err := handleComparisonExpr(p.TableAliasStmtInfo, stmt.Where)
   435  	if err != nil {
   436  		return fmt.Errorf("rewrite Where error: %v", err)
   437  	}
   438  	if has {
   439  		p.GetRouteResult().Inter(result)
   440  	}
   441  	stmt.Where = decorator
   442  	return nil
   443  }
   444  
   445  // 检查TableRefs中存在的不允许在分表中执行的语法
   446  func precheckJoinClause(join *ast.Join) error {
   447  	// 不允许USING的列名中出现DB名和表名, 因为目前Join子句的TableName不方便加装饰器
   448  	for _, c := range join.Using {
   449  		if c.Schema.String() != "" {
   450  			return fmt.Errorf("JOIN does not support USING column with schema")
   451  		}
   452  		if c.Table.String() != "" {
   453  			return fmt.Errorf("JOIN does not support USING column with table")
   454  		}
   455  	}
   456  	return nil
   457  }
   458  
   459  // 改写TableSource节点, 得到一个装饰器
   460  // Source必须为TableName节点或子查询
   461  func rewriteTableSource(p *TableAliasStmtInfo, tableSource *ast.TableSource) error {
   462  	switch ss := tableSource.Source.(type) {
   463  	case *ast.TableName:
   464  		return rewriteTableNameInTableSource(p, tableSource)
   465  	case *ast.SelectStmt:
   466  		if err := handleSubquerySelectStmt(p, ss); err != nil {
   467  			return fmt.Errorf("handleSubquerySelectStmt error: %v", err)
   468  		}
   469  		alias := tableSource.AsName.L
   470  		if alias != "" {
   471  			if _, err := p.RecordSubqueryTableAlias(alias); err != nil {
   472  				return fmt.Errorf("record subquery alias error: %v", err)
   473  			}
   474  		}
   475  		return nil
   476  	default:
   477  		return fmt.Errorf("field Source cannot handle, type: %T", tableSource.Source)
   478  	}
   479  }
   480  
   481  func rewriteTableNameInTableSource(p *TableAliasStmtInfo, tableSource *ast.TableSource) error {
   482  	tableName, ok := tableSource.Source.(*ast.TableName)
   483  	if !ok {
   484  		return fmt.Errorf("field Source is not type of TableName, type: %T", tableSource.Source)
   485  	}
   486  	alias := tableSource.AsName.L
   487  
   488  	rule, need, err := NeedCreateTableNameDecorator(p, tableName, alias)
   489  	if err != nil {
   490  		return fmt.Errorf("check NeedCreateTableNameDecorator error: %v", err)
   491  	}
   492  
   493  	if !need {
   494  		return nil
   495  	}
   496  
   497  	// 这是一个分片表或关联表, 创建一个TableName的装饰器, 并替换原有节点
   498  	d, err := CreateTableNameDecorator(tableName, rule, p.GetRouteResult())
   499  	if err != nil {
   500  		return fmt.Errorf("create TableNameDecorator error: %v", err)
   501  	}
   502  	tableSource.Source = d
   503  	return nil
   504  }
   505  
   506  func rewriteOnCondition(p *TableAliasStmtInfo, on *ast.OnCondition) error {
   507  	has, result, decorator, err := handleComparisonExpr(p, on.Expr)
   508  	if err != nil {
   509  		return fmt.Errorf("rewrite Expr in OnCondition error: %v", err)
   510  	}
   511  	if has {
   512  		p.GetRouteResult().Inter(result)
   513  	}
   514  	on.Expr = decorator
   515  	return nil
   516  }
   517  
   518  // 处理info中的hint
   519  // 目前只有mycat路由方式支持
   520  // hint路由会覆盖遍历语法树时计算出的路由
   521  func postHandleHintDatabaseFunction(p *SelectPlan) error {
   522  	if p.hintPhyDB == "" {
   523  		return nil
   524  	}
   525  
   526  	rule, ok := p.router.GetShardRule(p.result.db, p.result.table)
   527  	if !ok {
   528  		return fmt.Errorf("sharding rule of route result not found, result: %v", p.result)
   529  	}
   530  	mr, ok := rule.(router.MycatRule)
   531  	if !ok {
   532  		return fmt.Errorf("sharding rule is not mycat mode, result: %v", p.result)
   533  	}
   534  
   535  	if !router.IsMycatShardingRule(mr.GetType()) { // TODO: need refactor, why is MycatRule's type not mycat rule?
   536  		return fmt.Errorf("only mycat rule supports database function hint")
   537  	}
   538  
   539  	idx, ok := mr.GetTableIndexByDatabaseName(p.hintPhyDB)
   540  	if !ok {
   541  		return fmt.Errorf("hint db not found: %s", p.hintPhyDB)
   542  	}
   543  
   544  	p.result.indexes = []int{idx}
   545  	return nil
   546  }
   547  
   548  // ColumnNameRewriteVisitor visit ColumnNameExpr, check if need decorate, and then decorate it.
   549  type ColumnNameRewriteVisitor struct {
   550  	info *TableAliasStmtInfo
   551  }
   552  
   553  // NewColumnNameRewriteVisitor constructor of ColumnNameRewriteVisitor
   554  func NewColumnNameRewriteVisitor(p *TableAliasStmtInfo) *ColumnNameRewriteVisitor {
   555  	return &ColumnNameRewriteVisitor{
   556  		info: p,
   557  	}
   558  }
   559  
   560  // Enter implement ast.Visitor
   561  func (s *ColumnNameRewriteVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
   562  	return n, false
   563  }
   564  
   565  // Leave implement ast.Visitor
   566  func (s *ColumnNameRewriteVisitor) Leave(n ast.Node) (node ast.Node, ok bool) {
   567  	field, ok := n.(*ast.ColumnNameExpr)
   568  	if !ok {
   569  		return n, true
   570  	}
   571  	rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInField(s.info, field)
   572  	if err != nil {
   573  		panic(fmt.Errorf("check NeedCreateColumnNameExprDecoratorInField in ColumnNameExpr error: %v", err))
   574  	}
   575  	if need {
   576  		decorator := CreateColumnNameExprDecorator(field, rule, isAlias, s.info.GetRouteResult())
   577  		return decorator, true
   578  	}
   579  
   580  	return n, true
   581  }
   582  
   583  func handleFieldList(p *SelectPlan, stmt *ast.SelectStmt) (err error) {
   584  	defer func() {
   585  		if e := recover(); e != nil {
   586  			err = fmt.Errorf("handleFieldList panic: %v", e)
   587  		}
   588  	}()
   589  
   590  	fields := stmt.Fields
   591  	if fields == nil {
   592  		return nil
   593  	}
   594  
   595  	// 先用一个Visitor生成一个替换表名的装饰器
   596  	// 这里如果出错, 只能通过panic返回err
   597  	columnNameRewriter := NewColumnNameRewriteVisitor(p.TableAliasStmtInfo)
   598  	fields.Accept(columnNameRewriter)
   599  
   600  	// 如果最外层是聚合函数, 则生成一个聚合函数装饰器, 并记录对应的列位置
   601  	// 只处理最外层的聚合函数.
   602  	for i, f := range fields.Fields {
   603  		switch field := f.Expr.(type) {
   604  		case *ast.AggregateFuncExpr:
   605  			merger, err := CreateAggregateFunctionMerger(field.F, i)
   606  			if err != nil {
   607  				return fmt.Errorf("create aggregate function merger error, column index: %d, err: %v", i, err)
   608  			}
   609  			if err := p.setAggregateFuncMerger(i, merger); err != nil {
   610  				return fmt.Errorf("set aggregate function merger error, column index: %d, err: %v", i, err)
   611  			}
   612  		default:
   613  			// do nothing
   614  		}
   615  	}
   616  	return nil
   617  }
   618  
   619  func handleHaving(p *SelectPlan, stmt *ast.SelectStmt) (err error) {
   620  	defer func() {
   621  		if e := recover(); e != nil {
   622  			err = fmt.Errorf("handleHaving panic: %v", e)
   623  		}
   624  	}()
   625  
   626  	having := stmt.Having
   627  	if having == nil {
   628  		return nil
   629  	}
   630  
   631  	// 先用一个Visitor生成一个替换表名的装饰器
   632  	// 这里如果出错, 只能通过panic返回err
   633  	columnNameRewriter := NewColumnNameRewriteVisitor(p.TableAliasStmtInfo)
   634  	having.Accept(columnNameRewriter)
   635  	return nil
   636  }
   637  
   638  func handleComparisonExpr(p *TableAliasStmtInfo, comp ast.ExprNode) (bool, []int, ast.ExprNode, error) {
   639  	switch expr := comp.(type) {
   640  	case *ast.BinaryOperationExpr:
   641  		return handleBinaryOperationExpr(p, expr)
   642  	case *ast.PatternInExpr:
   643  		return handlePatternInExpr(p, expr)
   644  	case *ast.BetweenExpr:
   645  		return handleBetweenExpr(p, expr)
   646  	case *ast.ParenthesesExpr:
   647  		has, routeResult, newExpr, err := handleComparisonExpr(p, expr.Expr)
   648  		expr.Expr = newExpr
   649  		return has, routeResult, expr, err
   650  	default:
   651  		// 其他情况只替换表名 (但是不处理根节点是ColumnNameExpr的情况, 理论上也不会出现这种情况)
   652  		columnNameRewriter := NewColumnNameRewriteVisitor(p)
   653  		expr.Accept(columnNameRewriter)
   654  		return false, p.GetRouteResult().GetShardIndexes(), comp, nil
   655  	}
   656  }
   657  
   658  func handlePatternInExpr(p *TableAliasStmtInfo, expr *ast.PatternInExpr) (bool, []int, ast.ExprNode, error) {
   659  	rule, need, isAlias, err := NeedCreatePatternInExprDecorator(p, expr)
   660  	if err != nil {
   661  		return false, nil, nil, fmt.Errorf("check PatternInExpr error: %v", err)
   662  	}
   663  	if !need {
   664  		return false, nil, expr, nil
   665  	}
   666  	decorator, err := CreatePatternInExprDecorator(expr, rule, isAlias, p.GetRouteResult())
   667  	if err != nil {
   668  		return false, nil, nil, fmt.Errorf("create PatternInExprDecorator error: %v", err)
   669  	}
   670  	return true, decorator.GetCurrentRouteResult(), decorator, nil
   671  }
   672  
   673  func handleBetweenExpr(p *TableAliasStmtInfo, expr *ast.BetweenExpr) (bool, []int, ast.ExprNode, error) {
   674  	rule, need, isAlias, err := NeedCreateBetweenExprDecorator(p, expr)
   675  	if err != nil {
   676  		return false, nil, nil, fmt.Errorf("check BetweenExpr error: %v", err)
   677  	}
   678  	if !need {
   679  		return false, nil, expr, nil
   680  	}
   681  
   682  	decorator, err := CreateBetweenExprDecorator(expr, rule, isAlias, p.GetRouteResult())
   683  	if err != nil {
   684  		return false, nil, nil, fmt.Errorf("create CreateBetweenExprDecorator error: %v", err)
   685  	}
   686  
   687  	return true, decorator.GetCurrentRouteResult(), decorator, nil
   688  }
   689  
   690  // return value: hasRoutingResult, RouteResult, Decorator, error
   691  // the Decorator must not be nil. If no modification to the input expr, just return it.
   692  func handleBinaryOperationExpr(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) {
   693  	_, ok := opcode.Ops[expr.Op]
   694  	if !ok {
   695  		return false, nil, nil, fmt.Errorf("unknown BinaryOperationExpr.Op: %v", expr.Op)
   696  	}
   697  
   698  	switch expr.Op {
   699  	case opcode.LogicAnd, opcode.LogicOr:
   700  		return handleBinaryOperationExprLogic(p, expr)
   701  	case opcode.EQ, opcode.NE, opcode.GT, opcode.GE, opcode.LT, opcode.LE:
   702  		return handleBinaryOperationExprMathCompare(p, expr)
   703  	default:
   704  		return handleBinaryOperationExprOther(p, expr)
   705  	}
   706  }
   707  
   708  // 处理逻辑比较运算
   709  func handleBinaryOperationExprLogic(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) {
   710  	lHas, lResult, lDecorator, lErr := handleComparisonExpr(p, expr.L)
   711  	if lErr != nil {
   712  		return false, nil, nil, fmt.Errorf("handle BinaryOperationExpr.L error: %v", lErr)
   713  	}
   714  	rHas, rResult, rDecorator, rErr := handleComparisonExpr(p, expr.R)
   715  	if rErr != nil {
   716  		return false, nil, nil, fmt.Errorf("handle BinaryOperationExpr.R error: %v", rErr)
   717  	}
   718  
   719  	if lDecorator != nil {
   720  		expr.L = lDecorator
   721  	}
   722  	if rDecorator != nil {
   723  		expr.R = rDecorator
   724  	}
   725  
   726  	has, result := mergeBinaryOperationRouteResult(expr.Op, lHas, lResult, rHas, rResult)
   727  	return has, result, expr, nil
   728  }
   729  
   730  // 处理算术比较运算
   731  // 如果出现列名, 则必须为列名与列名比较, 列名与值比较, 否则会报错 (比如 id + 2 = 3 就会报错, 因为 id + 2 处理不了)
   732  // 如果是其他情况, 则直接返回 (如 1 = 1 这种)
   733  func handleBinaryOperationExprMathCompare(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) {
   734  	lType := getExprNodeTypeInBinaryOperation(expr.L)
   735  	rType := getExprNodeTypeInBinaryOperation(expr.R)
   736  
   737  	// handle hint database function: SELECT * from tbl where DATABASE() = db_0 / 'db_0' / `db_0`
   738  	if expr.Op == opcode.EQ {
   739  		if lType == FuncCallExpr {
   740  			hintDB, err := getDatabaseFuncHint(expr.L.(*ast.FuncCallExpr), expr.R)
   741  			if err != nil {
   742  				return false, nil, nil, fmt.Errorf("get database function hint error: %v", err)
   743  			}
   744  			if hintDB != "" {
   745  				p.hintPhyDB = hintDB
   746  				return false, nil, expr, nil
   747  			}
   748  		} else if rType == FuncCallExpr {
   749  			hintDB, err := getDatabaseFuncHint(expr.R.(*ast.FuncCallExpr), expr.L)
   750  			if err != nil {
   751  				return false, nil, nil, fmt.Errorf("get database function hint error: %v", err)
   752  			}
   753  			if hintDB != "" {
   754  				p.hintPhyDB = hintDB
   755  				return false, nil, expr, nil
   756  			}
   757  		}
   758  	}
   759  
   760  	if lType == ColumnNameExpr && rType == ColumnNameExpr {
   761  		return handleBinaryOperationExprCompareLeftColumnRightColumn(p, expr)
   762  	}
   763  
   764  	if lType == ColumnNameExpr {
   765  		if rType == ValueExpr {
   766  			return handleBinaryOperationExprCompareLeftColumnRightValue(p, expr, getFindTableIndexesFunc(expr.Op))
   767  		}
   768  		column := expr.L.(*ast.ColumnNameExpr)
   769  		rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, column)
   770  		if err != nil {
   771  			return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.L: %v", err)
   772  		}
   773  		if !need {
   774  			return false, nil, expr, nil
   775  		}
   776  
   777  		decorator := CreateColumnNameExprDecorator(column, rule, isAlias, p.GetRouteResult())
   778  		expr.L = decorator
   779  		return false, nil, expr, nil
   780  	}
   781  
   782  	if rType == ColumnNameExpr {
   783  		if lType == ValueExpr {
   784  			return handleBinaryOperationExprCompareLeftValueRightColumn(p, expr, getFindTableIndexesFunc(inverseOperator(expr.Op)))
   785  		}
   786  		column := expr.R.(*ast.ColumnNameExpr)
   787  		rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, column)
   788  		if err != nil {
   789  			return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.R: %v", err)
   790  		}
   791  		if !need {
   792  			return false, nil, expr, nil
   793  		}
   794  
   795  		decorator := CreateColumnNameExprDecorator(column, rule, isAlias, p.GetRouteResult())
   796  		expr.R = decorator
   797  		return false, nil, expr, nil
   798  	}
   799  
   800  	return false, nil, expr, nil
   801  }
   802  
   803  // 处理其他情况的运算
   804  // 如果出现分表列, 只创建一个替换表名的装饰器, 不计算路由. 因此返回结果前两个一定是false, nil
   805  func handleBinaryOperationExprOther(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) {
   806  	if lColumn, ok := expr.L.(*ast.ColumnNameExpr); ok {
   807  		lRule, lNeed, lIsAlias, lErr := NeedCreateColumnNameExprDecoratorInCondition(p, lColumn)
   808  		if lErr != nil {
   809  			return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.L: %v", lErr)
   810  		}
   811  
   812  		if lNeed {
   813  			lDecorator := CreateColumnNameExprDecorator(lColumn, lRule, lIsAlias, p.GetRouteResult())
   814  			expr.L = lDecorator
   815  		}
   816  	}
   817  	if rColumn, ok := expr.R.(*ast.ColumnNameExpr); ok {
   818  		rRule, rNeed, rIsAlias, rErr := NeedCreateColumnNameExprDecoratorInCondition(p, rColumn)
   819  		if rErr != nil {
   820  			return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.R: %v", rErr)
   821  		}
   822  		if rNeed {
   823  			rDecorator := CreateColumnNameExprDecorator(rColumn, rRule, rIsAlias, p.GetRouteResult())
   824  			expr.R = rDecorator
   825  		}
   826  	}
   827  	return false, nil, expr, nil
   828  }
   829  
   830  // 获取mycat路由模式下的hint物理DB名
   831  func getDatabaseFuncHint(f *ast.FuncCallExpr, v ast.ExprNode) (string, error) {
   832  	if f.FnName.L != "database" {
   833  		return "", nil
   834  	}
   835  	switch vv := v.(type) {
   836  	case *ast.ColumnNameExpr:
   837  		return vv.Name.Name.String(), nil
   838  	case *driver.ValueExpr:
   839  		return vv.GetString(), nil
   840  	default:
   841  		return "", fmt.Errorf("invalid value type of database function hint: %T", v)
   842  	}
   843  }
   844  
   845  // 返回一个根据路由信息和路由值获取路由结果的函数
   846  // 左边为列名, 右边为参数
   847  func getFindTableIndexesFunc(op opcode.Op) func(rule router.Rule, columnName string, v interface{}) ([]int, error) {
   848  	findTableIndexesFunc := func(rule router.Rule, columnName string, v interface{}) ([]int, error) {
   849  		// 如果不是分表列, 则需要返回所有分片
   850  		if rule.GetShardingColumn() != columnName {
   851  			return rule.GetSubTableIndexes(), nil
   852  		}
   853  
   854  		// 如果是分表列, 还需要根据运算符判断
   855  		switch op {
   856  		case opcode.EQ:
   857  			index, err := rule.FindTableIndex(v)
   858  			if err != nil {
   859  				return nil, err
   860  			}
   861  			return []int{index}, nil
   862  		case opcode.NE:
   863  			return rule.GetSubTableIndexes(), nil
   864  		case opcode.GT, opcode.GE, opcode.LT, opcode.LE:
   865  			// 如果是range路由, 需要做一些特殊处理
   866  			if rangeShard, ok := rule.GetShard().(router.RangeShard); ok {
   867  				index, err := rule.FindTableIndex(v)
   868  				if err != nil {
   869  					return nil, err
   870  				}
   871  				if op == opcode.LT || op == opcode.LE {
   872  					if op == opcode.LT {
   873  						index = adjustShardIndex(rangeShard, v, index)
   874  					}
   875  					return makeList(rule.GetFirstTableIndex(), index+1), nil
   876  				} else {
   877  					return makeList(index, rule.GetLastTableIndex()+1), nil
   878  				}
   879  			}
   880  
   881  			// 如果不是 (即hash路由), 则返回所有分片
   882  			return rule.GetSubTableIndexes(), nil
   883  		default: // should not going here
   884  			return rule.GetSubTableIndexes(), nil
   885  		}
   886  	}
   887  
   888  	return findTableIndexesFunc
   889  }
   890  
   891  // copy from PlanBuilder.adjustShardIndex()
   892  func adjustShardIndex(s router.RangeShard, value interface{}, index int) int {
   893  	if s.EqualStart(value, index) {
   894  		return index - 1
   895  	}
   896  	return index
   897  }
   898  
   899  func inverseOperator(op opcode.Op) opcode.Op {
   900  	switch op {
   901  	case opcode.GT:
   902  		return opcode.LT
   903  	case opcode.GE:
   904  		return opcode.LE
   905  	case opcode.LT:
   906  		return opcode.GT
   907  	case opcode.LE:
   908  		return opcode.GE
   909  	default:
   910  		return op
   911  	}
   912  }
   913  
   914  func handleBinaryOperationExprCompareLeftColumnRightColumn(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr) (bool, []int, ast.ExprNode, error) {
   915  	lColumn := expr.L.(*ast.ColumnNameExpr)
   916  	lRule, lNeed, lIsAlias, lErr := NeedCreateColumnNameExprDecoratorInCondition(p, lColumn)
   917  	if lErr != nil {
   918  		return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.L: %v", lErr)
   919  	}
   920  	rColumn := expr.R.(*ast.ColumnNameExpr)
   921  	rRule, rNeed, rIsAlias, rErr := NeedCreateColumnNameExprDecoratorInCondition(p, rColumn)
   922  	if rErr != nil {
   923  		return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.R: %v", rErr)
   924  	}
   925  
   926  	if lNeed {
   927  		lDecorator := CreateColumnNameExprDecorator(lColumn, lRule, lIsAlias, p.GetRouteResult())
   928  		expr.L = lDecorator
   929  	}
   930  	if rNeed {
   931  		rDecorator := CreateColumnNameExprDecorator(rColumn, rRule, rIsAlias, p.GetRouteResult())
   932  		expr.R = rDecorator
   933  	}
   934  	return false, nil, expr, nil
   935  }
   936  
   937  func handleBinaryOperationExprCompareLeftColumnRightValue(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr, findTableIndexes func(router.Rule, string, interface{}) ([]int, error)) (bool, []int, ast.ExprNode, error) {
   938  	column := expr.L.(*ast.ColumnNameExpr)
   939  	rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, column)
   940  	if err != nil {
   941  		return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.L: %v", err)
   942  	}
   943  	if !need {
   944  		return false, nil, expr, nil
   945  	}
   946  
   947  	decorator := CreateColumnNameExprDecorator(column, rule, isAlias, p.GetRouteResult())
   948  	expr.L = decorator
   949  
   950  	if rule.GetType() == router.GlobalTableRuleType {
   951  		return false, nil, expr, nil
   952  	}
   953  
   954  	valueExpr := expr.R.(*driver.ValueExpr)
   955  	v, err := util.GetValueExprResult(valueExpr)
   956  	if err != nil {
   957  		return false, nil, nil, fmt.Errorf("get ValueExpr value error: %v", err)
   958  	}
   959  
   960  	tableIndexes, err := findTableIndexes(rule, column.Name.Name.L, v)
   961  	if err != nil {
   962  		return false, nil, nil, fmt.Errorf("find table index error: %v", err)
   963  	}
   964  
   965  	return true, tableIndexes, expr, nil
   966  }
   967  
   968  func handleBinaryOperationExprCompareLeftValueRightColumn(p *TableAliasStmtInfo, expr *ast.BinaryOperationExpr, findTableIndexes func(router.Rule, string, interface{}) ([]int, error)) (bool, []int, ast.ExprNode, error) {
   969  	column := expr.R.(*ast.ColumnNameExpr)
   970  	rule, need, isAlias, err := NeedCreateColumnNameExprDecoratorInCondition(p, column)
   971  	if err != nil {
   972  		return false, nil, nil, fmt.Errorf("check ColumnNameExpr error in BinaryOperationExpr.R: %v", err)
   973  	}
   974  	if !need {
   975  		return false, nil, expr, nil
   976  	}
   977  
   978  	decorator := CreateColumnNameExprDecorator(column, rule, isAlias, p.GetRouteResult())
   979  	expr.R = decorator
   980  
   981  	if rule.GetType() == router.GlobalTableRuleType {
   982  		return false, nil, expr, nil
   983  	}
   984  
   985  	valueExpr := expr.L.(*driver.ValueExpr)
   986  	v, err := util.GetValueExprResult(valueExpr)
   987  	if err != nil {
   988  		return false, nil, nil, fmt.Errorf("get ValueExpr value error: %v", err)
   989  	}
   990  
   991  	tableIndexes, err := findTableIndexes(rule, column.Name.Name.L, v)
   992  	if err != nil {
   993  		return false, nil, nil, fmt.Errorf("find table index error: %v", err)
   994  	}
   995  
   996  	return true, tableIndexes, expr, nil
   997  }
   998  
   999  func mergeBinaryOperationRouteResult(op opcode.Op, lHas bool, lResult []int, rHas bool, rResult []int) (bool, []int) {
  1000  	switch op {
  1001  	case opcode.LogicAnd:
  1002  		if lHas == false && rHas == false {
  1003  			return false, nil
  1004  		}
  1005  		if lHas && rHas {
  1006  			return true, interList(lResult, rResult)
  1007  		}
  1008  		if lHas {
  1009  			return true, lResult
  1010  		}
  1011  		if rHas {
  1012  			return true, rResult
  1013  		}
  1014  	case opcode.LogicOr:
  1015  		if lHas && rHas {
  1016  			return true, unionList(lResult, rResult)
  1017  		}
  1018  		return false, nil
  1019  	}
  1020  	return false, nil
  1021  }
  1022  
  1023  func handleLimit(p *SelectPlan, stmt *ast.SelectStmt) error {
  1024  	need, originOffset, originCount, newLimit := NeedRewriteLimitOrCreateRewrite(stmt)
  1025  	p.offset = originOffset
  1026  	p.count = originCount
  1027  	if need {
  1028  		stmt.Limit = newLimit
  1029  	}
  1030  	return nil
  1031  }
  1032  
  1033  func getTableInfoFromTableName(t *ast.TableName) (string, string) {
  1034  	return t.Schema.O, t.Name.L
  1035  }
  1036  
  1037  func getColumnInfoFromColumnName(t *ast.ColumnName) (string, string, string) {
  1038  	return t.Schema.O, t.Table.L, t.Name.L
  1039  }