
     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package expression
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"strconv"
    21  	"strings"
    22  	"unsafe"
    24  	""
    26  	""
    27  	""
    28  )
    30  // BitOp expressions include BIT -AND, -OR and -XOR (&, | and ^) operations
    31  //
    32  type BitOp struct {
    33  	BinaryExpressionStub
    34  	Op string
    35  }
    37  var _ sql.Expression = (*BitOp)(nil)
    38  var _ sql.CollationCoercible = (*BitOp)(nil)
    40  // NewBitOp creates a new BitOp sql.Expression.
    41  func NewBitOp(left, right sql.Expression, op string) *BitOp {
    42  	return &BitOp{BinaryExpressionStub{LeftChild: left, RightChild: right}, op}
    43  }
    45  // NewBitAnd creates a new BitOp & sql.Expression.
    46  func NewBitAnd(left, right sql.Expression) *BitOp {
    47  	return NewBitOp(left, right, sqlparser.BitAndStr)
    48  }
    50  // NewBitOr creates a new BitOp | sql.Expression.
    51  func NewBitOr(left, right sql.Expression) *BitOp {
    52  	return NewBitOp(left, right, sqlparser.BitOrStr)
    53  }
    55  // NewBitXor creates a new BitOp ^ sql.Expression.
    56  func NewBitXor(left, right sql.Expression) *BitOp {
    57  	return NewBitOp(left, right, sqlparser.BitXorStr)
    58  }
    60  // NewShiftLeft creates a new BitOp << sql.Expression.
    61  func NewShiftLeft(left, right sql.Expression) *BitOp {
    62  	return NewBitOp(left, right, sqlparser.ShiftLeftStr)
    63  }
    65  // NewShiftRight creates a new BitOp >> sql.Expression.
    66  func NewShiftRight(left, right sql.Expression) *BitOp {
    67  	return NewBitOp(left, right, sqlparser.ShiftRightStr)
    68  }
    70  func (b *BitOp) String() string {
    71  	return fmt.Sprintf("(%s %s %s)", b.LeftChild, b.Op, b.RightChild)
    72  }
    74  func (b *BitOp) DebugString() string {
    75  	return fmt.Sprintf("(%s %s %s)", sql.DebugString(b.LeftChild), b.Op, sql.DebugString(b.RightChild))
    76  }
    78  // IsNullable implements the sql.Expression interface.
    79  func (b *BitOp) IsNullable() bool {
    80  	return b.BinaryExpressionStub.IsNullable()
    81  }
    83  // Type returns the greatest type for given operation.
    84  func (b *BitOp) Type() sql.Type {
    85  	rTyp := b.RightChild.Type()
    86  	if types.IsDeferredType(rTyp) {
    87  		return rTyp
    88  	}
    89  	lTyp := b.LeftChild.Type()
    90  	if types.IsDeferredType(lTyp) {
    91  		return lTyp
    92  	}
    94  	if types.IsText(lTyp) || types.IsText(rTyp) {
    95  		return types.Float64
    96  	}
    98  	if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) {
    99  		return types.Uint64
   100  	} else if types.IsSigned(lTyp) && types.IsSigned(rTyp) {
   101  		return types.Int64
   102  	}
   104  	return types.Float64
   105  }
   107  // CollationCoercibility implements the interface sql.CollationCoercible.
   108  func (*BitOp) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   109  	return sql.Collation_binary, 5
   110  }
   112  // WithChildren implements the Expression interface.
   113  func (b *BitOp) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   114  	if len(children) != 2 {
   115  		return nil, sql.ErrInvalidChildrenNumber.New(b, len(children), 2)
   116  	}
   117  	return NewBitOp(children[0], children[1], b.Op), nil
   118  }
   120  // Eval implements the Expression interface.
   121  func (b *BitOp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   122  	lval, rval, err := b.evalLeftRight(ctx, row)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   127  	if lval == nil || rval == nil {
   128  		return nil, nil
   129  	}
   131  	lval, rval, err = b.convertLeftRight(ctx, lval, rval)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   136  	switch strings.ToLower(b.Op) {
   137  	case sqlparser.BitAndStr:
   138  		return bitAnd(lval, rval)
   139  	case sqlparser.BitOrStr:
   140  		return bitOr(lval, rval)
   141  	case sqlparser.BitXorStr:
   142  		return bitXor(lval, rval)
   143  	case sqlparser.ShiftLeftStr:
   144  		return shiftLeft(lval, rval)
   145  	case sqlparser.ShiftRightStr:
   146  		return shiftRight(lval, rval)
   147  	}
   149  	return nil, errUnableToEval.New(lval, b.Op, rval)
   150  }
   152  func (b *BitOp) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) {
   153  	var lval, rval interface{}
   154  	var err error
   156  	// bit ops used with Interval error is caught at parsing the query
   157  	lval, err = b.LeftChild.Eval(ctx, row)
   158  	if err != nil {
   159  		return nil, nil, err
   160  	}
   162  	rval, err = b.RightChild.Eval(ctx, row)
   163  	if err != nil {
   164  		return nil, nil, err
   165  	}
   167  	return lval, rval, nil
   168  }
   170  func (b *BitOp) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}, error) {
   171  	typ := b.Type()
   173  	left = convertValueToType(ctx, typ, left, types.IsTime(b.LeftChild.Type()))
   174  	right = convertValueToType(ctx, typ, right, types.IsTime(b.RightChild.Type()))
   176  	return left, right, nil
   177  }
   179  // convertUintFromInt returns any int64 value converted to uint64 value
   180  // including negative numbers. Mysql does not return negative result on
   181  // bit arithmetic operations, so all results are returned in uint64 type.
   182  func convertUintFromInt(n int64) uint64 {
   183  	intStr := strconv.FormatUint(*(*uint64)(unsafe.Pointer(&n)), 2)
   184  	uintVal, err := strconv.ParseUint(intStr, 2, 64)
   185  	if err != nil {
   186  		return 0
   187  	}
   188  	return uintVal
   189  }
   191  func bitAnd(lval, rval interface{}) (interface{}, error) {
   192  	if lval == nil || rval == nil {
   193  		return 0, nil
   194  	}
   196  	switch l := lval.(type) {
   197  	case float64:
   198  		switch r := rval.(type) {
   199  		case float64:
   200  			left := convertUintFromInt(int64(math.Round(l)))
   201  			right := convertUintFromInt(int64(math.Round(r)))
   202  			return left & right, nil
   203  		}
   204  	case uint64:
   205  		switch r := rval.(type) {
   206  		case uint64:
   207  			return l & r, nil
   208  		}
   209  	case int64:
   210  		switch r := rval.(type) {
   211  		case int64:
   212  			left := convertUintFromInt(l)
   213  			right := convertUintFromInt(r)
   214  			return left & right, nil
   215  		}
   216  	}
   218  	return nil, errUnableToCast.New(lval, rval)
   219  }
   221  func bitOr(lval, rval interface{}) (interface{}, error) {
   222  	if lval == nil && rval == nil {
   223  		return 0, nil
   224  	} else if lval == nil {
   225  		switch r := rval.(type) {
   226  		case float64:
   227  			return convertUintFromInt(int64(math.Round(r))), nil
   228  		case int64:
   229  			return convertUintFromInt(int64(math.Round(float64(r)))), nil
   230  		case uint64:
   231  			return r, nil
   232  		}
   233  	} else if rval == nil {
   234  		switch l := lval.(type) {
   235  		case float64:
   236  			return convertUintFromInt(int64(math.Round(l))), nil
   237  		case int64:
   238  			return convertUintFromInt(int64(math.Round(float64(l)))), nil
   239  		case uint64:
   240  			return l, nil
   241  		}
   242  	}
   244  	switch l := lval.(type) {
   245  	case float64:
   246  		switch r := rval.(type) {
   247  		case float64:
   248  			left := convertUintFromInt(int64(math.Round(l)))
   249  			right := convertUintFromInt(int64(math.Round(r)))
   250  			return left | right, nil
   251  		}
   252  	case uint64:
   253  		switch r := rval.(type) {
   254  		case uint64:
   255  			return l | r, nil
   256  		}
   257  	case int64:
   258  		switch r := rval.(type) {
   259  		case int64:
   260  			left := convertUintFromInt(l)
   261  			right := convertUintFromInt(r)
   262  			return left | right, nil
   263  		}
   264  	}
   266  	return nil, errUnableToCast.New(lval, rval)
   267  }
   269  func bitXor(lval, rval interface{}) (interface{}, error) {
   270  	if lval == nil && rval == nil {
   271  		return 0, nil
   272  	} else if lval == nil {
   273  		switch r := rval.(type) {
   274  		case float64:
   275  			return convertUintFromInt(int64(math.Round(r))), nil
   276  		case int64:
   277  			return convertUintFromInt(int64(math.Round(float64(r)))), nil
   278  		case uint64:
   279  			return r, nil
   280  		}
   281  	} else if rval == nil {
   282  		switch l := lval.(type) {
   283  		case float64:
   284  			return convertUintFromInt(int64(math.Round(l))), nil
   285  		case int64:
   286  			return convertUintFromInt(int64(math.Round(float64(l)))), nil
   287  		case uint64:
   288  			return l, nil
   289  		}
   290  	}
   292  	switch l := lval.(type) {
   293  	case float64:
   294  		switch r := rval.(type) {
   295  		case float64:
   296  			left := convertUintFromInt(int64(math.Round(l)))
   297  			right := convertUintFromInt(int64(math.Round(r)))
   298  			return left ^ right, nil
   299  		}
   300  	case uint64:
   301  		switch r := rval.(type) {
   302  		case uint64:
   303  			return l ^ r, nil
   304  		}
   305  	case int64:
   306  		switch r := rval.(type) {
   307  		case int64:
   308  			left := convertUintFromInt(l)
   309  			right := convertUintFromInt(r)
   310  			return left ^ right, nil
   311  		}
   312  	}
   314  	return nil, errUnableToCast.New(lval, rval)
   315  }
   317  func shiftLeft(lval, rval interface{}) (interface{}, error) {
   318  	if lval == nil {
   319  		return 0, nil
   320  	}
   321  	if rval == nil {
   322  		switch l := lval.(type) {
   323  		case float64:
   324  			return convertUintFromInt(int64(math.Round(l))), nil
   325  		case int64:
   326  			return convertUintFromInt(int64(math.Round(float64(l)))), nil
   327  		case uint64:
   328  			return l, nil
   329  		}
   330  	}
   331  	switch l := lval.(type) {
   332  	case float64:
   333  		switch r := rval.(type) {
   334  		case float64:
   335  			left := convertUintFromInt(int64(math.Round(l)))
   336  			right := convertUintFromInt(int64(math.Round(r)))
   337  			return left << right, nil
   338  		}
   339  	case uint64:
   340  		switch r := rval.(type) {
   341  		case uint64:
   342  			return l << r, nil
   343  		}
   344  	case int64:
   345  		switch r := rval.(type) {
   346  		case int64:
   347  			left := convertUintFromInt(l)
   348  			right := convertUintFromInt(r)
   349  			return left << right, nil
   350  		}
   351  	}
   353  	return nil, errUnableToCast.New(lval, rval)
   354  }
   356  func shiftRight(lval, rval interface{}) (interface{}, error) {
   357  	if lval == nil {
   358  		return 0, nil
   359  	}
   360  	if rval == nil {
   361  		switch l := lval.(type) {
   362  		case float64:
   363  			return convertUintFromInt(int64(math.Round(l))), nil
   364  		case int64:
   365  			return convertUintFromInt(int64(math.Round(float64(l)))), nil
   366  		case uint64:
   367  			return l, nil
   368  		}
   369  	}
   370  	switch l := lval.(type) {
   371  	case float64:
   372  		switch r := rval.(type) {
   373  		case float64:
   374  			left := convertUintFromInt(int64(math.Round(l)))
   375  			right := convertUintFromInt(int64(math.Round(r)))
   376  			return left >> right, nil
   377  		}
   378  	case uint64:
   379  		switch r := rval.(type) {
   380  		case uint64:
   381  			return l >> r, nil
   382  		}
   383  	case int64:
   384  		switch r := rval.(type) {
   385  		case int64:
   386  			left := convertUintFromInt(l)
   387  			right := convertUintFromInt(r)
   388  			return left >> right, nil
   389  		}
   390  	}
   392  	return nil, errUnableToCast.New(lval, rval)
   393  }