vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/expressions.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package evalengine
    18  
    19  import (
    20  	"encoding/hex"
    21  	"fmt"
    22  	"math"
    23  	"strconv"
    24  	"unicode/utf8"
    25  
    26  	"vitess.io/vitess/go/mysql/collations"
    27  	"vitess.io/vitess/go/sqltypes"
    28  	querypb "vitess.io/vitess/go/vt/proto/query"
    29  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    30  	"vitess.io/vitess/go/vt/sqlparser"
    31  	"vitess.io/vitess/go/vt/vterrors"
    32  	"vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal"
    33  )
    34  
    35  type (
    36  	// ExpressionEnv contains the environment that the expression
    37  	// evaluates in, such as the current row and bindvars
    38  	ExpressionEnv struct {
    39  		BindVars         map[string]*querypb.BindVariable
    40  		DefaultCollation collations.ID
    41  
    42  		// Row and Fields should line up
    43  		Row    []sqltypes.Value
    44  		Fields []*querypb.Field
    45  	}
    46  
    47  	// Expr is the interface that all evaluating expressions must implement
    48  	Expr interface {
    49  		eval(env *ExpressionEnv, result *EvalResult)
    50  		typeof(env *ExpressionEnv) (sqltypes.Type, flag)
    51  		format(buf *formatter, depth int)
    52  		constant() bool
    53  		simplify(env *ExpressionEnv) error
    54  	}
    55  
    56  	Literal struct {
    57  		Val EvalResult
    58  	}
    59  
    60  	BindVariable struct {
    61  		Key        string
    62  		coll       collations.TypedCollation
    63  		coerceType sqltypes.Type
    64  	}
    65  
    66  	Column struct {
    67  		Offset int
    68  		coll   collations.TypedCollation
    69  	}
    70  
    71  	TupleExpr []Expr
    72  
    73  	CollateExpr struct {
    74  		UnaryExpr
    75  		TypedCollation collations.TypedCollation
    76  	}
    77  
    78  	BinaryExpr struct {
    79  		Left, Right Expr
    80  	}
    81  )
    82  
    83  func (expr *BinaryExpr) LeftExpr() Expr {
    84  	return expr.Left
    85  }
    86  
    87  func (expr *BinaryExpr) RightExpr() Expr {
    88  	return expr.Right
    89  }
    90  
    91  var _ Expr = (*Literal)(nil)
    92  var _ Expr = (*BindVariable)(nil)
    93  var _ Expr = (*Column)(nil)
    94  var _ Expr = (*ArithmeticExpr)(nil)
    95  var _ Expr = (*ComparisonExpr)(nil)
    96  var _ Expr = (*InExpr)(nil)
    97  var _ Expr = (*IsExpr)(nil)
    98  var _ Expr = (*LikeExpr)(nil)
    99  var _ Expr = (TupleExpr)(nil)
   100  var _ Expr = (*CollateExpr)(nil)
   101  var _ Expr = (*LogicalExpr)(nil)
   102  var _ Expr = (*NotExpr)(nil)
   103  var _ Expr = (*CallExpr)(nil)
   104  var _ Expr = (*WeightStringCallExpr)(nil)
   105  var _ Expr = (*BitwiseExpr)(nil)
   106  var _ Expr = (*BitwiseNotExpr)(nil)
   107  var _ Expr = (*ConvertExpr)(nil)
   108  var _ Expr = (*ConvertUsingExpr)(nil)
   109  
   110  type evalError struct {
   111  	error
   112  }
   113  
   114  func throwEvalError(err error) {
   115  	panic(evalError{err})
   116  }
   117  
   118  func throwCardinalityError(expected int) {
   119  	panic(evalError{
   120  		vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.OperandColumns, "Operand should contain %d column(s)", expected),
   121  	})
   122  }
   123  
   124  func (env *ExpressionEnv) cardinality(expr Expr) int {
   125  	switch expr := expr.(type) {
   126  	case *BindVariable:
   127  		tt, _ := expr.typeof(env)
   128  		if tt == sqltypes.Tuple {
   129  			return len(expr.bvar(env).Values)
   130  		}
   131  		return 1
   132  
   133  	case TupleExpr:
   134  		return len(expr)
   135  
   136  	default:
   137  		return 1
   138  	}
   139  }
   140  
   141  func (env *ExpressionEnv) ensureCardinality(expr Expr, expected int) {
   142  	if env.cardinality(expr) != expected {
   143  		throwCardinalityError(expected)
   144  	}
   145  }
   146  
   147  func (env *ExpressionEnv) subexpr(expr Expr, nth int) (Expr, int) {
   148  	switch expr := expr.(type) {
   149  	case *BindVariable:
   150  		tt, _ := expr.typeof(env)
   151  		if tt == sqltypes.Tuple {
   152  			return nil, 1
   153  		}
   154  	case *Literal:
   155  		if expr.Val.typeof() == sqltypes.Tuple {
   156  			return nil, 1
   157  		}
   158  	case TupleExpr:
   159  		return expr[nth], env.cardinality(expr[nth])
   160  	}
   161  	panic("subexpr called on non-tuple")
   162  }
   163  
   164  func (env *ExpressionEnv) typecheckComparison(expr1 Expr, card1 int, expr2 Expr, card2 int) {
   165  	switch {
   166  	case card1 == 1 && card2 == 1:
   167  		env.typecheck(expr1)
   168  		env.typecheck(expr2)
   169  	case card1 == card2:
   170  		for n := 0; n < card1; n++ {
   171  			left1, leftcard1 := env.subexpr(expr1, n)
   172  			right1, rightcard1 := env.subexpr(expr2, n)
   173  			env.typecheckComparison(left1, leftcard1, right1, rightcard1)
   174  		}
   175  	default:
   176  		env.typecheck(expr1)
   177  		env.typecheck(expr2)
   178  		throwCardinalityError(card1)
   179  	}
   180  }
   181  
   182  func (env *ExpressionEnv) typecheckBinary(left, right Expr) {
   183  	env.typecheck(left)
   184  	env.ensureCardinality(left, 1)
   185  
   186  	env.typecheck(right)
   187  	env.ensureCardinality(right, 1)
   188  }
   189  
   190  func (env *ExpressionEnv) typecheckUnary(inner Expr) {
   191  	env.typecheck(inner)
   192  	env.ensureCardinality(inner, 1)
   193  }
   194  
   195  func (env *ExpressionEnv) typecheck(expr Expr) {
   196  	if expr == nil {
   197  		return
   198  	}
   199  
   200  	switch expr := expr.(type) {
   201  	case *ConvertExpr:
   202  		env.typecheckUnary(expr.Inner)
   203  	case *ConvertUsingExpr:
   204  		env.typecheckUnary(expr.Inner)
   205  	case *NegateExpr:
   206  		env.typecheckUnary(expr.Inner)
   207  	case *CollateExpr:
   208  		env.typecheckUnary(expr.Inner)
   209  	case *IsExpr:
   210  		env.typecheckUnary(expr.Inner)
   211  	case *BitwiseNotExpr:
   212  		env.typecheckUnary(expr.Inner)
   213  	case *WeightStringCallExpr:
   214  		env.typecheckUnary(expr.String)
   215  	case *ArithmeticExpr:
   216  		env.typecheckBinary(expr.Left, expr.Right)
   217  	case *LogicalExpr:
   218  		env.typecheckBinary(expr.Left, expr.Right)
   219  	case *BitwiseExpr:
   220  		env.typecheckBinary(expr.Left, expr.Right)
   221  	case *LikeExpr:
   222  		env.typecheckBinary(expr.Left, expr.Right)
   223  	case *ComparisonExpr:
   224  		left := env.cardinality(expr.Left)
   225  		right := env.cardinality(expr.Right)
   226  		env.typecheckComparison(expr.Left, left, expr.Right, right)
   227  	case *InExpr:
   228  		env.typecheck(expr.Left)
   229  		left := env.cardinality(expr.Left)
   230  		right := env.cardinality(expr.Right)
   231  
   232  		tt, _ := expr.Right.typeof(env)
   233  		if tt != sqltypes.Tuple {
   234  			throwEvalError(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple"))
   235  		}
   236  
   237  		for n := 0; n < right; n++ {
   238  			subexpr, subcard := env.subexpr(expr.Right, n)
   239  			env.typecheck(subexpr)
   240  			if left != subcard {
   241  				throwCardinalityError(left)
   242  			}
   243  		}
   244  	case TupleExpr:
   245  		for _, subexpr := range expr {
   246  			env.typecheck(subexpr)
   247  		}
   248  	case *CallExpr:
   249  		env.typecheck(expr.Arguments)
   250  	case *Literal, *Column, *BindVariable, *CaseExpr: // noop
   251  	default:
   252  		panic(fmt.Sprintf("unhandled cardinality: %T", expr))
   253  	}
   254  }
   255  
   256  func (env *ExpressionEnv) Evaluate(expr Expr) (er EvalResult, err error) {
   257  	if env == nil {
   258  		panic("ExpressionEnv == nil")
   259  	}
   260  	defer func() {
   261  		if r := recover(); r != nil {
   262  			if ee, ok := r.(evalError); ok {
   263  				err = ee.error
   264  			} else {
   265  				panic(r)
   266  			}
   267  		}
   268  	}()
   269  	env.typecheck(expr)
   270  	expr.eval(env, &er)
   271  	return
   272  }
   273  
   274  func (env *ExpressionEnv) TypeOf(expr Expr) (ty sqltypes.Type, err error) {
   275  	defer func() {
   276  		if r := recover(); r != nil {
   277  			if ee, ok := r.(evalError); ok {
   278  				err = ee.error
   279  			} else {
   280  				panic(r)
   281  			}
   282  		}
   283  	}()
   284  	ty, _ = expr.typeof(env)
   285  	return
   286  }
   287  
   288  // EmptyExpressionEnv returns a new ExpressionEnv with no bind vars or row
   289  func EmptyExpressionEnv() *ExpressionEnv {
   290  	return EnvWithBindVars(map[string]*querypb.BindVariable{}, collations.Unknown)
   291  }
   292  
   293  // EnvWithBindVars returns an expression environment with no current row, but with bindvars
   294  func EnvWithBindVars(bindVars map[string]*querypb.BindVariable, coll collations.ID) *ExpressionEnv {
   295  	if coll == collations.Unknown {
   296  		coll = collations.Default()
   297  	}
   298  	return &ExpressionEnv{BindVars: bindVars, DefaultCollation: coll}
   299  }
   300  
   301  // NullExpr is just what you are lead to believe
   302  var NullExpr = &Literal{}
   303  
   304  func init() {
   305  	NullExpr.Val.setNull()
   306  	NullExpr.Val.replaceCollation(collationNull)
   307  }
   308  
   309  // NewLiteralIntegralFromBytes returns a literal expression.
   310  // It tries to return an int64, but if the value is too large, it tries with an uint64
   311  func NewLiteralIntegralFromBytes(val []byte) (*Literal, error) {
   312  	if val[0] == '-' {
   313  		panic("NewLiteralIntegralFromBytes: negative value")
   314  	}
   315  
   316  	uval, err := strconv.ParseUint(string(val), 10, 64)
   317  	if err != nil {
   318  		if numError, ok := err.(*strconv.NumError); ok && numError.Err == strconv.ErrRange {
   319  			return NewLiteralDecimalFromBytes(val)
   320  		}
   321  		return nil, err
   322  	}
   323  	if uval <= math.MaxInt64 {
   324  		return NewLiteralInt(int64(uval)), nil
   325  	}
   326  	return NewLiteralUint(uval), nil
   327  }
   328  
   329  // NewLiteralInt returns a literal expression
   330  func NewLiteralInt(i int64) *Literal {
   331  	lit := &Literal{}
   332  	lit.Val.setInt64(i)
   333  	return lit
   334  }
   335  
   336  // NewLiteralUint returns a literal expression
   337  func NewLiteralUint(i uint64) *Literal {
   338  	lit := &Literal{}
   339  	lit.Val.setUint64(i)
   340  	return lit
   341  }
   342  
   343  // NewLiteralFloat returns a literal expression
   344  func NewLiteralFloat(val float64) *Literal {
   345  	lit := &Literal{}
   346  	lit.Val.setFloat(val)
   347  	return lit
   348  }
   349  
   350  // NewLiteralFloatFromBytes returns a float literal expression from a slice of bytes
   351  func NewLiteralFloatFromBytes(val []byte) (*Literal, error) {
   352  	lit := &Literal{}
   353  	fval, err := strconv.ParseFloat(string(val), 64)
   354  	if err != nil {
   355  		return nil, err
   356  	}
   357  	lit.Val.setFloat(fval)
   358  	return lit, nil
   359  }
   360  
   361  func NewLiteralDecimalFromBytes(val []byte) (*Literal, error) {
   362  	lit := &Literal{}
   363  	dec, err := decimal.NewFromMySQL(val)
   364  	if err != nil {
   365  		return nil, err
   366  	}
   367  	lit.Val.setDecimal(dec, -dec.Exponent())
   368  	return lit, nil
   369  }
   370  
   371  // NewLiteralString returns a literal expression
   372  func NewLiteralString(val []byte, collation collations.TypedCollation) *Literal {
   373  	collation.Repertoire = collations.RepertoireASCII
   374  	for _, b := range val {
   375  		if b >= utf8.RuneSelf {
   376  			collation.Repertoire = collations.RepertoireUnicode
   377  			break
   378  		}
   379  	}
   380  	lit := &Literal{}
   381  	lit.Val.setRaw(sqltypes.VarChar, val, collation)
   382  	return lit
   383  }
   384  
   385  // NewLiteralDateFromBytes returns a literal expression.
   386  func NewLiteralDateFromBytes(val []byte) (*Literal, error) {
   387  	_, err := sqlparser.ParseDate(string(val))
   388  	if err != nil {
   389  		return nil, err
   390  	}
   391  	lit := &Literal{}
   392  	lit.Val.setRaw(querypb.Type_DATE, val, collationNumeric)
   393  	return lit, nil
   394  }
   395  
   396  // NewLiteralTimeFromBytes returns a literal expression.
   397  // it validates the time by parsing it and checking the error.
   398  func NewLiteralTimeFromBytes(val []byte) (*Literal, error) {
   399  	_, err := sqlparser.ParseTime(string(val))
   400  	if err != nil {
   401  		return nil, err
   402  	}
   403  	lit := &Literal{}
   404  	lit.Val.setRaw(querypb.Type_TIME, val, collationNumeric)
   405  	return lit, nil
   406  }
   407  
   408  // NewLiteralDatetimeFromBytes returns a literal expression.
   409  // it validates the datetime by parsing it and checking the error.
   410  func NewLiteralDatetimeFromBytes(val []byte) (*Literal, error) {
   411  	_, err := sqlparser.ParseDateTime(string(val))
   412  	if err != nil {
   413  		return nil, err
   414  	}
   415  	lit := &Literal{}
   416  	lit.Val.setRaw(querypb.Type_DATETIME, val, collationNumeric)
   417  	return lit, nil
   418  }
   419  
   420  func parseHexLiteral(val []byte) ([]byte, error) {
   421  	raw := make([]byte, hex.DecodedLen(len(val)))
   422  	if _, err := hex.Decode(raw, val); err != nil {
   423  		return nil, err
   424  	}
   425  	return raw, nil
   426  }
   427  
   428  func parseHexNumber(val []byte) ([]byte, error) {
   429  	if val[0] != '0' || val[1] != 'x' {
   430  		panic("malformed hex literal from parser")
   431  	}
   432  	if len(val)%2 == 0 {
   433  		return parseHexLiteral(val[2:])
   434  	}
   435  	// If the hex literal doesn't have an even amount of hex digits, we need
   436  	// to pad it with a '0' in the left. Instead of allocating a new slice
   437  	// for padding pad in-place by replacing the 'x' in the original slice with
   438  	// a '0', and clean it up after parsing.
   439  	val[1] = '0'
   440  	defer func() {
   441  		val[1] = 'x'
   442  	}()
   443  	return parseHexLiteral(val[1:])
   444  }
   445  
   446  func NewLiteralBinary(val []byte) *Literal {
   447  	lit := &Literal{}
   448  	lit.Val.setRaw(sqltypes.VarBinary, val, collationBinary)
   449  	return lit
   450  }
   451  
   452  func NewLiteralBinaryFromHex(val []byte) (*Literal, error) {
   453  	raw, err := parseHexLiteral(val)
   454  	if err != nil {
   455  		return nil, err
   456  	}
   457  	lit := &Literal{}
   458  	lit.Val.setBinaryHex(raw)
   459  	return lit, nil
   460  }
   461  
   462  func NewLiteralBinaryFromHexNum(val []byte) (*Literal, error) {
   463  	raw, err := parseHexNumber(val)
   464  	if err != nil {
   465  		return nil, err
   466  	}
   467  	lit := &Literal{}
   468  	lit.Val.setBinaryHex(raw)
   469  	return lit, nil
   470  }
   471  
   472  // NewBindVar returns a bind variable
   473  func NewBindVar(key string, collation collations.TypedCollation) Expr {
   474  	return &BindVariable{
   475  		Key:        key,
   476  		coll:       collation,
   477  		coerceType: -1,
   478  	}
   479  }
   480  
   481  // NewColumn returns a column expression
   482  func NewColumn(offset int, collation collations.TypedCollation) Expr {
   483  	return &Column{
   484  		Offset: offset,
   485  		coll:   collation,
   486  	}
   487  }
   488  
   489  // NewTupleExpr returns a tuple expression
   490  func NewTupleExpr(exprs ...Expr) TupleExpr {
   491  	tupleExpr := make(TupleExpr, 0, len(exprs))
   492  	for _, f := range exprs {
   493  		tupleExpr = append(tupleExpr, f)
   494  	}
   495  	return tupleExpr
   496  }
   497  
   498  // eval implements the Expr interface
   499  func (l *Literal) eval(_ *ExpressionEnv, result *EvalResult) {
   500  	*result = l.Val
   501  }
   502  
   503  func (t TupleExpr) eval(env *ExpressionEnv, result *EvalResult) {
   504  	var tup = make([]EvalResult, len(t))
   505  	for i, expr := range t {
   506  		tup[i].init(env, expr)
   507  	}
   508  	result.setTuple(tup)
   509  }
   510  
   511  func (bv *BindVariable) bvar(env *ExpressionEnv) *querypb.BindVariable {
   512  	val, ok := env.BindVars[bv.Key]
   513  	if !ok {
   514  		throwEvalError(vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "query arguments missing for %s", bv.Key))
   515  	}
   516  	return val
   517  }
   518  
   519  // eval implements the Expr interface
   520  func (bv *BindVariable) eval(env *ExpressionEnv, result *EvalResult) {
   521  	bvar := bv.bvar(env)
   522  	typ := bvar.Type
   523  	if bv.coerceType >= 0 {
   524  		typ = bv.coerceType
   525  	}
   526  
   527  	switch typ {
   528  	case sqltypes.Tuple:
   529  		tuple := make([]EvalResult, len(bvar.Values))
   530  		for i, value := range bvar.Values {
   531  			if err := tuple[i].setValue(sqltypes.MakeTrusted(value.Type, value.Value), collations.TypedCollation{}); err != nil {
   532  				throwEvalError(err)
   533  			}
   534  		}
   535  		result.setTuple(tuple)
   536  
   537  	default:
   538  		if err := result.setValue(sqltypes.MakeTrusted(typ, bvar.Value), bv.coll); err != nil {
   539  			throwEvalError(err)
   540  		}
   541  	}
   542  }
   543  
   544  // typeof implements the Expr interface
   545  func (bv *BindVariable) typeof(env *ExpressionEnv) (sqltypes.Type, flag) {
   546  	bvar := bv.bvar(env)
   547  	switch bvar.Type {
   548  	case sqltypes.Null:
   549  		return sqltypes.Null, flagNull | flagNullable
   550  	case sqltypes.HexNum, sqltypes.HexVal:
   551  		return sqltypes.VarBinary, flagHex
   552  	default:
   553  		if bv.coerceType >= 0 {
   554  			return bv.coerceType, 0
   555  		}
   556  		return bvar.Type, 0
   557  	}
   558  }
   559  
   560  // eval implements the Expr interface
   561  func (c *Column) eval(env *ExpressionEnv, result *EvalResult) {
   562  	if err := result.setValue(env.Row[c.Offset], c.coll); err != nil {
   563  		throwEvalError(err)
   564  	}
   565  }
   566  
   567  // typeof implements the Expr interface
   568  func (l *Literal) typeof(*ExpressionEnv) (sqltypes.Type, flag) {
   569  	return l.Val.typeof(), l.Val.flags_
   570  }
   571  
   572  // typeof implements the Expr interface
   573  func (t TupleExpr) typeof(*ExpressionEnv) (sqltypes.Type, flag) {
   574  	return sqltypes.Tuple, flagNullable
   575  }
   576  
   577  func (c *Column) typeof(env *ExpressionEnv) (sqltypes.Type, flag) {
   578  	// we'll try to do the best possible with the information we have
   579  	if c.Offset < len(env.Row) {
   580  		value := env.Row[c.Offset]
   581  		if value.IsNull() {
   582  			return sqltypes.Null, flagNull | flagNullable
   583  		}
   584  		return value.Type(), flag(0)
   585  	}
   586  
   587  	if c.Offset < len(env.Fields) {
   588  		return env.Fields[c.Offset].Type, flagNullable
   589  	}
   590  
   591  	panic("Column missing both data and field")
   592  }