vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/evalengine.go (about) 1 /* 2 Copyright 2020 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package evalengine 18 19 import ( 20 "math" 21 "time" 22 23 "vitess.io/vitess/go/mysql/collations" 24 "vitess.io/vitess/go/sqltypes" 25 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 26 "vitess.io/vitess/go/vt/sqlparser" 27 "vitess.io/vitess/go/vt/vterrors" 28 "vitess.io/vitess/go/vt/vtgate/evalengine/internal/decimal" 29 ) 30 31 // Cast converts a Value to the target type. 32 func Cast(v sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) { 33 if v.Type() == typ || v.IsNull() { 34 return v, nil 35 } 36 vBytes, err := v.ToBytes() 37 if err != nil { 38 return v, err 39 } 40 if sqltypes.IsSigned(typ) && v.IsSigned() { 41 return sqltypes.MakeTrusted(typ, vBytes), nil 42 } 43 if sqltypes.IsUnsigned(typ) && v.IsUnsigned() { 44 return sqltypes.MakeTrusted(typ, vBytes), nil 45 } 46 if (sqltypes.IsFloat(typ) || typ == sqltypes.Decimal) && (v.IsIntegral() || v.IsFloat() || v.Type() == sqltypes.Decimal) { 47 return sqltypes.MakeTrusted(typ, vBytes), nil 48 } 49 if sqltypes.IsQuoted(typ) && (v.IsIntegral() || v.IsFloat() || v.Type() == sqltypes.Decimal || v.IsQuoted()) { 50 return sqltypes.MakeTrusted(typ, vBytes), nil 51 } 52 53 // Explicitly disallow Expression. 54 if v.Type() == sqltypes.Expression { 55 return sqltypes.NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be cast to %v", v, typ) 56 } 57 58 // If the above fast-paths were not possible, 59 // go through full validation. 60 return sqltypes.NewValue(typ, vBytes) 61 } 62 63 // ToUint64 converts Value to uint64. 64 func ToUint64(v sqltypes.Value) (uint64, error) { 65 var num EvalResult 66 if err := num.setValueIntegralNumeric(v); err != nil { 67 return 0, err 68 } 69 switch num.typeof() { 70 case sqltypes.Int64: 71 if num.uint64() > math.MaxInt64 { 72 return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: %d", num.int64()) 73 } 74 return num.uint64(), nil 75 case sqltypes.Uint64: 76 return num.uint64(), nil 77 } 78 panic("unreachable") 79 } 80 81 // ToInt64 converts Value to int64. 82 func ToInt64(v sqltypes.Value) (int64, error) { 83 var num EvalResult 84 if err := num.setValueIntegralNumeric(v); err != nil { 85 return 0, err 86 } 87 switch num.typeof() { 88 case sqltypes.Int64: 89 return num.int64(), nil 90 case sqltypes.Uint64: 91 ival := num.int64() 92 if ival < 0 { 93 return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: %d", num.uint64()) 94 } 95 return ival, nil 96 } 97 panic("unreachable") 98 } 99 100 // ToFloat64 converts Value to float64. 101 func ToFloat64(v sqltypes.Value) (float64, error) { 102 var num EvalResult 103 if err := num.setValue(v, collationNumeric); err != nil { 104 return 0, err 105 } 106 num.makeFloat() 107 return num.float64(), nil 108 } 109 110 func LiteralToValue(literal *sqlparser.Literal) (sqltypes.Value, error) { 111 lit, err := translateLiteral(literal, nil) 112 if err != nil { 113 return sqltypes.Value{}, err 114 } 115 return lit.Val.Value(), nil 116 } 117 118 // ToNative converts Value to a native go type. 119 // Decimal is returned as []byte. 120 func ToNative(v sqltypes.Value) (any, error) { 121 var out any 122 var err error 123 switch { 124 case v.Type() == sqltypes.Null: 125 // no-op 126 case v.IsSigned(): 127 return ToInt64(v) 128 case v.IsUnsigned(): 129 return ToUint64(v) 130 case v.IsFloat(): 131 return ToFloat64(v) 132 case v.IsQuoted() || v.Type() == sqltypes.Bit || v.Type() == sqltypes.Decimal: 133 out, err = v.ToBytes() 134 case v.Type() == sqltypes.Expression: 135 err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be converted to a go type", v) 136 } 137 return out, err 138 } 139 140 func compareNumeric(v1, v2 *EvalResult) (int, error) { 141 // upcast all <64 bit numeric types to 64 bit, e.g. int8 -> int64, uint8 -> uint64, float32 -> float64 142 // so we don't have to consider integer types which aren't 64 bit 143 v1.upcastNumeric() 144 v2.upcastNumeric() 145 146 // Equalize the types the same way MySQL does 147 // https://dev.mysql.com/doc/refman/8.0/en/type-conversion.html 148 switch v1.typeof() { 149 case sqltypes.Int64: 150 switch v2.typeof() { 151 case sqltypes.Uint64: 152 if v1.uint64() > math.MaxInt64 { 153 return -1, nil 154 } 155 v1.setUint64(v1.uint64()) 156 case sqltypes.Float64: 157 v1.setFloat(float64(v1.int64())) 158 case sqltypes.Decimal: 159 v1.setDecimal(decimal.NewFromInt(v1.int64()), 0) 160 } 161 case sqltypes.Uint64: 162 switch v2.typeof() { 163 case sqltypes.Int64: 164 if v2.uint64() > math.MaxInt64 { 165 return 1, nil 166 } 167 v2.setUint64(v2.uint64()) 168 case sqltypes.Float64: 169 v1.setFloat(float64(v1.uint64())) 170 case sqltypes.Decimal: 171 v1.setDecimal(decimal.NewFromUint(v1.uint64()), 0) 172 } 173 case sqltypes.Float64: 174 switch v2.typeof() { 175 case sqltypes.Int64: 176 v2.setFloat(float64(v2.int64())) 177 case sqltypes.Uint64: 178 if v1.float64() < 0 { 179 return -1, nil 180 } 181 v2.setFloat(float64(v2.uint64())) 182 case sqltypes.Decimal: 183 f, ok := v2.decimal().Float64() 184 if !ok { 185 return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range") 186 } 187 v2.setFloat(f) 188 } 189 case sqltypes.Decimal: 190 switch v2.typeof() { 191 case sqltypes.Int64: 192 v2.setDecimal(decimal.NewFromInt(v2.int64()), 0) 193 case sqltypes.Uint64: 194 v2.setDecimal(decimal.NewFromUint(v2.uint64()), 0) 195 case sqltypes.Float64: 196 f, ok := v1.decimal().Float64() 197 if !ok { 198 return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.DataOutOfRange, "DECIMAL value is out of range") 199 } 200 v1.setFloat(f) 201 } 202 } 203 204 // Both values are of the same type. 205 switch v1.typeof() { 206 case sqltypes.Int64: 207 v1v, v2v := v1.int64(), v2.int64() 208 switch { 209 case v1v == v2v: 210 return 0, nil 211 case v1v < v2v: 212 return -1, nil 213 } 214 case sqltypes.Uint64: 215 switch { 216 case v1.uint64() == v2.uint64(): 217 return 0, nil 218 case v1.uint64() < v2.uint64(): 219 return -1, nil 220 } 221 case sqltypes.Float64: 222 v1v, v2v := v1.float64(), v2.float64() 223 switch { 224 case v1v == v2v: 225 return 0, nil 226 case v1v < v2v: 227 return -1, nil 228 } 229 case sqltypes.Decimal: 230 return v1.decimal().Cmp(v2.decimal()), nil 231 } 232 return 1, nil 233 } 234 235 func parseDate(expr *EvalResult) (t time.Time, err error) { 236 switch expr.typeof() { 237 case sqltypes.Date: 238 t, err = sqlparser.ParseDate(expr.string()) 239 case sqltypes.Timestamp, sqltypes.Datetime: 240 t, err = sqlparser.ParseDateTime(expr.string()) 241 case sqltypes.Time: 242 t, err = sqlparser.ParseTime(expr.string()) 243 } 244 return 245 } 246 247 // matchExprWithAnyDateFormat formats the given expr (usually a string) to a date using the first format 248 // that does not return an error. 249 func matchExprWithAnyDateFormat(expr *EvalResult) (t time.Time, err error) { 250 t, err = sqlparser.ParseDate(expr.string()) 251 if err == nil { 252 return 253 } 254 t, err = sqlparser.ParseDateTime(expr.string()) 255 if err == nil { 256 return 257 } 258 t, err = sqlparser.ParseTime(expr.string()) 259 return 260 } 261 262 // Date comparison based on: 263 // - https://dev.mysql.com/doc/refman/8.0/en/type-conversion.html 264 // - https://dev.mysql.com/doc/refman/8.0/en/date-and-time-type-conversion.html 265 func compareDates(l, r *EvalResult) (int, error) { 266 lTime, err := parseDate(l) 267 if err != nil { 268 return 0, err 269 } 270 rTime, err := parseDate(r) 271 if err != nil { 272 return 0, err 273 } 274 275 return compareGoTimes(lTime, rTime) 276 } 277 278 func compareDateAndString(l, r *EvalResult) (int, error) { 279 var lTime, rTime time.Time 280 var err error 281 switch { 282 case sqltypes.IsDate(l.typeof()): 283 lTime, err = parseDate(l) 284 if err != nil { 285 return 0, err 286 } 287 rTime, err = matchExprWithAnyDateFormat(r) 288 if err != nil { 289 return 0, err 290 } 291 case l.isTextual(): 292 rTime, err = parseDate(r) 293 if err != nil { 294 return 0, err 295 } 296 lTime, err = matchExprWithAnyDateFormat(l) 297 if err != nil { 298 return 0, err 299 } 300 } 301 return compareGoTimes(lTime, rTime) 302 } 303 304 func compareGoTimes(lTime, rTime time.Time) (int, error) { 305 if lTime.Before(rTime) { 306 return -1, nil 307 } 308 if lTime.After(rTime) { 309 return 1, nil 310 } 311 return 0, nil 312 } 313 314 // More on string collations coercibility on MySQL documentation: 315 // - https://dev.mysql.com/doc/refman/8.0/en/charset-collation-coercibility.html 316 func compareStrings(l, r *EvalResult) int { 317 coll, err := mergeCollations(l, r) 318 if err != nil { 319 throwEvalError(err) 320 } 321 collation := collations.Local().LookupByID(coll) 322 if collation == nil { 323 panic("unknown collation after coercion") 324 } 325 return collation.Collate(l.bytes(), r.bytes(), false) 326 }