github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_select_subquery.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/XiaoMi/Gaea/parser/ast"
    21  	"github.com/XiaoMi/Gaea/proxy/router"
    22  )
    23  
    24  // SubqueryColumnNameRewriteVisitor visit ColumnNameExpr in subquery, check if need decorate, and then decorate it.
    25  type SubqueryColumnNameRewriteVisitor struct {
    26  	info *TableAliasStmtInfo
    27  }
    28  
    29  // NewSubqueryColumnNameRewriteVisitor consturctor of SubqueryColumnNameRewriteVisitor
    30  func NewSubqueryColumnNameRewriteVisitor(p *TableAliasStmtInfo) *SubqueryColumnNameRewriteVisitor {
    31  	return &SubqueryColumnNameRewriteVisitor{
    32  		info: p,
    33  	}
    34  }
    35  
    36  // Enter implement ast.Visitor
    37  func (s *SubqueryColumnNameRewriteVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
    38  	return n, false
    39  }
    40  
    41  // Leave implement ast.Visitor
    42  func (s *SubqueryColumnNameRewriteVisitor) Leave(n ast.Node) (node ast.Node, ok bool) {
    43  	field, ok := n.(*ast.ColumnNameExpr)
    44  	if !ok {
    45  		return n, true
    46  	}
    47  
    48  	db, table, _ := getColumnInfoFromColumnName(field.Name)
    49  
    50  	rule, _ := s.info.getShardRule(db, table)
    51  	if rule == nil || rule.GetType() == router.GlobalTableRuleType {
    52  		return n, true
    53  	}
    54   
    55  	decorator := CreateColumnNameExprDecorator(field, rule, false, s.info.GetRouteResult())
    56  	return decorator, true
    57  }
    58  
    59  func handleSubquerySelectStmt(p *TableAliasStmtInfo, subquery *ast.SelectStmt) (err error) {
    60  	defer func() {
    61  		if v := recover(); v != nil {
    62  			err = fmt.Errorf("handleSubqueryExpr panic: %v", v)
    63  		}
    64  	}()
    65  
    66  	if err = handleSubqueryTableRefs(p, subquery); err != nil {
    67  		return fmt.Errorf("handle From error: %v", err)
    68  	}
    69  
    70  	// 对所有可能含有ColumnName的Node做装饰.
    71  	columnRewritter := NewSubqueryColumnNameRewriteVisitor(p)
    72  	if subquery.Where != nil {
    73  		subquery.Where.Accept(columnRewritter)
    74  	}
    75  	if subquery.Fields != nil {
    76  		subquery.Fields.Accept(columnRewritter)
    77  	}
    78  	if subquery.GroupBy != nil {
    79  		subquery.GroupBy.Accept(columnRewritter)
    80  	}
    81  	if subquery.Having != nil {
    82  		subquery.Having.Accept(columnRewritter)
    83  	}
    84  	if subquery.OrderBy != nil {
    85  		subquery.OrderBy.Accept(columnRewritter)
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  // 处理from table和join on部分
    92  // 主要是改写table ExprNode, 并找到路由条件
    93  func handleSubqueryTableRefs(p *TableAliasStmtInfo, stmt *ast.SelectStmt) error {
    94  	tableRefs := stmt.From
    95  	if tableRefs == nil {
    96  		return nil
    97  	}
    98  
    99  	join := tableRefs.TableRefs
   100  	if join == nil {
   101  		return nil
   102  	}
   103  
   104  	if err := handleSubqueryJoin(p, join); err != nil {
   105  		return fmt.Errorf("handleSubqueryTableRefs error: %v", err)
   106  	}
   107  	return nil
   108  }
   109  
   110  func handleSubqueryJoin(p *TableAliasStmtInfo, join *ast.Join) error {
   111  	if err := precheckJoinClause(join); err != nil {
   112  		return fmt.Errorf("precheck Join error: %v", err)
   113  	}
   114  
   115  	// 只允许最多两个表的JOIN
   116  	if join.Left != nil {
   117  		switch left := join.Left.(type) {
   118  		case *ast.TableSource:
   119  			// 改写两个表的node
   120  			if err := rewriteSubqueryTableSource(p, left); err != nil {
   121  				return fmt.Errorf("rewrite left TableSource error: %v", err)
   122  			}
   123  		case *ast.Join:
   124  			if err := handleSubqueryJoin(p, left); err != nil {
   125  				return fmt.Errorf("handle nested left Join error: %v", err)
   126  			}
   127  		default:
   128  			return fmt.Errorf("invalid left Join type: %T", join.Left)
   129  		}
   130  	}
   131  	if join.Right != nil {
   132  		right, ok := join.Right.(*ast.TableSource)
   133  		if !ok {
   134  			return fmt.Errorf("right is not TableSource, type: %T", join.Right)
   135  		}
   136  
   137  		if err := rewriteSubqueryTableSource(p, right); err != nil {
   138  			return fmt.Errorf("rewrite right TableSource error: %v", err)
   139  		}
   140  	}
   141  
   142  	// 只改写表名, 不计算路由
   143  	if join.On != nil {
   144  		rewritter := NewSubqueryColumnNameRewriteVisitor(p)
   145  		join.On.Accept(rewritter)
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  // gaea规定在子查询的FROM表名中不能再出现子查询
   152  func rewriteSubqueryTableSource(p *TableAliasStmtInfo, tableSource *ast.TableSource) error {
   153  	switch tableSource.Source.(type) {
   154  	case *ast.TableName:
   155  		return rewriteSubqueryTableNameInTableSource(p, tableSource)
   156  	case *ast.SelectStmt:
   157  		return fmt.Errorf("cannot handle subquery in subquery")
   158  	default:
   159  		return fmt.Errorf("field Source cannot handle, type: %T", tableSource.Source)
   160  	}
   161  }
   162  
   163  func rewriteSubqueryTableNameInTableSource(p *TableAliasStmtInfo, tableSource *ast.TableSource) error {
   164  	tableName, ok := tableSource.Source.(*ast.TableName)
   165  	if !ok {
   166  		return fmt.Errorf("field Source is not type of TableName, type: %T", tableSource.Source)
   167  	}
   168  
   169  	// 不记录子查询的表名alias
   170  	rule, need, err := NeedCreateTableNameDecorator(p, tableName, "")
   171  	if err != nil {
   172  		return fmt.Errorf("check NeedCreateTableNameDecorator error: %v", err)
   173  	}
   174  
   175  	if !need {
   176  		return nil
   177  	}
   178  
   179  	// 这是一个分片表或关联表, 创建一个TableName的装饰器, 并替换原有节点
   180  	d, err := CreateTableNameDecorator(tableName, rule, p.GetRouteResult())
   181  	if err != nil {
   182  		return fmt.Errorf("create TableNameDecorator error: %v", err)
   183  	}
   184  	tableSource.Source = d
   185  	return nil
   186  }