github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/expression/binary_op.go (about) 1 // Copyright 2023 zGraph Authors. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "bytes" 19 "errors" 20 "fmt" 21 "math" 22 23 "github.com/cockroachdb/apd/v3" 24 "github.com/vescale/zgraph/datum" 25 "github.com/vescale/zgraph/parser/opcode" 26 "github.com/vescale/zgraph/stmtctx" 27 "github.com/vescale/zgraph/types" 28 ) 29 30 var _ Expression = &BinaryExpr{} 31 32 type BinaryExpr struct { 33 Op opcode.Op 34 Left Expression 35 Right Expression 36 EvalOp BinaryEvalOp 37 } 38 39 func (expr *BinaryExpr) String() string { 40 return fmt.Sprintf("%s %s %s", expr.Left, expr.Op, expr.Right) 41 } 42 43 func (expr *BinaryExpr) ReturnType() types.T { 44 leftType := expr.Left.ReturnType() 45 rightType := expr.Right.ReturnType() 46 return expr.EvalOp.InferReturnType(leftType, rightType) 47 } 48 49 func (expr *BinaryExpr) Eval(stmtCtx *stmtctx.Context, input datum.Row) (datum.Datum, error) { 50 left, err := expr.Left.Eval(stmtCtx, input) 51 if err != nil { 52 return nil, err 53 } 54 if left == datum.Null && !expr.EvalOp.CallOnNullInput() { 55 return datum.Null, nil 56 } 57 right, err := expr.Right.Eval(stmtCtx, input) 58 if err != nil { 59 return nil, err 60 } 61 if right == datum.Null && !expr.EvalOp.CallOnNullInput() { 62 return datum.Null, nil 63 } 64 return expr.EvalOp.Eval(stmtCtx, left, right) 65 } 66 67 func NewBinaryExpr(op opcode.Op, left, right Expression) (*BinaryExpr, error) { 68 binOp, ok := binOps[op] 69 if !ok { 70 return nil, fmt.Errorf("unsupported binary operator: %s", op) 71 } 72 return &BinaryExpr{ 73 Op: op, 74 Left: left, 75 Right: right, 76 EvalOp: binOp, 77 }, nil 78 } 79 80 type BinaryEvalOp interface { 81 InferReturnType(leftType, rightType types.T) types.T 82 CallOnNullInput() bool 83 Eval(stmtCtx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) 84 } 85 86 var binOps = map[opcode.Op]BinaryEvalOp{ 87 opcode.Plus: makeArithOp(opcode.Plus), 88 opcode.Minus: makeArithOp(opcode.Minus), 89 opcode.Mul: makeArithOp(opcode.Mul), 90 opcode.Div: makeArithOp(opcode.Div), 91 opcode.Mod: makeArithOp(opcode.Mod), 92 opcode.LogicAnd: logicalAndOp{}, 93 opcode.LogicOr: logicalOrOp{}, 94 opcode.EQ: makeCmpOp(opcode.EQ), 95 opcode.NE: makeNegateCmpOp(opcode.EQ), // NE(left, right) is implemented as !EQ(left, right) 96 opcode.LT: makeCmpOp(opcode.LT), 97 opcode.LE: makeFlippedNegateCmpOp(opcode.LT), // LE(left, right) is implemented as !LT(right, left) 98 opcode.GE: makeNegateCmpOp(opcode.LT), // GE(left, right) is implemented as !LT(left, right) 99 opcode.GT: makeFlippedCmpOp(opcode.LT), // GT(left, right) is implemented as LT(right, left) 100 } 101 102 type typePair struct { 103 left types.T 104 right types.T 105 } 106 107 type binEvalFunc func(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) 108 109 func makeBinEvalFuncWithLeftCast(eval binEvalFunc, cast castFunc) binEvalFunc { 110 return func(stmtCtx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 111 left, err := cast(stmtCtx, left) 112 if err != nil { 113 return nil, err 114 } 115 return eval(stmtCtx, left, right) 116 } 117 } 118 119 func makeBinEvalFuncWithRightCast(eval binEvalFunc, cast castFunc) binEvalFunc { 120 return func(stmtCtx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 121 right, err := cast(stmtCtx, right) 122 if err != nil { 123 return nil, err 124 } 125 return eval(stmtCtx, left, right) 126 } 127 } 128 129 var numericOpReturnTypes = map[typePair]types.T{ 130 {types.Int, types.Int}: types.Int, 131 {types.Int, types.Float}: types.Float, 132 {types.Int, types.Decimal}: types.Decimal, 133 {types.Float, types.Int}: types.Float, 134 {types.Float, types.Float}: types.Float, 135 {types.Float, types.Decimal}: types.Decimal, 136 {types.Decimal, types.Int}: types.Decimal, 137 {types.Decimal, types.Float}: types.Decimal, 138 {types.Decimal, types.Decimal}: types.Decimal, 139 } 140 141 var arithOpReturnTypes = func() map[opcode.Op]map[typePair]types.T { 142 result := make(map[opcode.Op]map[typePair]types.T) 143 for _, op := range []opcode.Op{opcode.Plus, opcode.Minus, opcode.Mul, opcode.Div, opcode.Mod} { 144 result[op] = numericOpReturnTypes 145 } 146 for _, op := range []opcode.Op{opcode.Plus, opcode.Minus} { 147 result[op][typePair{types.Date, types.Interval}] = types.Date 148 result[op][typePair{types.Time, types.Interval}] = types.Time 149 result[op][typePair{types.TimeTZ, types.Interval}] = types.TimeTZ 150 result[op][typePair{types.Timestamp, types.Interval}] = types.Timestamp 151 result[op][typePair{types.TimestampTZ, types.Interval}] = types.TimestampTZ 152 } 153 return result 154 }() 155 156 var arithOpEvalFuncs = map[opcode.Op]map[typePair]binEvalFunc{ 157 opcode.Plus: { 158 {types.Int, types.Int}: plusInt, 159 {types.Int, types.Float}: makeBinEvalFuncWithLeftCast(plusFloat, castIntAsFloat), 160 {types.Int, types.Decimal}: makeBinEvalFuncWithLeftCast(plusDecimal, castIntAsDecimal), 161 {types.Float, types.Int}: makeBinEvalFuncWithRightCast(plusFloat, castIntAsFloat), 162 {types.Float, types.Float}: plusFloat, 163 {types.Float, types.Decimal}: makeBinEvalFuncWithLeftCast(plusDecimal, castFloatAsDecimal), 164 {types.Decimal, types.Int}: makeBinEvalFuncWithRightCast(plusDecimal, castIntAsDecimal), 165 {types.Decimal, types.Float}: makeBinEvalFuncWithRightCast(plusDecimal, castFloatAsDecimal), 166 {types.Decimal, types.Decimal}: plusDecimal, 167 {types.Date, types.Interval}: plusDateInterval, 168 {types.Time, types.Interval}: plusTimeInterval, 169 {types.TimeTZ, types.Interval}: plusTimeTZInterval, 170 {types.Timestamp, types.Interval}: plusTimestampInterval, 171 {types.TimestampTZ, types.Interval}: plusTimestampTZInterval, 172 }, 173 opcode.Minus: { 174 {types.Int, types.Int}: minusInt, 175 {types.Int, types.Float}: makeBinEvalFuncWithLeftCast(minusFloat, castIntAsFloat), 176 {types.Int, types.Decimal}: makeBinEvalFuncWithLeftCast(minusDecimal, castIntAsDecimal), 177 {types.Float, types.Int}: makeBinEvalFuncWithRightCast(minusFloat, castIntAsFloat), 178 {types.Float, types.Float}: minusFloat, 179 {types.Float, types.Decimal}: makeBinEvalFuncWithLeftCast(minusDecimal, castFloatAsDecimal), 180 {types.Decimal, types.Int}: makeBinEvalFuncWithRightCast(minusDecimal, castIntAsDecimal), 181 {types.Decimal, types.Float}: makeBinEvalFuncWithRightCast(minusDecimal, castFloatAsDecimal), 182 {types.Decimal, types.Decimal}: minusDecimal, 183 {types.Date, types.Interval}: minusDateInterval, 184 {types.Time, types.Interval}: minusTimeInterval, 185 {types.TimeTZ, types.Interval}: minusTimeTZInterval, 186 {types.Timestamp, types.Interval}: minusTimestampInterval, 187 {types.TimestampTZ, types.Interval}: minusTimestampTZInterval, 188 }, 189 opcode.Mul: { 190 {types.Int, types.Int}: mulInt, 191 {types.Int, types.Float}: makeBinEvalFuncWithLeftCast(mulFloat, castIntAsFloat), 192 {types.Int, types.Decimal}: makeBinEvalFuncWithLeftCast(mulDecimal, castIntAsDecimal), 193 {types.Float, types.Int}: makeBinEvalFuncWithRightCast(mulFloat, castIntAsFloat), 194 {types.Float, types.Float}: mulFloat, 195 {types.Float, types.Decimal}: makeBinEvalFuncWithLeftCast(mulDecimal, castFloatAsDecimal), 196 {types.Decimal, types.Int}: makeBinEvalFuncWithRightCast(mulDecimal, castIntAsDecimal), 197 {types.Decimal, types.Float}: makeBinEvalFuncWithRightCast(mulDecimal, castFloatAsDecimal), 198 {types.Decimal, types.Decimal}: mulDecimal, 199 }, 200 opcode.Div: { 201 {types.Int, types.Int}: divInt, 202 {types.Int, types.Float}: makeBinEvalFuncWithLeftCast(divFloat, castIntAsFloat), 203 {types.Int, types.Decimal}: makeBinEvalFuncWithLeftCast(divDecimal, castIntAsDecimal), 204 {types.Float, types.Int}: makeBinEvalFuncWithRightCast(divFloat, castIntAsFloat), 205 {types.Float, types.Float}: divFloat, 206 {types.Float, types.Decimal}: makeBinEvalFuncWithLeftCast(divDecimal, castFloatAsDecimal), 207 {types.Decimal, types.Int}: makeBinEvalFuncWithRightCast(divDecimal, castIntAsDecimal), 208 {types.Decimal, types.Float}: makeBinEvalFuncWithRightCast(divDecimal, castFloatAsDecimal), 209 {types.Decimal, types.Decimal}: divDecimal, 210 }, 211 opcode.Mod: { 212 {types.Int, types.Int}: modInt, 213 {types.Int, types.Float}: makeBinEvalFuncWithLeftCast(modFloat, castIntAsFloat), 214 {types.Int, types.Decimal}: makeBinEvalFuncWithLeftCast(modDecimal, castIntAsDecimal), 215 {types.Float, types.Int}: makeBinEvalFuncWithRightCast(modFloat, castIntAsFloat), 216 {types.Float, types.Float}: modFloat, 217 {types.Float, types.Decimal}: makeBinEvalFuncWithLeftCast(modDecimal, castFloatAsDecimal), 218 {types.Decimal, types.Int}: makeBinEvalFuncWithRightCast(modDecimal, castIntAsDecimal), 219 {types.Decimal, types.Float}: makeBinEvalFuncWithRightCast(modDecimal, castFloatAsDecimal), 220 {types.Decimal, types.Decimal}: modDecimal, 221 }, 222 } 223 224 type arithOp struct { 225 op opcode.Op 226 returnTypes map[typePair]types.T 227 evalFuncs map[typePair]binEvalFunc 228 } 229 230 func (op arithOp) InferReturnType(leftType, rightType types.T) types.T { 231 return op.returnTypes[typePair{leftType, rightType}] 232 } 233 234 func (op arithOp) CallOnNullInput() bool { 235 return false 236 } 237 238 func (op arithOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 239 evalFunc, ok := op.evalFuncs[typePair{left.Type(), right.Type()}] 240 if !ok { 241 return nil, fmt.Errorf("cannot evaluate %s on %s and %s", op.op, left.Type(), right.Type()) 242 } 243 return evalFunc(ctx, left, right) 244 } 245 246 func makeArithOp(op opcode.Op) arithOp { 247 returnTypes := arithOpReturnTypes[op] 248 evalFuncs := arithOpEvalFuncs[op] 249 return arithOp{ 250 op: op, 251 returnTypes: returnTypes, 252 evalFuncs: evalFuncs, 253 } 254 } 255 256 func plusInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 257 l := datum.AsInt(left) 258 r := datum.AsInt(right) 259 return datum.NewInt(l + r), nil 260 } 261 262 func plusFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 263 l := datum.AsFloat(left) 264 r := datum.AsFloat(right) 265 return datum.NewFloat(l + r), nil 266 } 267 268 func plusDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 269 l := datum.AsDecimal(left) 270 r := datum.AsDecimal(right) 271 d := &apd.Decimal{} 272 _, err := apd.BaseContext.Add(d, l, r) 273 if err != nil { 274 return nil, err 275 } 276 return datum.NewDecimal(d), nil 277 } 278 279 func plusDateInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 280 return nil, errors.New("plusDateInterval unimplemented") 281 } 282 283 func plusTimeInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 284 return nil, errors.New("plusTimeInterval unimplemented") 285 } 286 287 func plusTimeTZInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 288 return nil, errors.New("plusTimeTZInterval unimplemented") 289 } 290 291 func plusTimestampInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 292 return nil, errors.New("plusTimestampInterval unimplemented") 293 } 294 295 func plusTimestampTZInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 296 return nil, errors.New("plusTimestampTZInterval unimplemented") 297 } 298 299 func minusInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 300 l := datum.AsInt(left) 301 r := datum.AsInt(right) 302 return datum.NewInt(l - r), nil 303 } 304 305 func minusFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 306 l := datum.AsFloat(left) 307 r := datum.AsFloat(right) 308 return datum.NewFloat(l - r), nil 309 } 310 311 func minusDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 312 l := datum.AsDecimal(left) 313 r := datum.AsDecimal(right) 314 d := &apd.Decimal{} 315 _, err := apd.BaseContext.Sub(d, l, r) 316 if err != nil { 317 return nil, err 318 } 319 return datum.NewDecimal(d), nil 320 } 321 322 func minusDateInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 323 return nil, errors.New("minusDateInterval unimplemented") 324 } 325 326 func minusTimeInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 327 return nil, errors.New("minusTimeInterval unimplemented") 328 } 329 330 func minusTimeTZInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 331 return nil, errors.New("minusTimeTZInterval unimplemented") 332 } 333 334 func minusTimestampInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 335 return nil, errors.New("minusTimestampInterval unimplemented") 336 } 337 338 func minusTimestampTZInterval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 339 return nil, errors.New("minusTimestampTZInterval unimplemented") 340 } 341 342 func mulInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 343 l := datum.AsInt(left) 344 r := datum.AsInt(right) 345 return datum.NewInt(l * r), nil 346 } 347 348 func mulFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 349 l := datum.AsFloat(left) 350 r := datum.AsFloat(right) 351 return datum.NewFloat(l * r), nil 352 } 353 354 func mulDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 355 l := datum.AsDecimal(left) 356 r := datum.AsDecimal(right) 357 d := &apd.Decimal{} 358 _, err := apd.BaseContext.Mul(d, l, r) 359 if err != nil { 360 return nil, err 361 } 362 return datum.NewDecimal(d), nil 363 } 364 365 func divInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 366 l := datum.AsInt(left) 367 r := datum.AsInt(right) 368 if r == 0 { 369 return nil, errors.New("division by zero") 370 } 371 return datum.NewInt(l / r), nil 372 } 373 374 func divFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 375 l := datum.AsFloat(left) 376 r := datum.AsFloat(right) 377 if r == 0 { 378 return nil, errors.New("division by zero") 379 } 380 return datum.NewFloat(l / r), nil 381 } 382 383 func divDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 384 l := datum.AsDecimal(left) 385 r := datum.AsDecimal(right) 386 if r.IsZero() { 387 return nil, errors.New("division by zero") 388 } 389 d := &apd.Decimal{} 390 _, err := apd.BaseContext.Quo(d, l, r) 391 if err != nil { 392 return nil, err 393 } 394 return datum.NewDecimal(d), nil 395 } 396 397 func modInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 398 l := datum.AsInt(left) 399 r := datum.AsInt(right) 400 if r == 0 { 401 return nil, errors.New("division by zero") 402 } 403 return datum.NewInt(l % r), nil 404 } 405 406 func modFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 407 l := datum.AsFloat(left) 408 r := datum.AsFloat(right) 409 if r == 0 { 410 return nil, errors.New("division by zero") 411 } 412 return datum.NewFloat(math.Mod(l, r)), nil 413 } 414 415 func modDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 416 l := datum.AsDecimal(left) 417 r := datum.AsDecimal(right) 418 if r.IsZero() { 419 return nil, errors.New("division by zero") 420 } 421 d := &apd.Decimal{} 422 _, err := apd.BaseContext.Rem(d, l, r) 423 if err != nil { 424 return nil, err 425 } 426 return datum.NewDecimal(d), nil 427 } 428 429 type logicalAndOp struct{} 430 431 func (logicalAndOp) InferReturnType(_, _ types.T) types.T { 432 return types.Bool 433 } 434 435 func (logicalAndOp) CallOnNullInput() bool { 436 return true 437 } 438 439 func (logicalAndOp) Eval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 440 leftBool, lerr := datum.TryAsBool(left) 441 rightBool, rerr := datum.TryAsBool(right) 442 if left == datum.Null { 443 if rerr == nil && !rightBool { 444 return datum.NewBool(false), nil 445 } 446 return datum.Null, nil 447 } 448 if right == datum.Null { 449 if lerr == nil && !leftBool { 450 return datum.NewBool(false), nil 451 } 452 return datum.Null, nil 453 } 454 return datum.NewBool(leftBool && rightBool), nil 455 } 456 457 type logicalOrOp struct{} 458 459 func (logicalOrOp) InferReturnType(_, _ types.T) types.T { 460 return types.Bool 461 } 462 463 func (logicalOrOp) CallOnNullInput() bool { 464 return true 465 } 466 467 func (logicalOrOp) Eval(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 468 leftBool, lerr := datum.TryAsBool(left) 469 rightBool, rerr := datum.TryAsBool(right) 470 if left == datum.Null { 471 if rerr == nil && rightBool { 472 return datum.NewBool(true), nil 473 } 474 return datum.Null, nil 475 } 476 if right == datum.Null { 477 if lerr == nil && leftBool { 478 return datum.NewBool(true), nil 479 } 480 return datum.Null, nil 481 } 482 return datum.NewBool(leftBool || rightBool), nil 483 } 484 485 var cmpOpEvalFuncs = map[opcode.Op]map[typePair]binEvalFunc{ 486 opcode.EQ: { 487 {types.Bool, types.Bool}: cmpEqBool, 488 {types.Int, types.Int}: cmpEqInt, 489 {types.Int, types.Float}: makeBinEvalFuncWithLeftCast(cmpEqFloat, castIntAsFloat), 490 {types.Int, types.Decimal}: makeBinEvalFuncWithLeftCast(cmpEqDecimal, castIntAsDecimal), 491 {types.Float, types.Int}: makeBinEvalFuncWithRightCast(cmpEqFloat, castIntAsFloat), 492 {types.Float, types.Float}: cmpEqFloat, 493 {types.Float, types.Decimal}: makeBinEvalFuncWithLeftCast(cmpEqDecimal, castFloatAsDecimal), 494 {types.Decimal, types.Int}: makeBinEvalFuncWithRightCast(cmpEqDecimal, castIntAsDecimal), 495 {types.Decimal, types.Float}: makeBinEvalFuncWithRightCast(cmpEqDecimal, castFloatAsDecimal), 496 {types.Decimal, types.Decimal}: cmpEqDecimal, 497 {types.String, types.String}: cmpEqString, 498 {types.String, types.Bytes}: makeBinEvalFuncWithRightCast(cmpEqString, castBytesAsString), 499 {types.Bytes, types.String}: makeBinEvalFuncWithLeftCast(cmpEqString, castBytesAsString), 500 {types.Bytes, types.Bytes}: cmpEqBytes, 501 {types.Date, types.Date}: cmpEqDate, 502 {types.Time, types.Time}: cmpEqTime, 503 {types.Time, types.TimeTZ}: makeBinEvalFuncWithLeftCast(cmpEqTimeTZ, castTimeAsTimeTZ), 504 {types.TimeTZ, types.Time}: makeBinEvalFuncWithRightCast(cmpEqTimeTZ, castTimeAsTimeTZ), 505 {types.TimeTZ, types.TimeTZ}: cmpEqTimeTZ, 506 {types.Timestamp, types.Timestamp}: cmpEqTimestamp, 507 {types.Timestamp, types.TimestampTZ}: makeBinEvalFuncWithLeftCast(cmpEqTimestampTZ, castTimestampAsTimestampTZ), 508 {types.TimestampTZ, types.Timestamp}: makeBinEvalFuncWithRightCast(cmpEqTimestampTZ, castTimestampAsTimestampTZ), 509 {types.TimestampTZ, types.TimestampTZ}: cmpEqTimestampTZ, 510 {types.Vertex, types.Vertex}: cmpEqVertex, 511 {types.Edge, types.Edge}: cmpEqEdge, 512 }, 513 opcode.LT: { 514 {types.Bool, types.Bool}: cmpLtBool, 515 {types.Int, types.Int}: cmpLtInt, 516 {types.Int, types.Float}: makeBinEvalFuncWithLeftCast(cmpLtFloat, castIntAsFloat), 517 {types.Int, types.Decimal}: makeBinEvalFuncWithLeftCast(cmpLtDecimal, castIntAsDecimal), 518 {types.Float, types.Int}: makeBinEvalFuncWithRightCast(cmpLtFloat, castIntAsFloat), 519 {types.Float, types.Float}: cmpLtFloat, 520 {types.Float, types.Decimal}: makeBinEvalFuncWithLeftCast(cmpLtDecimal, castFloatAsDecimal), 521 {types.Decimal, types.Int}: makeBinEvalFuncWithRightCast(cmpLtDecimal, castIntAsDecimal), 522 {types.Decimal, types.Float}: makeBinEvalFuncWithRightCast(cmpLtDecimal, castFloatAsDecimal), 523 {types.Decimal, types.Decimal}: cmpLtDecimal, 524 {types.String, types.String}: cmpLtString, 525 {types.String, types.Bytes}: makeBinEvalFuncWithRightCast(cmpLtString, castBytesAsString), 526 {types.Bytes, types.String}: makeBinEvalFuncWithLeftCast(cmpLtString, castBytesAsString), 527 {types.Bytes, types.Bytes}: cmpLtBytes, 528 {types.Date, types.Date}: cmpLtDate, 529 {types.Time, types.Time}: cmpLtTime, 530 {types.Time, types.TimeTZ}: makeBinEvalFuncWithLeftCast(cmpLtTimeTZ, castTimeAsTimeTZ), 531 {types.TimeTZ, types.Time}: makeBinEvalFuncWithRightCast(cmpLtTimeTZ, castTimeAsTimeTZ), 532 {types.TimeTZ, types.TimeTZ}: cmpLtTimeTZ, 533 {types.Timestamp, types.Timestamp}: cmpLtTimestamp, 534 {types.Timestamp, types.TimestampTZ}: makeBinEvalFuncWithLeftCast(cmpLtTimestampTZ, castTimestampAsTimestampTZ), 535 {types.TimestampTZ, types.Timestamp}: makeBinEvalFuncWithRightCast(cmpLtTimestampTZ, castTimestampAsTimestampTZ), 536 {types.TimestampTZ, types.TimestampTZ}: cmpLtTimestampTZ, 537 }, 538 } 539 540 type cmpOp struct { 541 op opcode.Op 542 evalFuncs map[typePair]binEvalFunc 543 } 544 545 func (op cmpOp) InferReturnType(_, _ types.T) types.T { 546 return types.Bool 547 } 548 549 func (op cmpOp) CallOnNullInput() bool { 550 return false 551 } 552 553 func (op cmpOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 554 evalFunc, ok := op.evalFuncs[typePair{left.Type(), right.Type()}] 555 if !ok { 556 return nil, fmt.Errorf("cannot evaluate %s on %s and %s", op.op, left.Type(), right.Type()) 557 } 558 return evalFunc(ctx, left, right) 559 } 560 561 func makeCmpOp(op opcode.Op) cmpOp { 562 evalFuncs := cmpOpEvalFuncs[op] 563 return cmpOp{ 564 op: op, 565 evalFuncs: evalFuncs, 566 } 567 } 568 569 type flippedCmpOp struct { 570 cmpOp 571 } 572 573 func (op flippedCmpOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 574 return op.cmpOp.Eval(ctx, right, left) 575 } 576 577 func makeFlippedCmpOp(op opcode.Op) flippedCmpOp { 578 return flippedCmpOp{makeCmpOp(op)} 579 } 580 581 type negateCmpOp struct { 582 cmpOp 583 } 584 585 func (op negateCmpOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 586 res, err := op.cmpOp.Eval(ctx, left, right) 587 if err != nil { 588 return nil, err 589 } 590 return datum.NewBool(!datum.AsBool(res)), nil 591 } 592 593 func makeNegateCmpOp(op opcode.Op) negateCmpOp { 594 return negateCmpOp{makeCmpOp(op)} 595 } 596 597 type flippedNegateCmpOp struct { 598 cmpOp 599 } 600 601 func (op flippedNegateCmpOp) Eval(ctx *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 602 res, err := op.cmpOp.Eval(ctx, right, left) 603 if err != nil { 604 return nil, err 605 } 606 return datum.NewBool(!datum.AsBool(res)), nil 607 } 608 609 func makeFlippedNegateCmpOp(op opcode.Op) flippedNegateCmpOp { 610 return flippedNegateCmpOp{makeCmpOp(op)} 611 } 612 613 func cmpEqBool(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 614 return datum.NewBool(datum.AsBool(left) == datum.AsBool(right)), nil 615 } 616 617 func cmpLtBool(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 618 // left < right is true if left is false and right is true. 619 return datum.NewBool(!datum.AsBool(left) && datum.AsBool(right)), nil 620 } 621 622 func cmpEqInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 623 return datum.NewBool(datum.AsInt(left) == datum.AsInt(right)), nil 624 } 625 626 func cmpLtInt(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 627 return datum.NewBool(datum.AsInt(left) < datum.AsInt(right)), nil 628 } 629 630 func cmpEqFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 631 return datum.NewBool(datum.AsFloat(left) == datum.AsFloat(right)), nil 632 } 633 634 func cmpLtFloat(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 635 return datum.NewBool(datum.AsFloat(left) < datum.AsFloat(right)), nil 636 } 637 638 func cmpEqDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 639 l := datum.AsDecimal(left) 640 r := datum.AsDecimal(right) 641 return datum.NewBool(l.Cmp(r) == 0), nil 642 } 643 644 func cmpLtDecimal(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 645 l := datum.AsDecimal(left) 646 r := datum.AsDecimal(right) 647 return datum.NewBool(l.Cmp(r) < 0), nil 648 } 649 650 func cmpEqString(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 651 return datum.NewBool(datum.AsString(left) == datum.AsString(right)), nil 652 } 653 654 func cmpLtString(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 655 return datum.NewBool(datum.AsString(left) < datum.AsString(right)), nil 656 } 657 658 func cmpEqBytes(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 659 return datum.NewBool(bytes.Equal(datum.AsBytes(left), datum.AsBytes(right))), nil 660 } 661 662 func cmpLtBytes(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 663 return datum.NewBool(bytes.Compare(datum.AsBytes(left), datum.AsBytes(right)) < 0), nil 664 } 665 666 func cmpEqDate(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 667 return nil, fmt.Errorf("cmpEqDate not implemented") 668 } 669 670 func cmpLtDate(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 671 return nil, fmt.Errorf("cmpLtDate not implemented") 672 } 673 674 func cmpEqTime(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 675 return nil, fmt.Errorf("cmpEqTime not implemented") 676 } 677 678 func cmpLtTime(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 679 return nil, fmt.Errorf("cmpLtTime not implemented") 680 } 681 682 func cmpEqTimeTZ(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 683 return nil, fmt.Errorf("cmpEqTimeTZ not implemented") 684 } 685 686 func cmpLtTimeTZ(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 687 return nil, fmt.Errorf("cmpLtTimeTZ not implemented") 688 } 689 690 func cmpEqTimestamp(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 691 return nil, fmt.Errorf("cmpEqTimestamp not implemented") 692 } 693 694 func cmpLtTimestamp(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 695 return nil, fmt.Errorf("cmpLtTimestamp not implemented") 696 } 697 698 func cmpEqTimestampTZ(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 699 return nil, fmt.Errorf("cmpEqTimestampTZ not implemented") 700 } 701 702 func cmpLtTimestampTZ(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 703 return nil, fmt.Errorf("cmpLtTimestampTZ not implemented") 704 } 705 706 func cmpEqVertex(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 707 return nil, fmt.Errorf("cmpEqVertex not implemented") 708 } 709 710 func cmpEqEdge(_ *stmtctx.Context, left, right datum.Datum) (datum.Datum, error) { 711 return nil, fmt.Errorf("cmpEqEdge not implemented") 712 }