github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/expression/binary_op.go (about)

     1  // Copyright 2023 zGraph Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"bytes"
    19  	"errors"
    20  	"fmt"
    21  	"math"
    22  
    23  	"github.com/cockroachdb/apd/v3"
    24  	"github.com/vescale/zgraph/datum"
    25  	"github.com/vescale/zgraph/parser/opcode"
    26  	"github.com/vescale/zgraph/stmtctx"
    27  	"github.com/vescale/zgraph/types"
    28  )
    29  
    30  var _ Expression = &BinaryExpr{}
    31  
    32  type BinaryExpr struct {
    33  	Op     opcode.Op
    34  	Left   Expression
    35  	Right  Expression
    36  	EvalOp BinaryEvalOp
    37  }
    38  
    39  func (expr *BinaryExpr) String() string {
    40  	return fmt.Sprintf("%s %s %s", expr.Left, expr.Op, expr.Right)
    41  }
    42  
    43  func (expr *BinaryExpr) ReturnType() types.T {
    44  	leftType := expr.Left.ReturnType()
    45  	rightType := expr.Right.ReturnType()
    46  	return expr.EvalOp.InferReturnType(leftType, rightType)
    47  }
    48  
    49  func (expr *BinaryExpr) Eval(stmtCtx *stmtctx.Context, input datum.Row) (datum.Datum, error) {
    50  	left, err := expr.Left.Eval(stmtCtx, input)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	if left == datum.Null && !expr.EvalOp.CallOnNullInput() {
    55  		return datum.Null, nil
    56  	}
    57  	right, err := expr.Right.Eval(stmtCtx, input)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	if right == datum.Null && !expr.EvalOp.CallOnNullInput() {
    62  		return datum.Null, nil
    63  	}
    64  	return expr.EvalOp.Eval(stmtCtx, left, right)
    65  }
    66  
    67  func NewBinaryExpr(op opcode.Op, left, right Expression) (*BinaryExpr, error) {
    68  	binOp, ok := binOps[op]
    69  	if !ok {
    70  		return nil, fmt.Errorf("unsupported binary operator: %s", op)
    71  	}
    72  	return &BinaryExpr{
    73  		Op:     op,
    74  		Left:   left,
    75  		Right:  right,
    76  		EvalOp: binOp,
    77  	}, nil
    78  }
    79  
    80  type BinaryEvalOp interface {
    81  	InferReturnType(leftType, rightType types.T) types.T
    82  	CallOnNullInput() bool
    83  	Eval(stmtCtx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error)
    84  }
    85  
    86  var binOps = map[opcode.Op]BinaryEvalOp{
    87  	opcode.Plus:     makeArithOp(opcode.Plus),
    88  	opcode.Minus:    makeArithOp(opcode.Minus),
    89  	opcode.Mul:      makeArithOp(opcode.Mul),
    90  	opcode.Div:      makeArithOp(opcode.Div),
    91  	opcode.Mod:      makeArithOp(opcode.Mod),
    92  	opcode.LogicAnd: logicalAndOp{},
    93  	opcode.LogicOr:  logicalOrOp{},
    94  	opcode.EQ:       makeCmpOp(opcode.EQ),
    95  	opcode.NE:       makeNegateCmpOp(opcode.EQ), // NE(left, right) is implemented as !EQ(left, right)
    96  	opcode.LT:       makeCmpOp(opcode.LT),
    97  	opcode.LE:       makeFlippedNegateCmpOp(opcode.LT), // LE(left, right) is implemented as !LT(right, left)
    98  	opcode.GE:       makeNegateCmpOp(opcode.LT),        // GE(left, right) is implemented as !LT(left, right)
    99  	opcode.GT:       makeFlippedCmpOp(opcode.LT),       // GT(left, right) is implemented as LT(right, left)
   100  }
   101  
   102  type typePair struct {
   103  	left  types.T
   104  	right types.T
   105  }
   106  
   107  type binEvalFunc func(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error)
   108  
   109  func makeBinEvalFuncWithLeftCast(eval binEvalFunc, cast castFunc) binEvalFunc {
   110  	return func(stmtCtx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   111  		left, err := cast(stmtCtx, left)
   112  		if err != nil {
   113  			return nil, err
   114  		}
   115  		return eval(stmtCtx, left, right)
   116  	}
   117  }
   118  
   119  func makeBinEvalFuncWithRightCast(eval binEvalFunc, cast castFunc) binEvalFunc {
   120  	return func(stmtCtx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   121  		right, err := cast(stmtCtx, right)
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  		return eval(stmtCtx, left, right)
   126  	}
   127  }
   128  
   129  var numericOpReturnTypes = map[typePair]types.T{
   130  	{types.Int, types.Int}:         types.Int,
   131  	{types.Int, types.Float}:       types.Float,
   132  	{types.Int, types.Decimal}:     types.Decimal,
   133  	{types.Float, types.Int}:       types.Float,
   134  	{types.Float, types.Float}:     types.Float,
   135  	{types.Float, types.Decimal}:   types.Decimal,
   136  	{types.Decimal, types.Int}:     types.Decimal,
   137  	{types.Decimal, types.Float}:   types.Decimal,
   138  	{types.Decimal, types.Decimal}: types.Decimal,
   139  }
   140  
   141  var arithOpReturnTypes = func() map[opcode.Op]map[typePair]types.T {
   142  	result := make(map[opcode.Op]map[typePair]types.T)
   143  	for _, op := range []opcode.Op{opcode.Plus, opcode.Minus, opcode.Mul, opcode.Div, opcode.Mod} {
   144  		result[op] = numericOpReturnTypes
   145  	}
   146  	for _, op := range []opcode.Op{opcode.Plus, opcode.Minus} {
   147  		result[op][typePair{types.Date, types.Interval}] = types.Date
   148  		result[op][typePair{types.Time, types.Interval}] = types.Time
   149  		result[op][typePair{types.TimeTZ, types.Interval}] = types.TimeTZ
   150  		result[op][typePair{types.Timestamp, types.Interval}] = types.Timestamp
   151  		result[op][typePair{types.TimestampTZ, types.Interval}] = types.TimestampTZ
   152  	}
   153  	return result
   154  }()
   155  
   156  var arithOpEvalFuncs = map[opcode.Op]map[typePair]binEvalFunc{
   157  	opcode.Plus: {
   158  		{types.Int, types.Int}:              plusInt,
   159  		{types.Int, types.Float}:            makeBinEvalFuncWithLeftCast(plusFloat, castIntAsFloat),
   160  		{types.Int, types.Decimal}:          makeBinEvalFuncWithLeftCast(plusDecimal, castIntAsDecimal),
   161  		{types.Float, types.Int}:            makeBinEvalFuncWithRightCast(plusFloat, castIntAsFloat),
   162  		{types.Float, types.Float}:          plusFloat,
   163  		{types.Float, types.Decimal}:        makeBinEvalFuncWithLeftCast(plusDecimal, castFloatAsDecimal),
   164  		{types.Decimal, types.Int}:          makeBinEvalFuncWithRightCast(plusDecimal, castIntAsDecimal),
   165  		{types.Decimal, types.Float}:        makeBinEvalFuncWithRightCast(plusDecimal, castFloatAsDecimal),
   166  		{types.Decimal, types.Decimal}:      plusDecimal,
   167  		{types.Date, types.Interval}:        plusDateInterval,
   168  		{types.Time, types.Interval}:        plusTimeInterval,
   169  		{types.TimeTZ, types.Interval}:      plusTimeTZInterval,
   170  		{types.Timestamp, types.Interval}:   plusTimestampInterval,
   171  		{types.TimestampTZ, types.Interval}: plusTimestampTZInterval,
   172  	},
   173  	opcode.Minus: {
   174  		{types.Int, types.Int}:              minusInt,
   175  		{types.Int, types.Float}:            makeBinEvalFuncWithLeftCast(minusFloat, castIntAsFloat),
   176  		{types.Int, types.Decimal}:          makeBinEvalFuncWithLeftCast(minusDecimal, castIntAsDecimal),
   177  		{types.Float, types.Int}:            makeBinEvalFuncWithRightCast(minusFloat, castIntAsFloat),
   178  		{types.Float, types.Float}:          minusFloat,
   179  		{types.Float, types.Decimal}:        makeBinEvalFuncWithLeftCast(minusDecimal, castFloatAsDecimal),
   180  		{types.Decimal, types.Int}:          makeBinEvalFuncWithRightCast(minusDecimal, castIntAsDecimal),
   181  		{types.Decimal, types.Float}:        makeBinEvalFuncWithRightCast(minusDecimal, castFloatAsDecimal),
   182  		{types.Decimal, types.Decimal}:      minusDecimal,
   183  		{types.Date, types.Interval}:        minusDateInterval,
   184  		{types.Time, types.Interval}:        minusTimeInterval,
   185  		{types.TimeTZ, types.Interval}:      minusTimeTZInterval,
   186  		{types.Timestamp, types.Interval}:   minusTimestampInterval,
   187  		{types.TimestampTZ, types.Interval}: minusTimestampTZInterval,
   188  	},
   189  	opcode.Mul: {
   190  		{types.Int, types.Int}:         mulInt,
   191  		{types.Int, types.Float}:       makeBinEvalFuncWithLeftCast(mulFloat, castIntAsFloat),
   192  		{types.Int, types.Decimal}:     makeBinEvalFuncWithLeftCast(mulDecimal, castIntAsDecimal),
   193  		{types.Float, types.Int}:       makeBinEvalFuncWithRightCast(mulFloat, castIntAsFloat),
   194  		{types.Float, types.Float}:     mulFloat,
   195  		{types.Float, types.Decimal}:   makeBinEvalFuncWithLeftCast(mulDecimal, castFloatAsDecimal),
   196  		{types.Decimal, types.Int}:     makeBinEvalFuncWithRightCast(mulDecimal, castIntAsDecimal),
   197  		{types.Decimal, types.Float}:   makeBinEvalFuncWithRightCast(mulDecimal, castFloatAsDecimal),
   198  		{types.Decimal, types.Decimal}: mulDecimal,
   199  	},
   200  	opcode.Div: {
   201  		{types.Int, types.Int}:         divInt,
   202  		{types.Int, types.Float}:       makeBinEvalFuncWithLeftCast(divFloat, castIntAsFloat),
   203  		{types.Int, types.Decimal}:     makeBinEvalFuncWithLeftCast(divDecimal, castIntAsDecimal),
   204  		{types.Float, types.Int}:       makeBinEvalFuncWithRightCast(divFloat, castIntAsFloat),
   205  		{types.Float, types.Float}:     divFloat,
   206  		{types.Float, types.Decimal}:   makeBinEvalFuncWithLeftCast(divDecimal, castFloatAsDecimal),
   207  		{types.Decimal, types.Int}:     makeBinEvalFuncWithRightCast(divDecimal, castIntAsDecimal),
   208  		{types.Decimal, types.Float}:   makeBinEvalFuncWithRightCast(divDecimal, castFloatAsDecimal),
   209  		{types.Decimal, types.Decimal}: divDecimal,
   210  	},
   211  	opcode.Mod: {
   212  		{types.Int, types.Int}:         modInt,
   213  		{types.Int, types.Float}:       makeBinEvalFuncWithLeftCast(modFloat, castIntAsFloat),
   214  		{types.Int, types.Decimal}:     makeBinEvalFuncWithLeftCast(modDecimal, castIntAsDecimal),
   215  		{types.Float, types.Int}:       makeBinEvalFuncWithRightCast(modFloat, castIntAsFloat),
   216  		{types.Float, types.Float}:     modFloat,
   217  		{types.Float, types.Decimal}:   makeBinEvalFuncWithLeftCast(modDecimal, castFloatAsDecimal),
   218  		{types.Decimal, types.Int}:     makeBinEvalFuncWithRightCast(modDecimal, castIntAsDecimal),
   219  		{types.Decimal, types.Float}:   makeBinEvalFuncWithRightCast(modDecimal, castFloatAsDecimal),
   220  		{types.Decimal, types.Decimal}: modDecimal,
   221  	},
   222  }
   223  
   224  type arithOp struct {
   225  	op          opcode.Op
   226  	returnTypes map[typePair]types.T
   227  	evalFuncs   map[typePair]binEvalFunc
   228  }
   229  
   230  func (op arithOp) InferReturnType(leftType, rightType types.T) types.T {
   231  	return op.returnTypes[typePair{leftType, rightType}]
   232  }
   233  
   234  func (op arithOp) CallOnNullInput() bool {
   235  	return false
   236  }
   237  
   238  func (op arithOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   239  	evalFunc, ok := op.evalFuncs[typePair{left.Type(), right.Type()}]
   240  	if !ok {
   241  		return nil, fmt.Errorf("cannot evaluate %s on %s and %s", op.op, left.Type(), right.Type())
   242  	}
   243  	return evalFunc(ctx, left, right)
   244  }
   245  
   246  func makeArithOp(op opcode.Op) arithOp {
   247  	returnTypes := arithOpReturnTypes[op]
   248  	evalFuncs := arithOpEvalFuncs[op]
   249  	return arithOp{
   250  		op:          op,
   251  		returnTypes: returnTypes,
   252  		evalFuncs:   evalFuncs,
   253  	}
   254  }
   255  
   256  func plusInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   257  	l := datum.AsInt(left)
   258  	r := datum.AsInt(right)
   259  	return datum.NewInt(l + r), nil
   260  }
   261  
   262  func plusFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   263  	l := datum.AsFloat(left)
   264  	r := datum.AsFloat(right)
   265  	return datum.NewFloat(l + r), nil
   266  }
   267  
   268  func plusDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   269  	l := datum.AsDecimal(left)
   270  	r := datum.AsDecimal(right)
   271  	d := &apd.Decimal{}
   272  	_, err := apd.BaseContext.Add(d, l, r)
   273  	if err != nil {
   274  		return nil, err
   275  	}
   276  	return datum.NewDecimal(d), nil
   277  }
   278  
   279  func plusDateInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   280  	return nil, errors.New("plusDateInterval unimplemented")
   281  }
   282  
   283  func plusTimeInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   284  	return nil, errors.New("plusTimeInterval unimplemented")
   285  }
   286  
   287  func plusTimeTZInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   288  	return nil, errors.New("plusTimeTZInterval unimplemented")
   289  }
   290  
   291  func plusTimestampInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   292  	return nil, errors.New("plusTimestampInterval unimplemented")
   293  }
   294  
   295  func plusTimestampTZInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   296  	return nil, errors.New("plusTimestampTZInterval unimplemented")
   297  }
   298  
   299  func minusInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   300  	l := datum.AsInt(left)
   301  	r := datum.AsInt(right)
   302  	return datum.NewInt(l - r), nil
   303  }
   304  
   305  func minusFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   306  	l := datum.AsFloat(left)
   307  	r := datum.AsFloat(right)
   308  	return datum.NewFloat(l - r), nil
   309  }
   310  
   311  func minusDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   312  	l := datum.AsDecimal(left)
   313  	r := datum.AsDecimal(right)
   314  	d := &apd.Decimal{}
   315  	_, err := apd.BaseContext.Sub(d, l, r)
   316  	if err != nil {
   317  		return nil, err
   318  	}
   319  	return datum.NewDecimal(d), nil
   320  }
   321  
   322  func minusDateInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   323  	return nil, errors.New("minusDateInterval unimplemented")
   324  }
   325  
   326  func minusTimeInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   327  	return nil, errors.New("minusTimeInterval unimplemented")
   328  }
   329  
   330  func minusTimeTZInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   331  	return nil, errors.New("minusTimeTZInterval unimplemented")
   332  }
   333  
   334  func minusTimestampInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   335  	return nil, errors.New("minusTimestampInterval unimplemented")
   336  }
   337  
   338  func minusTimestampTZInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   339  	return nil, errors.New("minusTimestampTZInterval unimplemented")
   340  }
   341  
   342  func mulInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   343  	l := datum.AsInt(left)
   344  	r := datum.AsInt(right)
   345  	return datum.NewInt(l * r), nil
   346  }
   347  
   348  func mulFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   349  	l := datum.AsFloat(left)
   350  	r := datum.AsFloat(right)
   351  	return datum.NewFloat(l * r), nil
   352  }
   353  
   354  func mulDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   355  	l := datum.AsDecimal(left)
   356  	r := datum.AsDecimal(right)
   357  	d := &apd.Decimal{}
   358  	_, err := apd.BaseContext.Mul(d, l, r)
   359  	if err != nil {
   360  		return nil, err
   361  	}
   362  	return datum.NewDecimal(d), nil
   363  }
   364  
   365  func divInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   366  	l := datum.AsInt(left)
   367  	r := datum.AsInt(right)
   368  	if r == 0 {
   369  		return nil, errors.New("division by zero")
   370  	}
   371  	return datum.NewInt(l / r), nil
   372  }
   373  
   374  func divFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   375  	l := datum.AsFloat(left)
   376  	r := datum.AsFloat(right)
   377  	if r == 0 {
   378  		return nil, errors.New("division by zero")
   379  	}
   380  	return datum.NewFloat(l / r), nil
   381  }
   382  
   383  func divDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   384  	l := datum.AsDecimal(left)
   385  	r := datum.AsDecimal(right)
   386  	if r.IsZero() {
   387  		return nil, errors.New("division by zero")
   388  	}
   389  	d := &apd.Decimal{}
   390  	_, err := apd.BaseContext.Quo(d, l, r)
   391  	if err != nil {
   392  		return nil, err
   393  	}
   394  	return datum.NewDecimal(d), nil
   395  }
   396  
   397  func modInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   398  	l := datum.AsInt(left)
   399  	r := datum.AsInt(right)
   400  	if r == 0 {
   401  		return nil, errors.New("division by zero")
   402  	}
   403  	return datum.NewInt(l % r), nil
   404  }
   405  
   406  func modFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   407  	l := datum.AsFloat(left)
   408  	r := datum.AsFloat(right)
   409  	if r == 0 {
   410  		return nil, errors.New("division by zero")
   411  	}
   412  	return datum.NewFloat(math.Mod(l, r)), nil
   413  }
   414  
   415  func modDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   416  	l := datum.AsDecimal(left)
   417  	r := datum.AsDecimal(right)
   418  	if r.IsZero() {
   419  		return nil, errors.New("division by zero")
   420  	}
   421  	d := &apd.Decimal{}
   422  	_, err := apd.BaseContext.Rem(d, l, r)
   423  	if err != nil {
   424  		return nil, err
   425  	}
   426  	return datum.NewDecimal(d), nil
   427  }
   428  
   429  type logicalAndOp struct{}
   430  
   431  func (logicalAndOp) InferReturnType(_, _ types.T) types.T {
   432  	return types.Bool
   433  }
   434  
   435  func (logicalAndOp) CallOnNullInput() bool {
   436  	return true
   437  }
   438  
   439  func (logicalAndOp) Eval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   440  	leftBool, lerr := datum.TryAsBool(left)
   441  	rightBool, rerr := datum.TryAsBool(right)
   442  	if left == datum.Null {
   443  		if rerr == nil && !rightBool {
   444  			return datum.NewBool(false), nil
   445  		}
   446  		return datum.Null, nil
   447  	}
   448  	if right == datum.Null {
   449  		if lerr == nil && !leftBool {
   450  			return datum.NewBool(false), nil
   451  		}
   452  		return datum.Null, nil
   453  	}
   454  	return datum.NewBool(leftBool && rightBool), nil
   455  }
   456  
   457  type logicalOrOp struct{}
   458  
   459  func (logicalOrOp) InferReturnType(_, _ types.T) types.T {
   460  	return types.Bool
   461  }
   462  
   463  func (logicalOrOp) CallOnNullInput() bool {
   464  	return true
   465  }
   466  
   467  func (logicalOrOp) Eval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   468  	leftBool, lerr := datum.TryAsBool(left)
   469  	rightBool, rerr := datum.TryAsBool(right)
   470  	if left == datum.Null {
   471  		if rerr == nil && rightBool {
   472  			return datum.NewBool(true), nil
   473  		}
   474  		return datum.Null, nil
   475  	}
   476  	if right == datum.Null {
   477  		if lerr == nil && leftBool {
   478  			return datum.NewBool(true), nil
   479  		}
   480  		return datum.Null, nil
   481  	}
   482  	return datum.NewBool(leftBool || rightBool), nil
   483  }
   484  
   485  var cmpOpEvalFuncs = map[opcode.Op]map[typePair]binEvalFunc{
   486  	opcode.EQ: {
   487  		{types.Bool, types.Bool}:               cmpEqBool,
   488  		{types.Int, types.Int}:                 cmpEqInt,
   489  		{types.Int, types.Float}:               makeBinEvalFuncWithLeftCast(cmpEqFloat, castIntAsFloat),
   490  		{types.Int, types.Decimal}:             makeBinEvalFuncWithLeftCast(cmpEqDecimal, castIntAsDecimal),
   491  		{types.Float, types.Int}:               makeBinEvalFuncWithRightCast(cmpEqFloat, castIntAsFloat),
   492  		{types.Float, types.Float}:             cmpEqFloat,
   493  		{types.Float, types.Decimal}:           makeBinEvalFuncWithLeftCast(cmpEqDecimal, castFloatAsDecimal),
   494  		{types.Decimal, types.Int}:             makeBinEvalFuncWithRightCast(cmpEqDecimal, castIntAsDecimal),
   495  		{types.Decimal, types.Float}:           makeBinEvalFuncWithRightCast(cmpEqDecimal, castFloatAsDecimal),
   496  		{types.Decimal, types.Decimal}:         cmpEqDecimal,
   497  		{types.String, types.String}:           cmpEqString,
   498  		{types.String, types.Bytes}:            makeBinEvalFuncWithRightCast(cmpEqString, castBytesAsString),
   499  		{types.Bytes, types.String}:            makeBinEvalFuncWithLeftCast(cmpEqString, castBytesAsString),
   500  		{types.Bytes, types.Bytes}:             cmpEqBytes,
   501  		{types.Date, types.Date}:               cmpEqDate,
   502  		{types.Time, types.Time}:               cmpEqTime,
   503  		{types.Time, types.TimeTZ}:             makeBinEvalFuncWithLeftCast(cmpEqTimeTZ, castTimeAsTimeTZ),
   504  		{types.TimeTZ, types.Time}:             makeBinEvalFuncWithRightCast(cmpEqTimeTZ, castTimeAsTimeTZ),
   505  		{types.TimeTZ, types.TimeTZ}:           cmpEqTimeTZ,
   506  		{types.Timestamp, types.Timestamp}:     cmpEqTimestamp,
   507  		{types.Timestamp, types.TimestampTZ}:   makeBinEvalFuncWithLeftCast(cmpEqTimestampTZ, castTimestampAsTimestampTZ),
   508  		{types.TimestampTZ, types.Timestamp}:   makeBinEvalFuncWithRightCast(cmpEqTimestampTZ, castTimestampAsTimestampTZ),
   509  		{types.TimestampTZ, types.TimestampTZ}: cmpEqTimestampTZ,
   510  		{types.Vertex, types.Vertex}:           cmpEqVertex,
   511  		{types.Edge, types.Edge}:               cmpEqEdge,
   512  	},
   513  	opcode.LT: {
   514  		{types.Bool, types.Bool}:               cmpLtBool,
   515  		{types.Int, types.Int}:                 cmpLtInt,
   516  		{types.Int, types.Float}:               makeBinEvalFuncWithLeftCast(cmpLtFloat, castIntAsFloat),
   517  		{types.Int, types.Decimal}:             makeBinEvalFuncWithLeftCast(cmpLtDecimal, castIntAsDecimal),
   518  		{types.Float, types.Int}:               makeBinEvalFuncWithRightCast(cmpLtFloat, castIntAsFloat),
   519  		{types.Float, types.Float}:             cmpLtFloat,
   520  		{types.Float, types.Decimal}:           makeBinEvalFuncWithLeftCast(cmpLtDecimal, castFloatAsDecimal),
   521  		{types.Decimal, types.Int}:             makeBinEvalFuncWithRightCast(cmpLtDecimal, castIntAsDecimal),
   522  		{types.Decimal, types.Float}:           makeBinEvalFuncWithRightCast(cmpLtDecimal, castFloatAsDecimal),
   523  		{types.Decimal, types.Decimal}:         cmpLtDecimal,
   524  		{types.String, types.String}:           cmpLtString,
   525  		{types.String, types.Bytes}:            makeBinEvalFuncWithRightCast(cmpLtString, castBytesAsString),
   526  		{types.Bytes, types.String}:            makeBinEvalFuncWithLeftCast(cmpLtString, castBytesAsString),
   527  		{types.Bytes, types.Bytes}:             cmpLtBytes,
   528  		{types.Date, types.Date}:               cmpLtDate,
   529  		{types.Time, types.Time}:               cmpLtTime,
   530  		{types.Time, types.TimeTZ}:             makeBinEvalFuncWithLeftCast(cmpLtTimeTZ, castTimeAsTimeTZ),
   531  		{types.TimeTZ, types.Time}:             makeBinEvalFuncWithRightCast(cmpLtTimeTZ, castTimeAsTimeTZ),
   532  		{types.TimeTZ, types.TimeTZ}:           cmpLtTimeTZ,
   533  		{types.Timestamp, types.Timestamp}:     cmpLtTimestamp,
   534  		{types.Timestamp, types.TimestampTZ}:   makeBinEvalFuncWithLeftCast(cmpLtTimestampTZ, castTimestampAsTimestampTZ),
   535  		{types.TimestampTZ, types.Timestamp}:   makeBinEvalFuncWithRightCast(cmpLtTimestampTZ, castTimestampAsTimestampTZ),
   536  		{types.TimestampTZ, types.TimestampTZ}: cmpLtTimestampTZ,
   537  	},
   538  }
   539  
   540  type cmpOp struct {
   541  	op        opcode.Op
   542  	evalFuncs map[typePair]binEvalFunc
   543  }
   544  
   545  func (op cmpOp) InferReturnType(_, _ types.T) types.T {
   546  	return types.Bool
   547  }
   548  
   549  func (op cmpOp) CallOnNullInput() bool {
   550  	return false
   551  }
   552  
   553  func (op cmpOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   554  	evalFunc, ok := op.evalFuncs[typePair{left.Type(), right.Type()}]
   555  	if !ok {
   556  		return nil, fmt.Errorf("cannot evaluate %s on %s and %s", op.op, left.Type(), right.Type())
   557  	}
   558  	return evalFunc(ctx, left, right)
   559  }
   560  
   561  func makeCmpOp(op opcode.Op) cmpOp {
   562  	evalFuncs := cmpOpEvalFuncs[op]
   563  	return cmpOp{
   564  		op:        op,
   565  		evalFuncs: evalFuncs,
   566  	}
   567  }
   568  
   569  type flippedCmpOp struct {
   570  	cmpOp
   571  }
   572  
   573  func (op flippedCmpOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   574  	return op.cmpOp.Eval(ctx, right, left)
   575  }
   576  
   577  func makeFlippedCmpOp(op opcode.Op) flippedCmpOp {
   578  	return flippedCmpOp{makeCmpOp(op)}
   579  }
   580  
   581  type negateCmpOp struct {
   582  	cmpOp
   583  }
   584  
   585  func (op negateCmpOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   586  	res, err := op.cmpOp.Eval(ctx, left, right)
   587  	if err != nil {
   588  		return nil, err
   589  	}
   590  	return datum.NewBool(!datum.AsBool(res)), nil
   591  }
   592  
   593  func makeNegateCmpOp(op opcode.Op) negateCmpOp {
   594  	return negateCmpOp{makeCmpOp(op)}
   595  }
   596  
   597  type flippedNegateCmpOp struct {
   598  	cmpOp
   599  }
   600  
   601  func (op flippedNegateCmpOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   602  	res, err := op.cmpOp.Eval(ctx, right, left)
   603  	if err != nil {
   604  		return nil, err
   605  	}
   606  	return datum.NewBool(!datum.AsBool(res)), nil
   607  }
   608  
   609  func makeFlippedNegateCmpOp(op opcode.Op) flippedNegateCmpOp {
   610  	return flippedNegateCmpOp{makeCmpOp(op)}
   611  }
   612  
   613  func cmpEqBool(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   614  	return datum.NewBool(datum.AsBool(left) == datum.AsBool(right)), nil
   615  }
   616  
   617  func cmpLtBool(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   618  	// left < right is true if left is false and right is true.
   619  	return datum.NewBool(!datum.AsBool(left) && datum.AsBool(right)), nil
   620  }
   621  
   622  func cmpEqInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   623  	return datum.NewBool(datum.AsInt(left) == datum.AsInt(right)), nil
   624  }
   625  
   626  func cmpLtInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   627  	return datum.NewBool(datum.AsInt(left) < datum.AsInt(right)), nil
   628  }
   629  
   630  func cmpEqFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   631  	return datum.NewBool(datum.AsFloat(left) == datum.AsFloat(right)), nil
   632  }
   633  
   634  func cmpLtFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   635  	return datum.NewBool(datum.AsFloat(left) < datum.AsFloat(right)), nil
   636  }
   637  
   638  func cmpEqDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   639  	l := datum.AsDecimal(left)
   640  	r := datum.AsDecimal(right)
   641  	return datum.NewBool(l.Cmp(r) == 0), nil
   642  }
   643  
   644  func cmpLtDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   645  	l := datum.AsDecimal(left)
   646  	r := datum.AsDecimal(right)
   647  	return datum.NewBool(l.Cmp(r) < 0), nil
   648  }
   649  
   650  func cmpEqString(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   651  	return datum.NewBool(datum.AsString(left) == datum.AsString(right)), nil
   652  }
   653  
   654  func cmpLtString(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   655  	return datum.NewBool(datum.AsString(left) < datum.AsString(right)), nil
   656  }
   657  
   658  func cmpEqBytes(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   659  	return datum.NewBool(bytes.Equal(datum.AsBytes(left), datum.AsBytes(right))), nil
   660  }
   661  
   662  func cmpLtBytes(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   663  	return datum.NewBool(bytes.Compare(datum.AsBytes(left), datum.AsBytes(right)) < 0), nil
   664  }
   665  
   666  func cmpEqDate(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   667  	return nil, fmt.Errorf("cmpEqDate not implemented")
   668  }
   669  
   670  func cmpLtDate(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   671  	return nil, fmt.Errorf("cmpLtDate not implemented")
   672  }
   673  
   674  func cmpEqTime(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   675  	return nil, fmt.Errorf("cmpEqTime not implemented")
   676  }
   677  
   678  func cmpLtTime(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   679  	return nil, fmt.Errorf("cmpLtTime not implemented")
   680  }
   681  
   682  func cmpEqTimeTZ(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   683  	return nil, fmt.Errorf("cmpEqTimeTZ not implemented")
   684  }
   685  
   686  func cmpLtTimeTZ(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   687  	return nil, fmt.Errorf("cmpLtTimeTZ not implemented")
   688  }
   689  
   690  func cmpEqTimestamp(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   691  	return nil, fmt.Errorf("cmpEqTimestamp not implemented")
   692  }
   693  
   694  func cmpLtTimestamp(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   695  	return nil, fmt.Errorf("cmpLtTimestamp not implemented")
   696  }
   697  
   698  func cmpEqTimestampTZ(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   699  	return nil, fmt.Errorf("cmpEqTimestampTZ not implemented")
   700  }
   701  
   702  func cmpLtTimestampTZ(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   703  	return nil, fmt.Errorf("cmpLtTimestampTZ not implemented")
   704  }
   705  
   706  func cmpEqVertex(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   707  	return nil, fmt.Errorf("cmpEqVertex not implemented")
   708  }
   709  
   710  func cmpEqEdge(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) {
   711  	return nil, fmt.Errorf("cmpEqEdge not implemented")
   712  }