github.com/dolthub/go-mysql-server@v0.18.0/sql/types/decimal.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package types 16 17 import ( 18 "fmt" 19 "math/big" 20 "reflect" 21 22 "github.com/dolthub/vitess/go/sqltypes" 23 "github.com/dolthub/vitess/go/vt/proto/query" 24 "github.com/shopspring/decimal" 25 "gopkg.in/src-d/go-errors.v1" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 ) 29 30 const ( 31 // DecimalTypeMaxPrecision returns the maximum precision allowed for the Decimal type. 32 DecimalTypeMaxPrecision = 65 33 // DecimalTypeMaxScale returns the maximum scale allowed for the Decimal type, assuming the 34 // maximum precision is used. For a maximum scale that is relative to the precision of a given 35 // decimal type, use its MaximumScale function. 36 DecimalTypeMaxScale = 30 37 ) 38 39 var ( 40 ErrConvertingToDecimal = errors.NewKind("value %v is not a valid Decimal") 41 ErrConvertToDecimalLimit = errors.NewKind("Out of range value for column of Decimal type ") 42 ErrMarshalNullDecimal = errors.NewKind("Decimal cannot marshal a null value") 43 44 decimalValueType = reflect.TypeOf(decimal.Decimal{}) 45 ) 46 47 type DecimalType_ struct { 48 exclusiveUpperBound decimal.Decimal 49 definesColumn bool 50 precision uint8 51 scale uint8 52 } 53 54 // InternalDecimalType is a special DecimalType that is used internally for Decimal comparisons. Not intended for usage 55 // from integrators. 56 var InternalDecimalType sql.DecimalType = DecimalType_{ 57 exclusiveUpperBound: decimal.New(1, int32(65)), 58 definesColumn: false, 59 precision: 65, 60 scale: 30, 61 } 62 63 // CreateDecimalType creates a DecimalType for NON-TABLE-COLUMN. 64 func CreateDecimalType(precision uint8, scale uint8) (sql.DecimalType, error) { 65 return createDecimalType(precision, scale, false) 66 } 67 68 // CreateColumnDecimalType creates a DecimalType for VALID-TABLE-COLUMN. Creating a decimal type for a column ensures that 69 // when operating on instances of this type, the result will be restricted to the defined precision and scale. 70 func CreateColumnDecimalType(precision uint8, scale uint8) (sql.DecimalType, error) { 71 return createDecimalType(precision, scale, true) 72 } 73 74 // createDecimalType creates a DecimalType using given precision, scale 75 // and whether this type defines a valid table column. 76 func createDecimalType(precision uint8, scale uint8, definesColumn bool) (sql.DecimalType, error) { 77 if scale > DecimalTypeMaxScale { 78 return nil, fmt.Errorf("Too big scale %v specified. Maximum is %v.", scale, DecimalTypeMaxScale) 79 } 80 if precision > DecimalTypeMaxPrecision { 81 return nil, fmt.Errorf("Too big precision %v specified. Maximum is %v.", precision, DecimalTypeMaxPrecision) 82 } 83 if scale > precision { 84 return nil, fmt.Errorf("Scale %v cannot be larger than the precision %v", scale, precision) 85 } 86 87 if precision == 0 { 88 precision = 10 89 } 90 return DecimalType_{ 91 exclusiveUpperBound: decimal.New(1, int32(precision-scale)), 92 definesColumn: definesColumn, 93 precision: precision, 94 scale: scale, 95 }, nil 96 } 97 98 // MustCreateDecimalType is the same as CreateDecimalType except it panics on errors and for NON-TABLE-COLUMN. 99 func MustCreateDecimalType(precision uint8, scale uint8) sql.DecimalType { 100 dt, err := CreateDecimalType(precision, scale) 101 if err != nil { 102 panic(err) 103 } 104 return dt 105 } 106 107 // MustCreateColumnDecimalType is the same as CreateDecimalType except it panics on errors and for VALID-TABLE-COLUMN. 108 func MustCreateColumnDecimalType(precision uint8, scale uint8) sql.DecimalType { 109 dt, err := CreateColumnDecimalType(precision, scale) 110 if err != nil { 111 panic(err) 112 } 113 return dt 114 } 115 116 // Type implements Type interface. 117 func (t DecimalType_) Type() query.Type { 118 return sqltypes.Decimal 119 } 120 121 // Compare implements Type interface. 122 func (t DecimalType_) Compare(a interface{}, b interface{}) (int, error) { 123 if hasNulls, res := CompareNulls(a, b); hasNulls { 124 return res, nil 125 } 126 127 af, err := t.ConvertToNullDecimal(a) 128 if err != nil { 129 return 0, err 130 } 131 bf, err := t.ConvertToNullDecimal(b) 132 if err != nil { 133 return 0, err 134 } 135 136 return af.Decimal.Cmp(bf.Decimal), nil 137 } 138 139 // Convert implements Type interface. 140 func (t DecimalType_) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { 141 dec, err := t.ConvertToNullDecimal(v) 142 if err != nil { 143 return nil, sql.OutOfRange, err 144 } 145 if !dec.Valid { 146 return nil, sql.InRange, nil 147 } 148 return t.BoundsCheck(dec.Decimal) 149 } 150 151 func (t DecimalType_) ConvertNoBoundsCheck(v interface{}) (decimal.Decimal, error) { 152 dec, err := t.ConvertToNullDecimal(v) 153 if err != nil { 154 return decimal.Decimal{}, err 155 } 156 if !dec.Valid { 157 return decimal.Decimal{}, nil 158 } 159 return dec.Decimal, nil 160 } 161 162 // ConvertToNullDecimal implements DecimalType interface. 163 func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, error) { 164 if v == nil { 165 return decimal.NullDecimal{}, nil 166 } 167 168 var res decimal.Decimal 169 170 switch value := v.(type) { 171 case bool: 172 if value { 173 return t.ConvertToNullDecimal(decimal.NewFromInt(1)) 174 } else { 175 return t.ConvertToNullDecimal(decimal.NewFromInt(0)) 176 } 177 case int: 178 return t.ConvertToNullDecimal(int64(value)) 179 case uint: 180 return t.ConvertToNullDecimal(uint64(value)) 181 case int8: 182 return t.ConvertToNullDecimal(int64(value)) 183 case uint8: 184 return t.ConvertToNullDecimal(uint64(value)) 185 case int16: 186 return t.ConvertToNullDecimal(int64(value)) 187 case uint16: 188 return t.ConvertToNullDecimal(uint64(value)) 189 case int32: 190 return t.ConvertToNullDecimal(decimal.NewFromInt32(value)) 191 case uint32: 192 return t.ConvertToNullDecimal(uint64(value)) 193 case int64: 194 return t.ConvertToNullDecimal(decimal.NewFromInt(value)) 195 case uint64: 196 return t.ConvertToNullDecimal(decimal.NewFromBigInt(new(big.Int).SetUint64(value), 0)) 197 case float32: 198 return t.ConvertToNullDecimal(decimal.NewFromFloat32(value)) 199 case float64: 200 return t.ConvertToNullDecimal(decimal.NewFromFloat(value)) 201 case string: 202 // TODO: implement truncation here 203 if len(value) == 0 { 204 return t.ConvertToNullDecimal(decimal.NewFromInt(0)) 205 } 206 var err error 207 res, err = decimal.NewFromString(value) 208 if err != nil { 209 // The decimal library cannot handle all of the different formats 210 bf, _, err := new(big.Float).SetPrec(217).Parse(value, 0) 211 if err != nil { 212 return decimal.NullDecimal{}, err 213 } 214 res, err = decimal.NewFromString(bf.Text('f', -1)) 215 if err != nil { 216 return decimal.NullDecimal{}, err 217 } 218 } 219 return t.ConvertToNullDecimal(res) 220 case *big.Float: 221 return t.ConvertToNullDecimal(value.Text('f', -1)) 222 case *big.Int: 223 return t.ConvertToNullDecimal(value.Text(10)) 224 case *big.Rat: 225 return t.ConvertToNullDecimal(new(big.Float).SetRat(value)) 226 case decimal.Decimal: 227 if t.definesColumn { 228 val, err := decimal.NewFromString(value.StringFixed(int32(t.scale))) 229 if err != nil { 230 return decimal.NullDecimal{}, err 231 } 232 res = val 233 } else { 234 res = value 235 } 236 case []uint8: 237 return t.ConvertToNullDecimal(string(value)) 238 case decimal.NullDecimal: 239 // This is the equivalent of passing in a nil 240 if !value.Valid { 241 return decimal.NullDecimal{}, nil 242 } 243 return t.ConvertToNullDecimal(value.Decimal) 244 case JSONDocument: 245 return t.ConvertToNullDecimal(value.Val) 246 default: 247 return decimal.NullDecimal{}, ErrConvertingToDecimal.New(v) 248 } 249 250 return decimal.NullDecimal{Decimal: res, Valid: true}, nil 251 } 252 253 func (t DecimalType_) BoundsCheck(v decimal.Decimal) (decimal.Decimal, sql.ConvertInRange, error) { 254 if -v.Exponent() > int32(t.scale) { 255 // TODO : add 'Data truncated' warning 256 v = v.Round(int32(t.scale)) 257 } 258 // TODO add shortcut for common case 259 // ex: certain num of bits fast tracks OK 260 if !v.Abs().LessThan(t.exclusiveUpperBound) { 261 return decimal.Decimal{}, sql.InRange, ErrConvertToDecimalLimit.New() 262 } 263 return v, sql.InRange, nil 264 } 265 266 // MustConvert implements the Type interface. 267 func (t DecimalType_) MustConvert(v interface{}) interface{} { 268 value, _, err := t.Convert(v) 269 if err != nil { 270 panic(err) 271 } 272 return value 273 } 274 275 // Equals implements the Type interface. 276 func (t DecimalType_) Equals(otherType sql.Type) bool { 277 if ot, ok := otherType.(DecimalType_); ok { 278 return t.precision == ot.precision && t.scale == ot.scale 279 } 280 return false 281 } 282 283 // MaxTextResponseByteLength implements the Type interface 284 func (t DecimalType_) MaxTextResponseByteLength(_ *sql.Context) uint32 { 285 if t.scale == 0 { 286 // if no digits are reserved for the right-hand side of the decimal point, 287 // just return precision plus one byte for sign 288 return uint32(t.precision + 1) 289 } else { 290 // otherwise return precision plus one byte for sign plus one byte for the decimal point 291 return uint32(t.precision + 2) 292 } 293 } 294 295 // Promote implements the Type interface. 296 func (t DecimalType_) Promote() sql.Type { 297 if t.definesColumn { 298 return MustCreateColumnDecimalType(DecimalTypeMaxPrecision, t.scale) 299 } 300 return MustCreateDecimalType(DecimalTypeMaxPrecision, t.scale) 301 } 302 303 // SQL implements Type interface. 304 func (t DecimalType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { 305 if v == nil { 306 return sqltypes.NULL, nil 307 } 308 value, err := t.ConvertToNullDecimal(v) 309 if err != nil { 310 return sqltypes.Value{}, err 311 } 312 313 val := AppendAndSliceString(dest, t.DecimalValueStringFixed(value.Decimal)) 314 315 return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil 316 } 317 318 // String implements Type interface. 319 func (t DecimalType_) String() string { 320 return fmt.Sprintf("decimal(%v,%v)", t.precision, t.scale) 321 } 322 323 // ValueType implements Type interface. 324 func (t DecimalType_) ValueType() reflect.Type { 325 return decimalValueType 326 } 327 328 // Zero implements Type interface. 329 func (t DecimalType_) Zero() interface{} { 330 return decimal.NewFromInt(0) 331 } 332 333 // CollationCoercibility implements sql.CollationCoercible interface. 334 func (DecimalType_) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 335 return sql.Collation_binary, 5 336 } 337 338 // ExclusiveUpperBound implements DecimalType interface. 339 func (t DecimalType_) ExclusiveUpperBound() decimal.Decimal { 340 return t.exclusiveUpperBound 341 } 342 343 // MaximumScale implements DecimalType interface. 344 func (t DecimalType_) MaximumScale() uint8 { 345 if t.precision >= DecimalTypeMaxScale { 346 return DecimalTypeMaxScale 347 } 348 return t.precision 349 } 350 351 // Precision implements DecimalType interface. 352 func (t DecimalType_) Precision() uint8 { 353 return t.precision 354 } 355 356 // Scale implements DecimalType interface. 357 func (t DecimalType_) Scale() uint8 { 358 return t.scale 359 } 360 361 // DecimalValueStringFixed returns string value for the given decimal value. If decimal type value is for valid table column only, 362 // it should use scale defined by the column. Otherwise, the result value should use its own precision and scale. 363 func (t DecimalType_) DecimalValueStringFixed(v decimal.Decimal) string { 364 if t.definesColumn { 365 return v.StringFixed(int32(t.scale)) 366 } else { 367 return v.StringFixed(v.Exponent() * -1) 368 } 369 }