github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/scalar.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"encoding/hex"
    19  	"fmt"
    20  	"strconv"
    21  	"strings"
    22  
    23  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/go-mysql-server/sql/encodings"
    27  	"github.com/dolthub/go-mysql-server/sql/expression"
    28  	"github.com/dolthub/go-mysql-server/sql/expression/function"
    29  	"github.com/dolthub/go-mysql-server/sql/expression/function/json"
    30  	"github.com/dolthub/go-mysql-server/sql/fulltext"
    31  	"github.com/dolthub/go-mysql-server/sql/plan"
    32  	"github.com/dolthub/go-mysql-server/sql/types"
    33  )
    34  
    35  func (b *Builder) buildWhere(inScope *scope, where *ast.Where) {
    36  	if where == nil {
    37  		return
    38  	}
    39  	filter := b.buildScalar(inScope, where.Expr)
    40  	filterNode := plan.NewFilter(filter, inScope.node)
    41  	inScope.node = filterNode
    42  }
    43  
    44  func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) {
    45  	defer func() {
    46  		if !(b.bindCtx == nil || b.bindCtx.resolveOnly) {
    47  			return
    48  		}
    49  
    50  		if be, ok := ex.(expression.BinaryExpression); ok {
    51  			left := be.Left()
    52  			right := be.Right()
    53  			if leftBindVar, ok := left.(*expression.BindVar); ok {
    54  				if typ, ok := hasColumnType(right); ok {
    55  					leftBindVar.Typ = typ
    56  					left = leftBindVar
    57  				}
    58  			} else if rightBindVar, ok := right.(*expression.BindVar); ok {
    59  				if typ, ok := hasColumnType(left); ok {
    60  					rightBindVar.Typ = typ
    61  					right = rightBindVar
    62  				}
    63  			}
    64  			ex, _ = be.WithChildren(left, right)
    65  		}
    66  	}()
    67  
    68  	switch v := e.(type) {
    69  	case *ast.Default:
    70  		return expression.WrapExpression(expression.NewDefaultColumn(v.ColName))
    71  	case *ast.SubstrExpr:
    72  		var name sql.Expression
    73  		if v.Name != nil {
    74  			name = b.buildScalar(inScope, v.Name)
    75  		} else {
    76  			name = b.buildScalar(inScope, v.StrVal)
    77  		}
    78  		start := b.buildScalar(inScope, v.From)
    79  
    80  		if v.To == nil {
    81  			return &function.Substring{Str: name, Start: start}
    82  		}
    83  		len := b.buildScalar(inScope, v.To)
    84  		return &function.Substring{Str: name, Start: start, Len: len}
    85  	case *ast.TrimExpr:
    86  		pat := b.buildScalar(inScope, v.Pattern)
    87  		str := b.buildScalar(inScope, v.Str)
    88  		return function.NewTrim(str, pat, v.Dir)
    89  	case *ast.ComparisonExpr:
    90  		return b.buildComparison(inScope, v)
    91  	case *ast.IsExpr:
    92  		return b.buildIsExprToExpression(inScope, v)
    93  	case *ast.NotExpr:
    94  		c := b.buildScalar(inScope, v.Expr)
    95  		return expression.NewNot(c)
    96  	case *ast.SQLVal:
    97  		return b.ConvertVal(v)
    98  	case ast.BoolVal:
    99  		return expression.NewLiteral(bool(v), types.Boolean)
   100  	case *ast.NullVal:
   101  		return expression.NewLiteral(nil, types.Null)
   102  	case *ast.ColName:
   103  		dbName := strings.ToLower(v.Qualifier.Qualifier.String())
   104  		tblName := strings.ToLower(v.Qualifier.Name.String())
   105  		colName := strings.ToLower(v.Name.String())
   106  		c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false)
   107  		if !ok {
   108  			sysVar, scope, ok := b.buildSysVar(v, ast.SetScope_None)
   109  			if ok {
   110  				return sysVar
   111  			}
   112  			var err error
   113  			if scope == ast.SetScope_User {
   114  				err = sql.ErrUnknownUserVariable.New(colName)
   115  			} else if scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly {
   116  				err = sql.ErrUnknownUserVariable.New(colName)
   117  			} else if scope == ast.SetScope_Global || scope == ast.SetScope_Session {
   118  				err = sql.ErrUnknownSystemVariable.New(colName)
   119  			} else if tblName != "" && !inScope.hasTable(tblName) {
   120  				err = sql.ErrTableNotFound.New(tblName)
   121  			} else if tblName != "" {
   122  				err = sql.ErrTableColumnNotFound.New(tblName, colName)
   123  			} else {
   124  				err = sql.ErrColumnNotFound.New(v)
   125  			}
   126  			b.handleErr(err)
   127  		}
   128  		c = c.withOriginal(v.Name.String())
   129  		return c.scalarGf()
   130  	case *ast.FuncExpr:
   131  		name := v.Name.Lowered()
   132  
   133  		if isAggregateFunc(name) && v.Over == nil {
   134  			// TODO this assumes aggregate is in the same scope
   135  			// also need to avoid nested aggregates
   136  			return b.buildAggregateFunc(inScope, name, v)
   137  		} else if isWindowFunc(name) {
   138  			return b.buildWindowFunc(inScope, name, v, (*ast.WindowDef)(v.Over))
   139  		}
   140  
   141  		f, err := b.cat.Function(b.ctx, name)
   142  		if err != nil {
   143  			b.handleErr(err)
   144  		}
   145  
   146  		args := make([]sql.Expression, len(v.Exprs))
   147  		for i, e := range v.Exprs {
   148  			args[i] = b.selectExprToExpression(inScope, e)
   149  		}
   150  
   151  		if name == "json_value" {
   152  			if len(args) == 3 {
   153  				args[2] = b.getJsonValueTypeLiteral(args[2])
   154  			}
   155  		}
   156  
   157  		rf, err := f.NewInstance(args)
   158  		if err != nil {
   159  			b.handleErr(err)
   160  		}
   161  
   162  		// NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
   163  		// errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
   164  		if v.Distinct {
   165  			if len(args) != 1 {
   166  				return nil
   167  			}
   168  			args[0] = expression.NewDistinctExpression(args[0])
   169  		}
   170  
   171  		if _, ok := rf.(sql.NonDeterministicExpression); ok && inScope.nearestSubquery() != nil {
   172  			inScope.nearestSubquery().markVolatile()
   173  		}
   174  
   175  		return rf
   176  
   177  	case *ast.GroupConcatExpr:
   178  		// TODO this is an aggregation
   179  		return b.buildGroupConcat(inScope, v)
   180  	case *ast.ParenExpr:
   181  		return b.buildScalar(inScope, v.Expr)
   182  	case *ast.AndExpr:
   183  		lhs := b.buildScalar(inScope, v.Left)
   184  		rhs := b.buildScalar(inScope, v.Right)
   185  		return expression.NewAnd(lhs, rhs)
   186  	case *ast.OrExpr:
   187  		lhs := b.buildScalar(inScope, v.Left)
   188  		rhs := b.buildScalar(inScope, v.Right)
   189  		return expression.NewOr(lhs, rhs)
   190  	case *ast.XorExpr:
   191  		lhs := b.buildScalar(inScope, v.Left)
   192  		rhs := b.buildScalar(inScope, v.Right)
   193  		return expression.NewXor(lhs, rhs)
   194  	case *ast.ConvertUsingExpr:
   195  		expr := b.buildScalar(inScope, v.Expr)
   196  		charset, err := sql.ParseCharacterSet(v.Type)
   197  		if err != nil {
   198  			b.handleErr(err)
   199  		}
   200  		return expression.NewConvertUsing(expr, charset)
   201  	case *ast.CharExpr:
   202  		args := make([]sql.Expression, len(v.Exprs))
   203  		for i, e := range v.Exprs {
   204  			args[i] = b.selectExprToExpression(inScope, e)
   205  		}
   206  
   207  		f, err := function.NewChar(args...)
   208  		if err != nil {
   209  			b.handleErr(err)
   210  		}
   211  
   212  		collId, err := sql.ParseCollation(&v.Type, nil, true)
   213  		if err != nil {
   214  			b.handleErr(err)
   215  		}
   216  
   217  		charFunc := f.(*function.Char)
   218  		charFunc.Collation = collId
   219  		return charFunc
   220  	case *ast.ConvertExpr:
   221  		var err error
   222  		typeLength := 0
   223  		if v.Type.Length != nil {
   224  			// TODO move to vitess
   225  			typeLength, err = strconv.Atoi(v.Type.Length.String())
   226  			if err != nil {
   227  				b.handleErr(err)
   228  			}
   229  		}
   230  
   231  		typeScale := 0
   232  		if v.Type.Scale != nil {
   233  			// TODO move to vitess
   234  			typeScale, err = strconv.Atoi(v.Type.Scale.String())
   235  			if err != nil {
   236  				b.handleErr(err)
   237  			}
   238  		}
   239  		expr := b.buildScalar(inScope, v.Expr)
   240  		ret, err := b.f.buildConvert(expr, v.Type.Type, typeLength, typeScale)
   241  		if err != nil {
   242  			b.handleErr(err)
   243  		}
   244  		return ret
   245  	case ast.InjectedExpr:
   246  		resolvedChildren := make([]any, len(v.Children))
   247  		for i, child := range v.Children {
   248  			resolvedChildren[i] = b.buildScalar(inScope, child)
   249  		}
   250  		expr, err := v.Expression.WithResolvedChildren(resolvedChildren)
   251  		if err != nil {
   252  			b.handleErr(err)
   253  		}
   254  		if sqlExpr, ok := expr.(sql.Expression); ok {
   255  			return sqlExpr
   256  		}
   257  		b.handleErr(fmt.Errorf("Injected expression does not resolve to a valid expression"))
   258  		return nil
   259  	case *ast.RangeCond:
   260  		val := b.buildScalar(inScope, v.Left)
   261  		lower := b.buildScalar(inScope, v.From)
   262  		upper := b.buildScalar(inScope, v.To)
   263  
   264  		switch strings.ToLower(v.Operator) {
   265  		case ast.BetweenStr:
   266  			return expression.NewBetween(val, lower, upper)
   267  		case ast.NotBetweenStr:
   268  			return expression.NewNot(expression.NewBetween(val, lower, upper))
   269  		default:
   270  			return nil
   271  		}
   272  	case ast.ValTuple:
   273  		var exprs = make([]sql.Expression, len(v))
   274  		for i, e := range v {
   275  			expr := b.buildScalar(inScope, e)
   276  			exprs[i] = expr
   277  		}
   278  		return expression.NewTuple(exprs...)
   279  
   280  	case *ast.BinaryExpr:
   281  		return b.buildBinaryScalar(inScope, v)
   282  	case *ast.UnaryExpr:
   283  		return b.buildUnaryScalar(inScope, v)
   284  	case *ast.Subquery:
   285  		sqScope := inScope.pushSubquery()
   286  		selectString := ast.String(v.Select)
   287  		selScope := b.buildSelectStmt(sqScope, v.Select)
   288  		// TODO: get the original select statement, not the reconstruction
   289  		sq := plan.NewSubquery(selScope.node, selectString)
   290  		sq = sq.WithCorrelated(sqScope.correlated())
   291  		if b.TriggerCtx().Active {
   292  			sq = sq.WithVolatile()
   293  		}
   294  		return sq
   295  	case *ast.CaseExpr:
   296  		return b.buildCaseExpr(inScope, v)
   297  	case *ast.IntervalExpr:
   298  		e := b.buildScalar(inScope, v.Expr)
   299  		return expression.NewInterval(e, v.Unit)
   300  	case *ast.CollateExpr:
   301  		// handleCollateExpr is meant to handle generic text-returning expressions that should be reinterpreted as a different collation.
   302  		innerExpr := b.buildScalar(inScope, v.Expr)
   303  		//TODO: rename this from Charset to Collation
   304  		collation, err := sql.ParseCollation(nil, &v.Charset, false)
   305  		if err != nil {
   306  			b.handleErr(err)
   307  		}
   308  		// If we're collating a string literal, we check that the charset and collation match now. Other string sources
   309  		// (such as from tables) will have their own charset, which we won't know until after the parsing stage.
   310  		charSet := b.ctx.GetCharacterSet()
   311  		if _, isLiteral := innerExpr.(*expression.Literal); isLiteral && collation.CharacterSet() != charSet {
   312  			b.handleErr(sql.ErrCollationInvalidForCharSet.New(collation.Name(), charSet.Name()))
   313  		}
   314  		return expression.NewCollatedExpression(innerExpr, collation)
   315  	case *ast.ValuesFuncExpr:
   316  		if b.insertActive {
   317  			if v.Name.Qualifier.Name.String() == "" {
   318  				v.Name.Qualifier.Name = ast.NewTableIdent(OnDupValuesPrefix)
   319  			}
   320  			dbName := strings.ToLower(v.Name.Qualifier.Qualifier.String())
   321  			tblName := strings.ToLower(v.Name.Qualifier.Name.String())
   322  			colName := strings.ToLower(v.Name.Name.String())
   323  			col, ok := inScope.resolveColumn(dbName, tblName, colName, false, false)
   324  			if !ok {
   325  				err := fmt.Errorf("expected ON DUPLICATE KEY ... VALUES() to reference a column, found: %s", v.Name.String())
   326  				b.handleErr(err)
   327  			}
   328  			return col.scalarGf()
   329  		} else {
   330  			col := b.buildScalar(inScope, v.Name)
   331  			fn, err := b.cat.Function(b.ctx, "values")
   332  			if err != nil {
   333  				b.handleErr(err)
   334  			}
   335  			values, err := fn.NewInstance([]sql.Expression{col})
   336  			if err != nil {
   337  				b.handleErr(err)
   338  			}
   339  			return values
   340  		}
   341  	case *ast.ExistsExpr:
   342  		sqScope := inScope.push()
   343  		sqScope.initSubquery()
   344  		selScope := b.buildSelectStmt(sqScope, v.Subquery.Select)
   345  		selectString := ast.String(v.Subquery.Select)
   346  		sq := plan.NewSubquery(selScope.node, selectString)
   347  		sq = sq.WithCorrelated(sqScope.correlated())
   348  		return plan.NewExistsSubquery(sq)
   349  	case *ast.TimestampFuncExpr:
   350  		var (
   351  			unit  sql.Expression
   352  			expr1 sql.Expression
   353  			expr2 sql.Expression
   354  		)
   355  
   356  		unit = expression.NewLiteral(v.Unit, types.LongText)
   357  		expr1 = b.buildScalar(inScope, v.Expr1)
   358  		expr2 = b.buildScalar(inScope, v.Expr2)
   359  
   360  		if v.Name == "timestampdiff" {
   361  			return function.NewTimestampDiff(unit, expr1, expr2)
   362  		} else if v.Name == "timestampadd" {
   363  			return nil
   364  		}
   365  		return nil
   366  	case *ast.ExtractFuncExpr:
   367  		var unit sql.Expression = expression.NewLiteral(strings.ToUpper(v.Unit), types.LongText)
   368  		expr := b.buildScalar(inScope, v.Expr)
   369  		return function.NewExtract(unit, expr)
   370  	case *ast.MatchExpr:
   371  		return b.buildMatchAgainst(inScope, v)
   372  	default:
   373  		b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e)))
   374  	}
   375  	return nil
   376  }
   377  
   378  // getJsonValueTypeLiteral converts a type coercion string into a literal
   379  // expression with the zero type of the coercion (see json_value function).
   380  func (b *Builder) getJsonValueTypeLiteral(e sql.Expression) sql.Expression {
   381  	typLit, ok := e.(*expression.Literal)
   382  	if !ok {
   383  		err := fmt.Errorf("invalid json_value coercion type: %s", e)
   384  		b.handleErr(err)
   385  	}
   386  	convStr, _, err := types.LongText.Convert(typLit.Value())
   387  	if err != nil {
   388  		err := fmt.Errorf("invalid json_value coercion type: %s; %s", typLit.Value(), err.Error())
   389  		b.handleErr(err)
   390  	}
   391  	var typ sql.Type
   392  	switch strings.ToLower(convStr.(string)) {
   393  	case "float":
   394  		typ = types.Float32
   395  	case "double", "decimal":
   396  		typ = types.Float64
   397  	case "signed":
   398  		typ = types.Int64
   399  	case "unsigned":
   400  		typ = types.Uint64
   401  	case "char":
   402  		typ = types.Text
   403  	case "json":
   404  		typ = types.JSON
   405  	case "time":
   406  		typ = types.Time
   407  	case "datetime":
   408  		typ = types.Datetime
   409  	case "date":
   410  		typ = types.Date
   411  	case "year":
   412  		typ = types.Year
   413  	default:
   414  		err := fmt.Errorf("invalid type for json_value: %s", convStr)
   415  		b.handleErr(err)
   416  	}
   417  	return expression.NewLiteral(typ.Zero(), typ)
   418  }
   419  
   420  func (b *Builder) buildUnaryScalar(inScope *scope, e *ast.UnaryExpr) sql.Expression {
   421  	switch strings.ToLower(e.Operator) {
   422  	case ast.MinusStr:
   423  		expr := b.buildScalar(inScope, e.Expr)
   424  		return expression.NewUnaryMinus(expr)
   425  	case ast.PlusStr:
   426  		// Unary plus expressions do nothing (do not turn the expression positive). Just return the underlying expressio return b.buildScalar(inScope, e.Expr)
   427  		return b.buildScalar(inScope, e.Expr)
   428  	case ast.BangStr:
   429  		c := b.buildScalar(inScope, e.Expr)
   430  		return expression.NewNot(c)
   431  	case ast.BinaryStr:
   432  		c := b.buildScalar(inScope, e.Expr)
   433  		return expression.NewBinary(c)
   434  	default:
   435  		lowerOperator := strings.TrimSpace(strings.ToLower(e.Operator))
   436  		if strings.HasPrefix(lowerOperator, "_") {
   437  			// This is a character set introducer, so we need to decode the string to our internal encoding (`utf8mb4`)
   438  			charSet, err := sql.ParseCharacterSet(lowerOperator[1:])
   439  			if err != nil {
   440  				b.handleErr(err)
   441  			}
   442  			if charSet.Encoder() == nil {
   443  				err := sql.ErrUnsupportedFeature.New("unsupported character set: " + charSet.Name())
   444  				b.handleErr(err)
   445  			}
   446  
   447  			// Due to how vitess orders expressions, COLLATE is a child rather than a parent, so we need to handle it in a special way
   448  			collation := charSet.DefaultCollation()
   449  			if collateExpr, ok := e.Expr.(*ast.CollateExpr); ok {
   450  				// We extract the expression out of CollateExpr as we're only concerned about the collation string
   451  				e.Expr = collateExpr.Expr
   452  				// TODO: rename this from Charset to Collation
   453  				collation, err = sql.ParseCollation(nil, &collateExpr.Charset, false)
   454  				if err != nil {
   455  					b.handleErr(err)
   456  				}
   457  				if collation.CharacterSet() != charSet {
   458  					err := sql.ErrCollationInvalidForCharSet.New(collation.Name(), charSet.Name())
   459  					b.handleErr(err)
   460  				}
   461  			}
   462  
   463  			// Character set introducers only work on string literals
   464  			expr := b.buildScalar(inScope, e.Expr)
   465  			if _, ok := expr.(*expression.Literal); !ok || !types.IsText(expr.Type()) {
   466  				err := sql.ErrCharSetIntroducer.New()
   467  				b.handleErr(err)
   468  			}
   469  			literal, _ := expr.Eval(b.ctx, nil)
   470  
   471  			// Internally all strings are `utf8mb4`, so we need to decode the string (which applies the introducer)
   472  			if strLiteral, ok := literal.(string); ok {
   473  				decodedLiteral, ok := charSet.Encoder().Decode(encodings.StringToBytes(strLiteral))
   474  				if !ok {
   475  					err := sql.ErrCharSetInvalidString.New(charSet.Name(), strLiteral)
   476  					b.handleErr(err)
   477  				}
   478  				return expression.NewLiteral(encodings.BytesToString(decodedLiteral), types.CreateLongText(collation))
   479  			} else if byteLiteral, ok := literal.([]byte); ok {
   480  				decodedLiteral, ok := charSet.Encoder().Decode(byteLiteral)
   481  				if !ok {
   482  					err := sql.ErrCharSetInvalidString.New(charSet.Name(), strLiteral)
   483  					b.handleErr(err)
   484  				}
   485  				return expression.NewLiteral(decodedLiteral, types.CreateLongText(collation))
   486  			} else {
   487  				// Should not be possible
   488  				err := fmt.Errorf("expression literal returned type `%s` but literal value had type `%T`",
   489  					expr.Type().String(), literal)
   490  				b.handleErr(err)
   491  			}
   492  		}
   493  		err := sql.ErrUnsupportedFeature.New("unary operator: " + e.Operator)
   494  		b.handleErr(err)
   495  	}
   496  	return nil
   497  }
   498  
   499  func (b *Builder) buildBinaryScalar(inScope *scope, be *ast.BinaryExpr) sql.Expression {
   500  	expr, err := b.binaryExprToExpression(inScope, be)
   501  	if err != nil {
   502  		b.handleErr(err)
   503  	}
   504  	return expr
   505  }
   506  
   507  func (b *Builder) buildComparison(inScope *scope, c *ast.ComparisonExpr) sql.Expression {
   508  	left := b.buildScalar(inScope, c.Left)
   509  	right := b.buildScalar(inScope, c.Right)
   510  
   511  	var escape sql.Expression = nil
   512  	if c.Escape != nil {
   513  		escape = b.buildScalar(inScope, c.Escape)
   514  	}
   515  
   516  	switch strings.ToLower(c.Operator) {
   517  	case ast.RegexpStr:
   518  		return expression.NewRegexp(left, right)
   519  	case ast.NotRegexpStr:
   520  		return expression.NewNot(expression.NewRegexp(left, right))
   521  	case ast.EqualStr:
   522  		return expression.NewEquals(left, right)
   523  	case ast.LessThanStr:
   524  		return expression.NewLessThan(left, right)
   525  	case ast.LessEqualStr:
   526  		return expression.NewLessThanOrEqual(left, right)
   527  	case ast.GreaterThanStr:
   528  		return expression.NewGreaterThan(left, right)
   529  	case ast.GreaterEqualStr:
   530  		return expression.NewGreaterThanOrEqual(left, right)
   531  	case ast.NullSafeEqualStr:
   532  		return expression.NewNullSafeEquals(left, right)
   533  	case ast.NotEqualStr:
   534  		return expression.NewNot(
   535  			expression.NewEquals(left, right),
   536  		)
   537  	case ast.InStr:
   538  		switch right.(type) {
   539  		case expression.Tuple:
   540  			return expression.NewInTuple(left, right)
   541  		case *plan.Subquery:
   542  			return plan.NewInSubquery(left, right)
   543  		default:
   544  			err := sql.ErrUnsupportedFeature.New(fmt.Sprintf("IN %T", right))
   545  			b.handleErr(err)
   546  		}
   547  	case ast.NotInStr:
   548  		switch right.(type) {
   549  		case expression.Tuple:
   550  			return expression.NewNotInTuple(left, right)
   551  		case *plan.Subquery:
   552  			return plan.NewNotInSubquery(left, right)
   553  		default:
   554  			err := sql.ErrUnsupportedFeature.New(fmt.Sprintf("NOT IN %T", right))
   555  			b.handleErr(err)
   556  		}
   557  	case ast.LikeStr:
   558  		return expression.NewLike(left, right, escape)
   559  	case ast.NotLikeStr:
   560  		return expression.NewNot(expression.NewLike(left, right, escape))
   561  	default:
   562  		err := sql.ErrUnsupportedFeature.New(c.Operator)
   563  		b.handleErr(err)
   564  	}
   565  	return nil
   566  }
   567  
   568  func hasColumnType(e sql.Expression) (sql.Type, bool) {
   569  	var typ sql.Type
   570  	sql.Inspect(e, func(e sql.Expression) bool {
   571  		if col, ok := e.(*expression.GetField); ok {
   572  			typ = col.Type()
   573  			return false
   574  		}
   575  		return true
   576  	})
   577  	return typ, typ != nil
   578  }
   579  
   580  func (b *Builder) buildCaseExpr(inScope *scope, e *ast.CaseExpr) sql.Expression {
   581  	expr, err := b.caseExprToExpression(inScope, e)
   582  	if err != nil {
   583  		b.handleErr(err)
   584  	}
   585  	return expr
   586  }
   587  
   588  func (b *Builder) buildIsExprToExpression(inScope *scope, c *ast.IsExpr) sql.Expression {
   589  	e := b.buildScalar(inScope, c.Expr)
   590  	switch strings.ToLower(c.Operator) {
   591  	case ast.IsNullStr:
   592  		return expression.NewIsNull(e)
   593  	case ast.IsNotNullStr:
   594  		return expression.NewNot(expression.NewIsNull(e))
   595  	case ast.IsTrueStr:
   596  		return expression.NewIsTrue(e)
   597  	case ast.IsFalseStr:
   598  		return expression.NewIsFalse(e)
   599  	case ast.IsNotTrueStr:
   600  		return expression.NewNot(expression.NewIsTrue(e))
   601  	case ast.IsNotFalseStr:
   602  		return expression.NewNot(expression.NewIsFalse(e))
   603  	default:
   604  		err := sql.ErrUnsupportedSyntax.New(ast.String(c))
   605  		b.handleErr(err)
   606  	}
   607  	return nil
   608  }
   609  
   610  func (b *Builder) binaryExprToExpression(inScope *scope, be *ast.BinaryExpr) (sql.Expression, error) {
   611  	l := b.buildScalar(inScope, be.Left)
   612  	r := b.buildScalar(inScope, be.Right)
   613  
   614  	operator := strings.ToLower(be.Operator)
   615  	switch operator {
   616  	case
   617  		ast.PlusStr,
   618  		ast.MinusStr,
   619  		ast.MultStr,
   620  		ast.DivStr,
   621  		ast.ShiftLeftStr,
   622  		ast.ShiftRightStr,
   623  		ast.BitAndStr,
   624  		ast.BitOrStr,
   625  		ast.BitXorStr,
   626  		ast.IntDivStr,
   627  		ast.ModStr:
   628  
   629  		_, lok := l.(*expression.Interval)
   630  		_, rok := r.(*expression.Interval)
   631  		if lok && be.Operator == "-" {
   632  			return nil, sql.ErrUnsupportedSyntax.New("subtracting from an interval")
   633  		} else if (lok || rok) && be.Operator != "+" && be.Operator != "-" {
   634  			return nil, sql.ErrUnsupportedSyntax.New("only + and - can be used to add or subtract intervals from dates")
   635  		} else if lok && rok {
   636  			return nil, sql.ErrUnsupportedSyntax.New("intervals cannot be added or subtracted from other intervals")
   637  		}
   638  
   639  		switch operator {
   640  		case ast.DivStr:
   641  			return expression.NewDiv(l, r), nil
   642  		case ast.ModStr:
   643  			return expression.NewMod(l, r), nil
   644  		case ast.BitAndStr, ast.BitOrStr, ast.BitXorStr, ast.ShiftRightStr, ast.ShiftLeftStr:
   645  			return expression.NewBitOp(l, r, be.Operator), nil
   646  		case ast.IntDivStr:
   647  			return expression.NewIntDiv(l, r), nil
   648  		case ast.MultStr:
   649  			return expression.NewMult(l, r), nil
   650  		case ast.PlusStr:
   651  			return expression.NewPlus(l, r), nil
   652  		case ast.MinusStr:
   653  			return expression.NewMinus(l, r), nil
   654  		default:
   655  			return nil, sql.ErrUnsupportedSyntax.New("unsupported operator: %s", be.Operator)
   656  		}
   657  
   658  	case ast.JSONExtractOp, ast.JSONUnquoteExtractOp:
   659  		jsonExtract, err := json.NewJSONExtract(l, r)
   660  		if err != nil {
   661  			return nil, err
   662  		}
   663  
   664  		if operator == ast.JSONUnquoteExtractOp {
   665  			return json.NewJSONUnquote(jsonExtract), nil
   666  		}
   667  		return jsonExtract, nil
   668  
   669  	default:
   670  		return nil, sql.ErrUnsupportedFeature.New(be.Operator)
   671  	}
   672  }
   673  
   674  func (b *Builder) caseExprToExpression(inScope *scope, e *ast.CaseExpr) (sql.Expression, error) {
   675  	var expr sql.Expression
   676  
   677  	if e.Expr != nil {
   678  		expr = b.buildScalar(inScope, e.Expr)
   679  	}
   680  
   681  	var branches []expression.CaseBranch
   682  	for _, w := range e.Whens {
   683  		var cond sql.Expression
   684  		cond = b.buildScalar(inScope, w.Cond)
   685  
   686  		var val sql.Expression
   687  		val = b.buildScalar(inScope, w.Val)
   688  
   689  		branches = append(branches, expression.CaseBranch{
   690  			Cond:  cond,
   691  			Value: val,
   692  		})
   693  	}
   694  
   695  	var elseExpr sql.Expression
   696  	if e.Else != nil {
   697  		elseExpr = b.buildScalar(inScope, e.Else)
   698  	}
   699  
   700  	return expression.NewCase(expr, branches, elseExpr), nil
   701  }
   702  
   703  func (b *Builder) intervalExprToExpression(inScope *scope, e *ast.IntervalExpr) *expression.Interval {
   704  	expr := b.buildScalar(inScope, e.Expr)
   705  	return expression.NewInterval(expr, e.Unit)
   706  }
   707  
   708  // Convert an integer, represented by the specified string in the specified
   709  // base, to its smallest representation possible, out of:
   710  // int8, uint8, int16, uint16, int32, uint32, int64 and uint64
   711  func (b *Builder) convertInt(value string, base int) *expression.Literal {
   712  	if i8, err := strconv.ParseInt(value, base, 8); err == nil {
   713  		return expression.NewLiteral(int8(i8), types.Int8)
   714  	}
   715  	if ui8, err := strconv.ParseUint(value, base, 8); err == nil {
   716  		return expression.NewLiteral(uint8(ui8), types.Uint8)
   717  	}
   718  	if i16, err := strconv.ParseInt(value, base, 16); err == nil {
   719  		return expression.NewLiteral(int16(i16), types.Int16)
   720  	}
   721  	if ui16, err := strconv.ParseUint(value, base, 16); err == nil {
   722  		return expression.NewLiteral(uint16(ui16), types.Uint16)
   723  	}
   724  	if i32, err := strconv.ParseInt(value, base, 32); err == nil {
   725  		return expression.NewLiteral(int32(i32), types.Int32)
   726  	}
   727  	if ui32, err := strconv.ParseUint(value, base, 32); err == nil {
   728  		return expression.NewLiteral(uint32(ui32), types.Uint32)
   729  	}
   730  	if i64, err := strconv.ParseInt(value, base, 64); err == nil {
   731  		return expression.NewLiteral(int64(i64), types.Int64)
   732  	}
   733  	if ui64, err := strconv.ParseUint(value, base, 64); err == nil {
   734  		return expression.NewLiteral(uint64(ui64), types.Uint64)
   735  	}
   736  	if decimal, _, err := types.InternalDecimalType.Convert(value); err == nil {
   737  		return expression.NewLiteral(decimal, types.InternalDecimalType)
   738  	}
   739  
   740  	b.handleErr(fmt.Errorf("could not convert %s to any numerical type", value))
   741  	return nil
   742  }
   743  
   744  func (b *Builder) ConvertVal(v *ast.SQLVal) sql.Expression {
   745  	switch v.Type {
   746  	case ast.StrVal:
   747  		return expression.NewLiteral(string(v.Val), types.CreateLongText(b.ctx.GetCollation()))
   748  	case ast.IntVal:
   749  		return b.convertInt(string(v.Val), 10)
   750  	case ast.FloatVal:
   751  		// any float value is parsed as decimal except when the value has scientific notation
   752  		ogVal := strings.ToLower(string(v.Val))
   753  		if strings.Contains(ogVal, "e") {
   754  			val, err := strconv.ParseFloat(string(v.Val), 64)
   755  			if err != nil {
   756  				b.handleErr(err)
   757  			}
   758  			return expression.NewLiteral(val, types.Float64)
   759  		}
   760  
   761  		// using DECIMAL data type avoids precision error of rounded up float64 value
   762  		if ps := strings.Split(string(v.Val), "."); len(ps) == 2 {
   763  			p, s := expression.GetDecimalPrecisionAndScale(ogVal)
   764  			dt, err := types.CreateDecimalType(p, s)
   765  			if err != nil {
   766  				return expression.NewLiteral(string(v.Val), types.CreateLongText(b.ctx.GetCollation()))
   767  			}
   768  			dVal, _, err := dt.Convert(ogVal)
   769  			if err != nil {
   770  				return expression.NewLiteral(string(v.Val), types.CreateLongText(b.ctx.GetCollation()))
   771  			}
   772  			return expression.NewLiteral(dVal, dt)
   773  		} else {
   774  			// if the value is not float type - this should not happen
   775  			return b.convertInt(string(v.Val), 10)
   776  		}
   777  	case ast.HexNum:
   778  		//TODO: binary collation?
   779  		v := strings.ToLower(string(v.Val))
   780  		if strings.HasPrefix(v, "0x") {
   781  			v = v[2:]
   782  		} else if strings.HasPrefix(v, "x") {
   783  			v = strings.Trim(v[1:], "'")
   784  		}
   785  
   786  		// pad string to even length
   787  		if len(v)%2 == 1 {
   788  			v = "0" + v
   789  		}
   790  
   791  		val, err := hex.DecodeString(v)
   792  		if err != nil {
   793  			b.handleErr(err)
   794  		}
   795  		return expression.NewLiteral(val, types.LongBlob)
   796  	case ast.HexVal:
   797  		//TODO: binary collation?
   798  		val, err := v.HexDecode()
   799  		if err != nil {
   800  			b.handleErr(err)
   801  		}
   802  		return expression.NewLiteral(val, types.LongBlob)
   803  	case ast.ValArg:
   804  		name := strings.TrimPrefix(string(v.Val), ":")
   805  		if b.bindCtx != nil {
   806  			if b.bindCtx.resolveOnly {
   807  				return expression.NewBindVar(name)
   808  			}
   809  			replacement := b.normalizeValArg(v)
   810  			return b.buildScalar(&scope{}, replacement)
   811  		}
   812  		return expression.NewBindVar(name)
   813  	case ast.BitVal:
   814  		if len(v.Val) == 0 {
   815  			return expression.NewLiteral(0, types.Uint64)
   816  		}
   817  
   818  		res, err := strconv.ParseUint(string(v.Val), 2, 64)
   819  		if err != nil {
   820  			b.handleErr(err)
   821  		}
   822  
   823  		return expression.NewLiteral(res, types.Uint64)
   824  	}
   825  
   826  	b.handleErr(sql.ErrInvalidSQLValType.New(v.Type))
   827  	return nil
   828  }
   829  
   830  // processMatchAgainst returns a new MatchAgainst expression that has had
   831  // all of its tables filled in. This essentially grabs the appropriate index
   832  // (if it hasn't already been grabbed), and then loads the appropriate
   833  // tables that are referenced by the index. The returned expression contains
   834  // everything needed to calculate relevancy.
   835  //
   836  // A fully resolved MatchAgainst expression is also used by the index
   837  // filter, since we only need to load the tables once. All steps after this
   838  // one can assume that the expression has been fully resolved and is valid.
   839  func (b *Builder) buildMatchAgainst(inScope *scope, v *ast.MatchExpr) *expression.MatchAgainst {
   840  	//TODO: implement proper scope support and remove this check
   841  	if (inScope.groupBy != nil && inScope.groupBy.hasAggs()) || inScope.windowFuncs != nil {
   842  		b.handleErr(fmt.Errorf("aggregate and window functions are not yet supported alongside MATCH expressions"))
   843  	}
   844  	rts := getTablesByName(inScope.node)
   845  	var rt *plan.ResolvedTable
   846  	var matchTable string
   847  	cols := make([]*expression.GetField, len(v.Columns))
   848  	for i, selectExpr := range v.Columns {
   849  		expr := b.selectExprToExpression(inScope, selectExpr)
   850  		gf, ok := expr.(*expression.GetField)
   851  		if !ok {
   852  			err := sql.ErrFullTextMatchAgainstNotColumns.New()
   853  			b.handleErr(err)
   854  		}
   855  		if rt == nil {
   856  			matchTable = strings.ToLower(gf.Table())
   857  			rt, ok = rts[matchTable]
   858  			if !ok {
   859  				// shouldn't be able to resolve expression without table being available
   860  				panic("shouldn't be able to resolve expression without table being available")
   861  			}
   862  		} else if !strings.EqualFold(matchTable, gf.Table()) {
   863  			err := sql.ErrFullTextMatchAgainstSameTable.New()
   864  			b.handleErr(err)
   865  		}
   866  		cols[i] = gf
   867  	}
   868  	matchExpr := b.buildScalar(inScope, v.Expr)
   869  	var searchModifier fulltext.SearchModifier
   870  	var err error
   871  	switch v.Option {
   872  	case ast.NaturalLanguageModeStr, "":
   873  		searchModifier = fulltext.SearchModifier_NaturalLanguage
   874  	case ast.NaturalLanguageModeWithQueryExpansionStr:
   875  		searchModifier = fulltext.SearchModifier_NaturalLangaugeQueryExpansion
   876  		err = fmt.Errorf(`"IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION" is not supported yet`)
   877  	case ast.BooleanModeStr:
   878  		searchModifier = fulltext.SearchModifier_Boolean
   879  		err = fmt.Errorf(`"IN BOOLEAN MODE" is not supported yet`)
   880  	case ast.QueryExpansionStr:
   881  		searchModifier = fulltext.SearchModifier_QueryExpansion
   882  		err = fmt.Errorf(`"WITH QUERY EXPANSION" is not supported yet`)
   883  	default:
   884  		err = sql.ErrUnsupportedFeature.New(v.Option)
   885  	}
   886  	if err != nil {
   887  		b.handleErr(err)
   888  	}
   889  
   890  	innerTbl := rt.UnderlyingTable()
   891  	indexedTbl, ok := innerTbl.(sql.IndexAddressableTable)
   892  	if !ok {
   893  		err := fmt.Errorf("cannot use MATCH ... AGAINST ... on a table that does not declare indexes")
   894  		b.handleErr(err)
   895  	}
   896  
   897  	indexes, err := indexedTbl.GetIndexes(b.ctx)
   898  	if err != nil {
   899  		b.handleErr(err)
   900  	}
   901  	ftIndex := findMatchAgainstIndex(cols, indexes)
   902  	if ftIndex == nil {
   903  		err := sql.ErrNoFullTextIndexFound.New(indexedTbl.Name())
   904  		b.handleErr(err)
   905  	}
   906  
   907  	// Get the key columns
   908  	keyCols, err := ftIndex.FullTextKeyColumns(b.ctx)
   909  	if err != nil {
   910  		b.handleErr(err)
   911  	}
   912  
   913  	genericCols := make([]sql.Expression, len(cols))
   914  	for i, e := range cols {
   915  		genericCols[i] = e
   916  	}
   917  
   918  	// Grab the pseudo-index table names
   919  	tableNames, err := ftIndex.FullTextTableNames(b.ctx)
   920  	if err != nil {
   921  		b.handleErr(err)
   922  	}
   923  
   924  	fullindexTableNames := [5]string{tableNames.Config, tableNames.Position, tableNames.DocCount, tableNames.GlobalCount, tableNames.RowCount}
   925  	idxTables := make([]sql.IndexAddressableTable, 5)
   926  	for i, name := range fullindexTableNames {
   927  		configTbl, ok, err := rt.SqlDatabase.GetTableInsensitive(b.ctx, name)
   928  		if err != nil {
   929  			b.handleErr(err)
   930  		}
   931  		if !ok {
   932  			err := fmt.Errorf("Full-Text index `%s` on table `%s` is linked to table `%s` which could not be found",
   933  				ftIndex.ID(), indexedTbl.Name(), tableNames.Config)
   934  			b.handleErr(err)
   935  		}
   936  		idxTables[i], ok = configTbl.(sql.IndexAddressableTable)
   937  		if !ok {
   938  			err := fmt.Errorf("Full-Text index `%s` on table `%s` requires table `%s` to implement sql.IndexAddressableTable",
   939  				ftIndex.ID(), indexedTbl.Name(), tableNames.Config)
   940  			b.handleErr(err)
   941  		}
   942  	}
   943  
   944  	matchAgainst := expression.NewMatchAgainst(genericCols, matchExpr, searchModifier)
   945  	matchAgainst.SetIndex(ftIndex)
   946  
   947  	return matchAgainst.WithInfo(indexedTbl, idxTables[0], idxTables[1], idxTables[2], idxTables[3], idxTables[4], keyCols)
   948  }
   949  
   950  func findMatchAgainstIndex(cols []*expression.GetField, indexes []sql.Index) fulltext.Index {
   951  	var found fulltext.Index
   952  	for _, idx := range indexes {
   953  		idxExprs := idx.Expressions()
   954  		if !idx.IsFullText() || len(cols) != len(idxExprs) {
   955  			continue
   956  		}
   957  		// check that index expressions match |cols|
   958  		allMatch := true
   959  		for _, gf := range cols {
   960  			var match bool
   961  			for _, idxExpr := range idxExprs {
   962  				if gf.String() == idxExpr {
   963  					match = true
   964  					break
   965  				}
   966  			}
   967  			if !match {
   968  				allMatch = false
   969  				break
   970  			}
   971  		}
   972  		if !allMatch {
   973  			continue
   974  		}
   975  		var ok bool
   976  		found, ok = idx.(fulltext.Index)
   977  		if ok {
   978  			break
   979  		}
   980  	}
   981  	return found
   982  }