vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/logical.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package evalengine
    18  
    19  import (
    20  	"vitess.io/vitess/go/mysql/collations"
    21  	"vitess.io/vitess/go/sqltypes"
    22  	"vitess.io/vitess/go/vt/sqlparser"
    23  )
    24  
    25  type (
    26  	LogicalOp interface {
    27  		eval(left, right EvalResult) (boolean, error)
    28  		String()
    29  	}
    30  
    31  	LogicalExpr struct {
    32  		BinaryExpr
    33  		op     func(left, right boolean) boolean
    34  		opname string
    35  	}
    36  	NotExpr struct {
    37  		UnaryExpr
    38  	}
    39  
    40  	OpLogicalAnd struct{}
    41  
    42  	boolean int8
    43  
    44  	// IsExpr represents the IS expression in MySQL.
    45  	// boolean_primary IS [NOT] {TRUE | FALSE | NULL}
    46  	IsExpr struct {
    47  		UnaryExpr
    48  		Op    sqlparser.IsExprOperator
    49  		Check func(*EvalResult) bool
    50  	}
    51  
    52  	WhenThen struct {
    53  		when Expr
    54  		then Expr
    55  	}
    56  
    57  	CaseExpr struct {
    58  		cases []WhenThen
    59  		Else  Expr
    60  	}
    61  )
    62  
    63  const (
    64  	boolFalse boolean = 0
    65  	boolTrue  boolean = 1
    66  	boolNULL  boolean = -1
    67  )
    68  
    69  func makeboolean(b bool) boolean {
    70  	if b {
    71  		return boolTrue
    72  	}
    73  	return boolFalse
    74  }
    75  
    76  func makeboolean2(b, isNull bool) boolean {
    77  	if isNull {
    78  		return boolNULL
    79  	}
    80  	return makeboolean(b)
    81  }
    82  
    83  func (left boolean) not() boolean {
    84  	switch left {
    85  	case boolFalse:
    86  		return boolTrue
    87  	case boolTrue:
    88  		return boolFalse
    89  	default:
    90  		return left
    91  	}
    92  }
    93  
    94  func (left boolean) and(right boolean) boolean {
    95  	// Logical AND.
    96  	// Evaluates to 1 if all operands are nonzero and not NULL, to 0 if one or more operands are 0, otherwise NULL is returned.
    97  	switch {
    98  	case left == boolTrue && right == boolTrue:
    99  		return boolTrue
   100  	case left == boolFalse || right == boolFalse:
   101  		return boolFalse
   102  	default:
   103  		return boolNULL
   104  	}
   105  }
   106  
   107  func (left boolean) or(right boolean) boolean {
   108  	// Logical OR. When both operands are non-NULL, the result is 1 if any operand is nonzero, and 0 otherwise.
   109  	// With a NULL operand, the result is 1 if the other operand is nonzero, and NULL otherwise.
   110  	// If both operands are NULL, the result is NULL.
   111  	switch {
   112  	case left == boolNULL:
   113  		if right == boolTrue {
   114  			return boolTrue
   115  		}
   116  		return boolNULL
   117  
   118  	case right == boolNULL:
   119  		if left == boolTrue {
   120  			return boolTrue
   121  		}
   122  		return boolNULL
   123  
   124  	default:
   125  		if left == boolTrue || right == boolTrue {
   126  			return boolTrue
   127  		}
   128  		return boolFalse
   129  	}
   130  }
   131  
   132  func (left boolean) xor(right boolean) boolean {
   133  	// Logical XOR. Returns NULL if either operand is NULL.
   134  	// For non-NULL operands, evaluates to 1 if an odd number of operands is nonzero, otherwise 0 is returned.
   135  	switch {
   136  	case left == boolNULL || right == boolNULL:
   137  		return boolNULL
   138  	default:
   139  		if left != right {
   140  			return boolTrue
   141  		}
   142  		return boolFalse
   143  	}
   144  }
   145  
   146  func (n *NotExpr) eval(env *ExpressionEnv, out *EvalResult) {
   147  	var inner EvalResult
   148  	inner.init(env, n.Inner)
   149  	out.setBoolean(inner.isTruthy().not())
   150  }
   151  
   152  func (n *NotExpr) typeof(env *ExpressionEnv) (sqltypes.Type, flag) {
   153  	_, flags := n.Inner.typeof(env)
   154  	return sqltypes.Uint64, flags
   155  }
   156  
   157  func (l *LogicalExpr) eval(env *ExpressionEnv, out *EvalResult) {
   158  	var left, right EvalResult
   159  	left.init(env, l.Left)
   160  	right.init(env, l.Right)
   161  	if left.typeof() == sqltypes.Tuple || right.typeof() == sqltypes.Tuple {
   162  		panic("did not typecheck tuples")
   163  	}
   164  	out.setBoolean(l.op(left.isTruthy(), right.isTruthy()))
   165  }
   166  
   167  func (l *LogicalExpr) typeof(env *ExpressionEnv) (sqltypes.Type, flag) {
   168  	_, f1 := l.Left.typeof(env)
   169  	_, f2 := l.Right.typeof(env)
   170  	return sqltypes.Uint64, f1 | f2
   171  }
   172  
   173  func (i *IsExpr) eval(env *ExpressionEnv, result *EvalResult) {
   174  	var in EvalResult
   175  	in.init(env, i.Inner)
   176  	result.setBool(i.Check(&in))
   177  }
   178  
   179  func (i *IsExpr) typeof(env *ExpressionEnv) (sqltypes.Type, flag) {
   180  	return sqltypes.Int64, 0
   181  }
   182  
   183  func (c *CaseExpr) eval(env *ExpressionEnv, result *EvalResult) {
   184  	var tmp EvalResult
   185  	var matched = false
   186  	var ca collationAggregation
   187  	var local = collations.Local()
   188  
   189  	// From what we can tell, MySQL actually evaluates all the branches
   190  	// of a CASE expression, even after a truthy match. I.e. the CASE
   191  	// operator does _not_ short-circuit.
   192  
   193  	for _, whenThen := range c.cases {
   194  		tmp.init(env, whenThen.when)
   195  		truthy := tmp.isTruthy() == boolTrue
   196  
   197  		tmp.init(env, whenThen.then)
   198  		ca.add(local, tmp.collation())
   199  
   200  		if !matched && truthy {
   201  			tmp.resolve()
   202  			*result = tmp
   203  			matched = true
   204  		}
   205  	}
   206  	if c.Else != nil {
   207  		tmp.init(env, c.Else)
   208  		ca.add(local, tmp.collation())
   209  
   210  		if !matched {
   211  			tmp.resolve()
   212  			*result = tmp
   213  			matched = true
   214  		}
   215  	}
   216  
   217  	if !matched {
   218  		result.setNull()
   219  	}
   220  
   221  	t, _ := c.typeof(env)
   222  	result.coerce(t, ca.result().Collation)
   223  }
   224  
   225  func (c *CaseExpr) typeof(env *ExpressionEnv) (sqltypes.Type, flag) {
   226  	var ta typeAggregation
   227  	var resultFlag flag
   228  
   229  	for _, whenthen := range c.cases {
   230  		t, f := whenthen.then.typeof(env)
   231  		ta.add(t, f)
   232  		resultFlag = resultFlag | f
   233  	}
   234  	if c.Else != nil {
   235  		t, f := c.Else.typeof(env)
   236  		ta.add(t, f)
   237  		resultFlag = f
   238  	}
   239  	return ta.result(), resultFlag
   240  }
   241  
   242  func (c *CaseExpr) format(buf *formatter, depth int) {
   243  	buf.WriteString("CASE")
   244  	for _, cs := range c.cases {
   245  		buf.WriteString(" WHEN ")
   246  		cs.when.format(buf, depth)
   247  		buf.WriteString(" THEN ")
   248  		cs.then.format(buf, depth)
   249  	}
   250  	if c.Else != nil {
   251  		buf.WriteString(" ELSE ")
   252  		c.Else.format(buf, depth)
   253  	}
   254  }
   255  
   256  func (c *CaseExpr) constant() bool {
   257  	// TODO we should be able to simplify more cases than constant/simplify allows us to today
   258  	// example: case when true then col end
   259  	if c.Else != nil {
   260  		if !c.Else.constant() {
   261  			return false
   262  		}
   263  	}
   264  
   265  	for _, then := range c.cases {
   266  		if !then.when.constant() || !then.then.constant() {
   267  			return false
   268  		}
   269  	}
   270  
   271  	return true
   272  }
   273  
   274  func (c *CaseExpr) simplify(env *ExpressionEnv) error {
   275  	var err error
   276  	for i := range c.cases {
   277  		whenThen := &c.cases[i]
   278  		whenThen.when, err = simplifyExpr(env, whenThen.when)
   279  		if err != nil {
   280  			return err
   281  		}
   282  		whenThen.then, err = simplifyExpr(env, whenThen.then)
   283  		if err != nil {
   284  			return err
   285  		}
   286  	}
   287  	if c.Else != nil {
   288  		c.Else, err = simplifyExpr(env, c.Else)
   289  	}
   290  	return err
   291  }
   292  
   293  var _ Expr = (*CaseExpr)(nil)