github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/optimizer/plan/predicate_push_down.go (about)

     1  // Copyright 2016 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package plan
    15  
    16  import (
    17  	"github.com/insionng/yougam/libraries/juju/errors"
    18  	"github.com/insionng/yougam/libraries/pingcap/tidb/ast"
    19  )
    20  
    21  func addFilter(p Plan, child Plan, conditions []ast.ExprNode) error {
    22  	filter := &Filter{Conditions: conditions}
    23  	return InsertPlan(p, child, filter)
    24  }
    25  
    26  // columnSubstituor substitutes the columns in filter to expressions in select fields.
    27  // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k.
    28  type columnSubstitutor struct {
    29  	fields []*ast.ResultField
    30  }
    31  
    32  func (cl *columnSubstitutor) Enter(inNode ast.Node) (node ast.Node, skipChild bool) {
    33  	return inNode, false
    34  }
    35  
    36  func (cl *columnSubstitutor) Leave(inNode ast.Node) (node ast.Node, ok bool) {
    37  	switch v := inNode.(type) {
    38  	case *ast.ColumnNameExpr:
    39  		for _, field := range cl.fields {
    40  			if v.Refer == field {
    41  				return field.Expr, true
    42  			}
    43  		}
    44  	}
    45  	return inNode, true
    46  }
    47  
    48  // PredicatePushDown applies predicate push down to all kinds of plans, except aggregation and union.
    49  func PredicatePushDown(p Plan, predicates []ast.ExprNode) (ret []ast.ExprNode, err error) {
    50  	switch v := p.(type) {
    51  	case *TableScan:
    52  		v.attachCondition(predicates)
    53  		return ret, nil
    54  	case *Filter:
    55  		conditions := v.Conditions
    56  		retConditions, err1 := PredicatePushDown(p.GetChildByIndex(0), append(conditions, predicates...))
    57  		if err1 != nil {
    58  			return nil, errors.Trace(err1)
    59  		}
    60  		if len(retConditions) > 0 {
    61  			v.Conditions = retConditions
    62  		} else {
    63  			if len(p.GetParents()) == 0 {
    64  				return ret, nil
    65  			}
    66  			err1 = RemovePlan(p)
    67  			if err1 != nil {
    68  				return nil, errors.Trace(err1)
    69  			}
    70  		}
    71  		return ret, nil
    72  	case *Join:
    73  		//TODO: add null rejecter
    74  		var leftCond, rightCond []ast.ExprNode
    75  		leftPlan := v.GetChildByIndex(0)
    76  		rightPlan := v.GetChildByIndex(1)
    77  		equalCond, leftPushCond, rightPushCond, otherCond := extractOnCondition(predicates, leftPlan, rightPlan)
    78  		if v.JoinType == LeftOuterJoin {
    79  			rightCond = v.RightConditions
    80  			leftCond = leftPushCond
    81  			ret = append(equalCond, otherCond...)
    82  			ret = append(ret, rightPushCond...)
    83  		} else if v.JoinType == RightOuterJoin {
    84  			leftCond = v.LeftConditions
    85  			rightCond = rightPushCond
    86  			ret = append(equalCond, otherCond...)
    87  			ret = append(ret, leftPushCond...)
    88  		} else {
    89  			leftCond = append(v.LeftConditions, leftPushCond...)
    90  			rightCond = append(v.RightConditions, rightPushCond...)
    91  		}
    92  		leftRet, err1 := PredicatePushDown(leftPlan, leftCond)
    93  		if err1 != nil {
    94  			return nil, errors.Trace(err1)
    95  		}
    96  		rightRet, err2 := PredicatePushDown(rightPlan, rightCond)
    97  		if err2 != nil {
    98  			return nil, errors.Trace(err2)
    99  		}
   100  		if len(leftRet) > 0 {
   101  			err2 = addFilter(p, leftPlan, leftRet)
   102  			if err2 != nil {
   103  				return nil, errors.Trace(err2)
   104  			}
   105  		}
   106  		if len(rightRet) > 0 {
   107  			err2 = addFilter(p, rightPlan, rightRet)
   108  			if err2 != nil {
   109  				return nil, errors.Trace(err2)
   110  			}
   111  		}
   112  		if v.JoinType == InnerJoin {
   113  			v.EqualConditions = append(v.EqualConditions, equalCond...)
   114  			v.OtherConditions = append(v.OtherConditions, otherCond...)
   115  		}
   116  		return ret, nil
   117  	case *SelectFields:
   118  		if len(v.GetChildren()) == 0 {
   119  			return predicates, nil
   120  		}
   121  		cs := &columnSubstitutor{fields: v.Fields()}
   122  		var push []ast.ExprNode
   123  		for _, cond := range predicates {
   124  			ce := &columnsExtractor{}
   125  			ok := true
   126  			cond.Accept(ce)
   127  			for _, col := range ce.result {
   128  				match := false
   129  				for _, field := range v.Fields() {
   130  					if col.Refer == field {
   131  						switch field.Expr.(type) {
   132  						case *ast.ColumnNameExpr:
   133  							match = true
   134  						}
   135  						break
   136  					}
   137  				}
   138  				if !match {
   139  					ok = false
   140  					break
   141  				}
   142  			}
   143  			if ok {
   144  				cond1, _ := cond.Accept(cs)
   145  				cond = cond1.(ast.ExprNode)
   146  				push = append(push, cond)
   147  			} else {
   148  				ret = append(ret, cond)
   149  			}
   150  		}
   151  		restConds, err1 := PredicatePushDown(v.GetChildByIndex(0), push)
   152  		if err1 != nil {
   153  			return nil, errors.Trace(err1)
   154  		}
   155  		if len(restConds) > 0 {
   156  			err1 = addFilter(v, v.GetChildByIndex(0), restConds)
   157  			if err1 != nil {
   158  				return nil, errors.Trace(err1)
   159  			}
   160  		}
   161  		return ret, nil
   162  	case *Sort, *Limit, *Distinct:
   163  		rest, err1 := PredicatePushDown(p.GetChildByIndex(0), predicates)
   164  		if err1 != nil {
   165  			return nil, errors.Trace(err1)
   166  		}
   167  		if len(rest) > 0 {
   168  			err1 = addFilter(p, p.GetChildByIndex(0), rest)
   169  			if err1 != nil {
   170  				return nil, errors.Trace(err1)
   171  			}
   172  		}
   173  		return ret, nil
   174  	default:
   175  		if len(v.GetChildren()) == 0 {
   176  			return predicates, nil
   177  		}
   178  		//TODO: support union and sub queries when abandon result field.
   179  		for _, child := range v.GetChildren() {
   180  			_, err = PredicatePushDown(child, []ast.ExprNode{})
   181  			if err != nil {
   182  				return nil, errors.Trace(err)
   183  			}
   184  		}
   185  		return predicates, nil
   186  	}
   187  }