github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/bit_ops.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "fmt" 19 "math" 20 "strconv" 21 "strings" 22 "unsafe" 23 24 "github.com/dolthub/vitess/go/vt/sqlparser" 25 26 "github.com/dolthub/go-mysql-server/sql" 27 "github.com/dolthub/go-mysql-server/sql/types" 28 ) 29 30 // BitOp expressions include BIT -AND, -OR and -XOR (&, | and ^) operations 31 // https://dev.mysql.com/doc/refman/8.0/en/bit-functions.html 32 type BitOp struct { 33 BinaryExpressionStub 34 Op string 35 } 36 37 var _ sql.Expression = (*BitOp)(nil) 38 var _ sql.CollationCoercible = (*BitOp)(nil) 39 40 // NewBitOp creates a new BitOp sql.Expression. 41 func NewBitOp(left, right sql.Expression, op string) *BitOp { 42 return &BitOp{BinaryExpressionStub{LeftChild: left, RightChild: right}, op} 43 } 44 45 // NewBitAnd creates a new BitOp & sql.Expression. 46 func NewBitAnd(left, right sql.Expression) *BitOp { 47 return NewBitOp(left, right, sqlparser.BitAndStr) 48 } 49 50 // NewBitOr creates a new BitOp | sql.Expression. 51 func NewBitOr(left, right sql.Expression) *BitOp { 52 return NewBitOp(left, right, sqlparser.BitOrStr) 53 } 54 55 // NewBitXor creates a new BitOp ^ sql.Expression. 56 func NewBitXor(left, right sql.Expression) *BitOp { 57 return NewBitOp(left, right, sqlparser.BitXorStr) 58 } 59 60 // NewShiftLeft creates a new BitOp << sql.Expression. 61 func NewShiftLeft(left, right sql.Expression) *BitOp { 62 return NewBitOp(left, right, sqlparser.ShiftLeftStr) 63 } 64 65 // NewShiftRight creates a new BitOp >> sql.Expression. 66 func NewShiftRight(left, right sql.Expression) *BitOp { 67 return NewBitOp(left, right, sqlparser.ShiftRightStr) 68 } 69 70 func (b *BitOp) String() string { 71 return fmt.Sprintf("(%s %s %s)", b.LeftChild, b.Op, b.RightChild) 72 } 73 74 func (b *BitOp) DebugString() string { 75 return fmt.Sprintf("(%s %s %s)", sql.DebugString(b.LeftChild), b.Op, sql.DebugString(b.RightChild)) 76 } 77 78 // IsNullable implements the sql.Expression interface. 79 func (b *BitOp) IsNullable() bool { 80 return b.BinaryExpressionStub.IsNullable() 81 } 82 83 // Type returns the greatest type for given operation. 84 func (b *BitOp) Type() sql.Type { 85 rTyp := b.RightChild.Type() 86 if types.IsDeferredType(rTyp) { 87 return rTyp 88 } 89 lTyp := b.LeftChild.Type() 90 if types.IsDeferredType(lTyp) { 91 return lTyp 92 } 93 94 if types.IsText(lTyp) || types.IsText(rTyp) { 95 return types.Float64 96 } 97 98 if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) { 99 return types.Uint64 100 } else if types.IsSigned(lTyp) && types.IsSigned(rTyp) { 101 return types.Int64 102 } 103 104 return types.Float64 105 } 106 107 // CollationCoercibility implements the interface sql.CollationCoercible. 108 func (*BitOp) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 109 return sql.Collation_binary, 5 110 } 111 112 // WithChildren implements the Expression interface. 113 func (b *BitOp) WithChildren(children ...sql.Expression) (sql.Expression, error) { 114 if len(children) != 2 { 115 return nil, sql.ErrInvalidChildrenNumber.New(b, len(children), 2) 116 } 117 return NewBitOp(children[0], children[1], b.Op), nil 118 } 119 120 // Eval implements the Expression interface. 121 func (b *BitOp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 122 lval, rval, err := b.evalLeftRight(ctx, row) 123 if err != nil { 124 return nil, err 125 } 126 127 if lval == nil || rval == nil { 128 return nil, nil 129 } 130 131 lval, rval, err = b.convertLeftRight(ctx, lval, rval) 132 if err != nil { 133 return nil, err 134 } 135 136 switch strings.ToLower(b.Op) { 137 case sqlparser.BitAndStr: 138 return bitAnd(lval, rval) 139 case sqlparser.BitOrStr: 140 return bitOr(lval, rval) 141 case sqlparser.BitXorStr: 142 return bitXor(lval, rval) 143 case sqlparser.ShiftLeftStr: 144 return shiftLeft(lval, rval) 145 case sqlparser.ShiftRightStr: 146 return shiftRight(lval, rval) 147 } 148 149 return nil, errUnableToEval.New(lval, b.Op, rval) 150 } 151 152 func (b *BitOp) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { 153 var lval, rval interface{} 154 var err error 155 156 // bit ops used with Interval error is caught at parsing the query 157 lval, err = b.LeftChild.Eval(ctx, row) 158 if err != nil { 159 return nil, nil, err 160 } 161 162 rval, err = b.RightChild.Eval(ctx, row) 163 if err != nil { 164 return nil, nil, err 165 } 166 167 return lval, rval, nil 168 } 169 170 func (b *BitOp) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}, error) { 171 typ := b.Type() 172 173 left = convertValueToType(ctx, typ, left, types.IsTime(b.LeftChild.Type())) 174 right = convertValueToType(ctx, typ, right, types.IsTime(b.RightChild.Type())) 175 176 return left, right, nil 177 } 178 179 // convertUintFromInt returns any int64 value converted to uint64 value 180 // including negative numbers. Mysql does not return negative result on 181 // bit arithmetic operations, so all results are returned in uint64 type. 182 func convertUintFromInt(n int64) uint64 { 183 intStr := strconv.FormatUint(*(*uint64)(unsafe.Pointer(&n)), 2) 184 uintVal, err := strconv.ParseUint(intStr, 2, 64) 185 if err != nil { 186 return 0 187 } 188 return uintVal 189 } 190 191 func bitAnd(lval, rval interface{}) (interface{}, error) { 192 if lval == nil || rval == nil { 193 return 0, nil 194 } 195 196 switch l := lval.(type) { 197 case float64: 198 switch r := rval.(type) { 199 case float64: 200 left := convertUintFromInt(int64(math.Round(l))) 201 right := convertUintFromInt(int64(math.Round(r))) 202 return left & right, nil 203 } 204 case uint64: 205 switch r := rval.(type) { 206 case uint64: 207 return l & r, nil 208 } 209 case int64: 210 switch r := rval.(type) { 211 case int64: 212 left := convertUintFromInt(l) 213 right := convertUintFromInt(r) 214 return left & right, nil 215 } 216 } 217 218 return nil, errUnableToCast.New(lval, rval) 219 } 220 221 func bitOr(lval, rval interface{}) (interface{}, error) { 222 if lval == nil && rval == nil { 223 return 0, nil 224 } else if lval == nil { 225 switch r := rval.(type) { 226 case float64: 227 return convertUintFromInt(int64(math.Round(r))), nil 228 case int64: 229 return convertUintFromInt(int64(math.Round(float64(r)))), nil 230 case uint64: 231 return r, nil 232 } 233 } else if rval == nil { 234 switch l := lval.(type) { 235 case float64: 236 return convertUintFromInt(int64(math.Round(l))), nil 237 case int64: 238 return convertUintFromInt(int64(math.Round(float64(l)))), nil 239 case uint64: 240 return l, nil 241 } 242 } 243 244 switch l := lval.(type) { 245 case float64: 246 switch r := rval.(type) { 247 case float64: 248 left := convertUintFromInt(int64(math.Round(l))) 249 right := convertUintFromInt(int64(math.Round(r))) 250 return left | right, nil 251 } 252 case uint64: 253 switch r := rval.(type) { 254 case uint64: 255 return l | r, nil 256 } 257 case int64: 258 switch r := rval.(type) { 259 case int64: 260 left := convertUintFromInt(l) 261 right := convertUintFromInt(r) 262 return left | right, nil 263 } 264 } 265 266 return nil, errUnableToCast.New(lval, rval) 267 } 268 269 func bitXor(lval, rval interface{}) (interface{}, error) { 270 if lval == nil && rval == nil { 271 return 0, nil 272 } else if lval == nil { 273 switch r := rval.(type) { 274 case float64: 275 return convertUintFromInt(int64(math.Round(r))), nil 276 case int64: 277 return convertUintFromInt(int64(math.Round(float64(r)))), nil 278 case uint64: 279 return r, nil 280 } 281 } else if rval == nil { 282 switch l := lval.(type) { 283 case float64: 284 return convertUintFromInt(int64(math.Round(l))), nil 285 case int64: 286 return convertUintFromInt(int64(math.Round(float64(l)))), nil 287 case uint64: 288 return l, nil 289 } 290 } 291 292 switch l := lval.(type) { 293 case float64: 294 switch r := rval.(type) { 295 case float64: 296 left := convertUintFromInt(int64(math.Round(l))) 297 right := convertUintFromInt(int64(math.Round(r))) 298 return left ^ right, nil 299 } 300 case uint64: 301 switch r := rval.(type) { 302 case uint64: 303 return l ^ r, nil 304 } 305 case int64: 306 switch r := rval.(type) { 307 case int64: 308 left := convertUintFromInt(l) 309 right := convertUintFromInt(r) 310 return left ^ right, nil 311 } 312 } 313 314 return nil, errUnableToCast.New(lval, rval) 315 } 316 317 func shiftLeft(lval, rval interface{}) (interface{}, error) { 318 if lval == nil { 319 return 0, nil 320 } 321 if rval == nil { 322 switch l := lval.(type) { 323 case float64: 324 return convertUintFromInt(int64(math.Round(l))), nil 325 case int64: 326 return convertUintFromInt(int64(math.Round(float64(l)))), nil 327 case uint64: 328 return l, nil 329 } 330 } 331 switch l := lval.(type) { 332 case float64: 333 switch r := rval.(type) { 334 case float64: 335 left := convertUintFromInt(int64(math.Round(l))) 336 right := convertUintFromInt(int64(math.Round(r))) 337 return left << right, nil 338 } 339 case uint64: 340 switch r := rval.(type) { 341 case uint64: 342 return l << r, nil 343 } 344 case int64: 345 switch r := rval.(type) { 346 case int64: 347 left := convertUintFromInt(l) 348 right := convertUintFromInt(r) 349 return left << right, nil 350 } 351 } 352 353 return nil, errUnableToCast.New(lval, rval) 354 } 355 356 func shiftRight(lval, rval interface{}) (interface{}, error) { 357 if lval == nil { 358 return 0, nil 359 } 360 if rval == nil { 361 switch l := lval.(type) { 362 case float64: 363 return convertUintFromInt(int64(math.Round(l))), nil 364 case int64: 365 return convertUintFromInt(int64(math.Round(float64(l)))), nil 366 case uint64: 367 return l, nil 368 } 369 } 370 switch l := lval.(type) { 371 case float64: 372 switch r := rval.(type) { 373 case float64: 374 left := convertUintFromInt(int64(math.Round(l))) 375 right := convertUintFromInt(int64(math.Round(r))) 376 return left >> right, nil 377 } 378 case uint64: 379 switch r := rval.(type) { 380 case uint64: 381 return l >> r, nil 382 } 383 case int64: 384 switch r := rval.(type) { 385 case int64: 386 left := convertUintFromInt(l) 387 right := convertUintFromInt(r) 388 return left >> right, nil 389 } 390 } 391 392 return nil, errUnableToCast.New(lval, rval) 393 }