github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/partition_pruner.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package memex
    15  
    16  import (
    17  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    18  	"github.com/whtcorpsinc/milevadb/stochastikctx"
    19  	"github.com/whtcorpsinc/milevadb/types"
    20  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    21  	"github.com/whtcorpsinc/milevadb/soliton/disjointset"
    22  )
    23  
    24  type hashPartitionPruner struct {
    25  	unionSet    *disjointset.IntSet // unionSet stores the relations like defCaus_i = defCaus_j
    26  	constantMap []*Constant
    27  	conditions  []Expression
    28  	defCausMapper   map[int64]int
    29  	numDeferredCauset   int
    30  	ctx         stochastikctx.Context
    31  }
    32  
    33  func (p *hashPartitionPruner) getDefCausID(defCaus *DeferredCauset) int {
    34  	return p.defCausMapper[defCaus.UniqueID]
    35  }
    36  
    37  func (p *hashPartitionPruner) insertDefCaus(defCaus *DeferredCauset) {
    38  	_, ok := p.defCausMapper[defCaus.UniqueID]
    39  	if !ok {
    40  		p.numDeferredCauset += 1
    41  		p.defCausMapper[defCaus.UniqueID] = len(p.defCausMapper)
    42  	}
    43  }
    44  
    45  func (p *hashPartitionPruner) reduceDeferredCausetEQ() bool {
    46  	p.unionSet = disjointset.NewIntSet(p.numDeferredCauset)
    47  	for i := range p.conditions {
    48  		if fun, ok := p.conditions[i].(*ScalarFunction); ok && fun.FuncName.L == ast.EQ {
    49  			lDefCaus, lOk := fun.GetArgs()[0].(*DeferredCauset)
    50  			rDefCaus, rOk := fun.GetArgs()[1].(*DeferredCauset)
    51  			if lOk && rOk {
    52  				lID := p.getDefCausID(lDefCaus)
    53  				rID := p.getDefCausID(rDefCaus)
    54  				p.unionSet.Union(lID, rID)
    55  			}
    56  		}
    57  	}
    58  	for i := 0; i < p.numDeferredCauset; i++ {
    59  		father := p.unionSet.FindRoot(i)
    60  		if p.constantMap[i] != nil {
    61  			if p.constantMap[father] != nil {
    62  				// May has conflict here.
    63  				if !p.constantMap[father].Equal(p.ctx, p.constantMap[i]) {
    64  					return true
    65  				}
    66  			} else {
    67  				p.constantMap[father] = p.constantMap[i]
    68  			}
    69  		}
    70  	}
    71  	for i := 0; i < p.numDeferredCauset; i++ {
    72  		father := p.unionSet.FindRoot(i)
    73  		if p.constantMap[father] != nil && p.constantMap[i] == nil {
    74  			p.constantMap[i] = p.constantMap[father]
    75  		}
    76  	}
    77  	return false
    78  }
    79  
    80  func (p *hashPartitionPruner) reduceConstantEQ() bool {
    81  	for _, con := range p.conditions {
    82  		var defCaus *DeferredCauset
    83  		var cond *Constant
    84  		if fn, ok := con.(*ScalarFunction); ok {
    85  			if fn.FuncName.L == ast.IsNull {
    86  				defCaus, ok = fn.GetArgs()[0].(*DeferredCauset)
    87  				if ok {
    88  					cond = NewNull()
    89  				}
    90  			} else {
    91  				defCaus, cond = validEqualCond(p.ctx, con)
    92  			}
    93  		}
    94  		if defCaus != nil {
    95  			id := p.getDefCausID(defCaus)
    96  			if p.constantMap[id] != nil {
    97  				if p.constantMap[id].Equal(p.ctx, cond) {
    98  					continue
    99  				}
   100  				return true
   101  			}
   102  			p.constantMap[id] = cond
   103  		}
   104  	}
   105  	return false
   106  }
   107  
   108  func (p *hashPartitionPruner) tryEvalPartitionExpr(piExpr Expression) (val int64, success bool, isNull bool) {
   109  	switch pi := piExpr.(type) {
   110  	case *ScalarFunction:
   111  		if pi.FuncName.L == ast.Plus || pi.FuncName.L == ast.Minus || pi.FuncName.L == ast.Mul || pi.FuncName.L == ast.Div {
   112  			left, right := pi.GetArgs()[0], pi.GetArgs()[1]
   113  			leftVal, ok, isNull := p.tryEvalPartitionExpr(left)
   114  			if !ok || isNull {
   115  				return 0, ok, isNull
   116  			}
   117  			rightVal, ok, isNull := p.tryEvalPartitionExpr(right)
   118  			if !ok || isNull {
   119  				return 0, ok, isNull
   120  			}
   121  			switch pi.FuncName.L {
   122  			case ast.Plus:
   123  				return rightVal + leftVal, true, false
   124  			case ast.Minus:
   125  				return rightVal - leftVal, true, false
   126  			case ast.Mul:
   127  				return rightVal * leftVal, true, false
   128  			case ast.Div:
   129  				return rightVal / leftVal, true, false
   130  			}
   131  		} else if pi.FuncName.L == ast.Year || pi.FuncName.L == ast.Month || pi.FuncName.L == ast.ToDays {
   132  			defCaus := pi.GetArgs()[0].(*DeferredCauset)
   133  			idx := p.getDefCausID(defCaus)
   134  			val := p.constantMap[idx]
   135  			if val != nil {
   136  				pi.GetArgs()[0] = val
   137  				ret, isNull, err := pi.EvalInt(p.ctx, chunk.Row{})
   138  				if err != nil {
   139  					return 0, false, false
   140  				}
   141  				return ret, true, isNull
   142  			}
   143  			return 0, false, false
   144  		}
   145  	case *Constant:
   146  		val, err := pi.Eval(chunk.Row{})
   147  		if err != nil {
   148  			return 0, false, false
   149  		}
   150  		if val.IsNull() {
   151  			return 0, true, true
   152  		}
   153  		if val.HoTT() == types.HoTTInt64 {
   154  			return val.GetInt64(), true, false
   155  		} else if val.HoTT() == types.HoTTUint64 {
   156  			return int64(val.GetUint64()), true, false
   157  		}
   158  	case *DeferredCauset:
   159  		// Look up map
   160  		idx := p.getDefCausID(pi)
   161  		val := p.constantMap[idx]
   162  		if val != nil {
   163  			return p.tryEvalPartitionExpr(val)
   164  		}
   165  		return 0, false, false
   166  	}
   167  	return 0, false, false
   168  }
   169  
   170  func newHashPartitionPruner() *hashPartitionPruner {
   171  	pruner := &hashPartitionPruner{}
   172  	pruner.defCausMapper = make(map[int64]int)
   173  	pruner.numDeferredCauset = 0
   174  	return pruner
   175  }
   176  
   177  // solve eval the hash partition memex, the first return value represent the result of partition memex. The second
   178  // return value is whether eval success. The third return value represent whether the query conditions is always false.
   179  func (p *hashPartitionPruner) solve(ctx stochastikctx.Context, conds []Expression, piExpr Expression) (val int64, ok bool, isAlwaysFalse bool) {
   180  	p.ctx = ctx
   181  	for _, cond := range conds {
   182  		p.conditions = append(p.conditions, SplitCNFItems(cond)...)
   183  		for _, defCaus := range ExtractDeferredCausets(cond) {
   184  			p.insertDefCaus(defCaus)
   185  		}
   186  	}
   187  	for _, defCaus := range ExtractDeferredCausets(piExpr) {
   188  		p.insertDefCaus(defCaus)
   189  	}
   190  	p.constantMap = make([]*Constant, p.numDeferredCauset)
   191  	isAlwaysFalse = p.reduceConstantEQ()
   192  	if isAlwaysFalse {
   193  		return 0, false, isAlwaysFalse
   194  	}
   195  	isAlwaysFalse = p.reduceDeferredCausetEQ()
   196  	if isAlwaysFalse {
   197  		return 0, false, isAlwaysFalse
   198  	}
   199  	res, ok, isNull := p.tryEvalPartitionExpr(piExpr)
   200  	if isNull && ok {
   201  		return 0, ok, false
   202  	}
   203  	return res, ok, false
   204  }
   205  
   206  // FastLocateHashPartition is used to get hash partition quickly.
   207  func FastLocateHashPartition(ctx stochastikctx.Context, conds []Expression, piExpr Expression) (int64, bool, bool) {
   208  	return newHashPartitionPruner().solve(ctx, conds, piExpr)
   209  }