github.com/dolthub/go-mysql-server@v0.18.0/sql/types/decimal.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"fmt"
    19  	"math/big"
    20  	"reflect"
    21  
    22  	"github.com/dolthub/vitess/go/sqltypes"
    23  	"github.com/dolthub/vitess/go/vt/proto/query"
    24  	"github.com/shopspring/decimal"
    25  	"gopkg.in/src-d/go-errors.v1"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  )
    29  
    30  const (
    31  	// DecimalTypeMaxPrecision returns the maximum precision allowed for the Decimal type.
    32  	DecimalTypeMaxPrecision = 65
    33  	// DecimalTypeMaxScale returns the maximum scale allowed for the Decimal type, assuming the
    34  	// maximum precision is used. For a maximum scale that is relative to the precision of a given
    35  	// decimal type, use its MaximumScale function.
    36  	DecimalTypeMaxScale = 30
    37  )
    38  
    39  var (
    40  	ErrConvertingToDecimal   = errors.NewKind("value %v is not a valid Decimal")
    41  	ErrConvertToDecimalLimit = errors.NewKind("Out of range value for column of Decimal type ")
    42  	ErrMarshalNullDecimal    = errors.NewKind("Decimal cannot marshal a null value")
    43  
    44  	decimalValueType = reflect.TypeOf(decimal.Decimal{})
    45  )
    46  
    47  type DecimalType_ struct {
    48  	exclusiveUpperBound decimal.Decimal
    49  	definesColumn       bool
    50  	precision           uint8
    51  	scale               uint8
    52  }
    53  
    54  // InternalDecimalType is a special DecimalType that is used internally for Decimal comparisons. Not intended for usage
    55  // from integrators.
    56  var InternalDecimalType sql.DecimalType = DecimalType_{
    57  	exclusiveUpperBound: decimal.New(1, int32(65)),
    58  	definesColumn:       false,
    59  	precision:           65,
    60  	scale:               30,
    61  }
    62  
    63  // CreateDecimalType creates a DecimalType for NON-TABLE-COLUMN.
    64  func CreateDecimalType(precision uint8, scale uint8) (sql.DecimalType, error) {
    65  	return createDecimalType(precision, scale, false)
    66  }
    67  
    68  // CreateColumnDecimalType creates a DecimalType for VALID-TABLE-COLUMN. Creating a decimal type for a column ensures that
    69  // when operating on instances of this type, the result will be restricted to the defined precision and scale.
    70  func CreateColumnDecimalType(precision uint8, scale uint8) (sql.DecimalType, error) {
    71  	return createDecimalType(precision, scale, true)
    72  }
    73  
    74  // createDecimalType creates a DecimalType using given precision, scale
    75  // and whether this type defines a valid table column.
    76  func createDecimalType(precision uint8, scale uint8, definesColumn bool) (sql.DecimalType, error) {
    77  	if scale > DecimalTypeMaxScale {
    78  		return nil, fmt.Errorf("Too big scale %v specified. Maximum is %v.", scale, DecimalTypeMaxScale)
    79  	}
    80  	if precision > DecimalTypeMaxPrecision {
    81  		return nil, fmt.Errorf("Too big precision %v specified. Maximum is %v.", precision, DecimalTypeMaxPrecision)
    82  	}
    83  	if scale > precision {
    84  		return nil, fmt.Errorf("Scale %v cannot be larger than the precision %v", scale, precision)
    85  	}
    86  
    87  	if precision == 0 {
    88  		precision = 10
    89  	}
    90  	return DecimalType_{
    91  		exclusiveUpperBound: decimal.New(1, int32(precision-scale)),
    92  		definesColumn:       definesColumn,
    93  		precision:           precision,
    94  		scale:               scale,
    95  	}, nil
    96  }
    97  
    98  // MustCreateDecimalType is the same as CreateDecimalType except it panics on errors and for NON-TABLE-COLUMN.
    99  func MustCreateDecimalType(precision uint8, scale uint8) sql.DecimalType {
   100  	dt, err := CreateDecimalType(precision, scale)
   101  	if err != nil {
   102  		panic(err)
   103  	}
   104  	return dt
   105  }
   106  
   107  // MustCreateColumnDecimalType is the same as CreateDecimalType except it panics on errors and for VALID-TABLE-COLUMN.
   108  func MustCreateColumnDecimalType(precision uint8, scale uint8) sql.DecimalType {
   109  	dt, err := CreateColumnDecimalType(precision, scale)
   110  	if err != nil {
   111  		panic(err)
   112  	}
   113  	return dt
   114  }
   115  
   116  // Type implements Type interface.
   117  func (t DecimalType_) Type() query.Type {
   118  	return sqltypes.Decimal
   119  }
   120  
   121  // Compare implements Type interface.
   122  func (t DecimalType_) Compare(a interface{}, b interface{}) (int, error) {
   123  	if hasNulls, res := CompareNulls(a, b); hasNulls {
   124  		return res, nil
   125  	}
   126  
   127  	af, err := t.ConvertToNullDecimal(a)
   128  	if err != nil {
   129  		return 0, err
   130  	}
   131  	bf, err := t.ConvertToNullDecimal(b)
   132  	if err != nil {
   133  		return 0, err
   134  	}
   135  
   136  	return af.Decimal.Cmp(bf.Decimal), nil
   137  }
   138  
   139  // Convert implements Type interface.
   140  func (t DecimalType_) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) {
   141  	dec, err := t.ConvertToNullDecimal(v)
   142  	if err != nil {
   143  		return nil, sql.OutOfRange, err
   144  	}
   145  	if !dec.Valid {
   146  		return nil, sql.InRange, nil
   147  	}
   148  	return t.BoundsCheck(dec.Decimal)
   149  }
   150  
   151  func (t DecimalType_) ConvertNoBoundsCheck(v interface{}) (decimal.Decimal, error) {
   152  	dec, err := t.ConvertToNullDecimal(v)
   153  	if err != nil {
   154  		return decimal.Decimal{}, err
   155  	}
   156  	if !dec.Valid {
   157  		return decimal.Decimal{}, nil
   158  	}
   159  	return dec.Decimal, nil
   160  }
   161  
   162  // ConvertToNullDecimal implements DecimalType interface.
   163  func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, error) {
   164  	if v == nil {
   165  		return decimal.NullDecimal{}, nil
   166  	}
   167  
   168  	var res decimal.Decimal
   169  
   170  	switch value := v.(type) {
   171  	case bool:
   172  		if value {
   173  			return t.ConvertToNullDecimal(decimal.NewFromInt(1))
   174  		} else {
   175  			return t.ConvertToNullDecimal(decimal.NewFromInt(0))
   176  		}
   177  	case int:
   178  		return t.ConvertToNullDecimal(int64(value))
   179  	case uint:
   180  		return t.ConvertToNullDecimal(uint64(value))
   181  	case int8:
   182  		return t.ConvertToNullDecimal(int64(value))
   183  	case uint8:
   184  		return t.ConvertToNullDecimal(uint64(value))
   185  	case int16:
   186  		return t.ConvertToNullDecimal(int64(value))
   187  	case uint16:
   188  		return t.ConvertToNullDecimal(uint64(value))
   189  	case int32:
   190  		return t.ConvertToNullDecimal(decimal.NewFromInt32(value))
   191  	case uint32:
   192  		return t.ConvertToNullDecimal(uint64(value))
   193  	case int64:
   194  		return t.ConvertToNullDecimal(decimal.NewFromInt(value))
   195  	case uint64:
   196  		return t.ConvertToNullDecimal(decimal.NewFromBigInt(new(big.Int).SetUint64(value), 0))
   197  	case float32:
   198  		return t.ConvertToNullDecimal(decimal.NewFromFloat32(value))
   199  	case float64:
   200  		return t.ConvertToNullDecimal(decimal.NewFromFloat(value))
   201  	case string:
   202  		// TODO: implement truncation here
   203  		if len(value) == 0 {
   204  			return t.ConvertToNullDecimal(decimal.NewFromInt(0))
   205  		}
   206  		var err error
   207  		res, err = decimal.NewFromString(value)
   208  		if err != nil {
   209  			// The decimal library cannot handle all of the different formats
   210  			bf, _, err := new(big.Float).SetPrec(217).Parse(value, 0)
   211  			if err != nil {
   212  				return decimal.NullDecimal{}, err
   213  			}
   214  			res, err = decimal.NewFromString(bf.Text('f', -1))
   215  			if err != nil {
   216  				return decimal.NullDecimal{}, err
   217  			}
   218  		}
   219  		return t.ConvertToNullDecimal(res)
   220  	case *big.Float:
   221  		return t.ConvertToNullDecimal(value.Text('f', -1))
   222  	case *big.Int:
   223  		return t.ConvertToNullDecimal(value.Text(10))
   224  	case *big.Rat:
   225  		return t.ConvertToNullDecimal(new(big.Float).SetRat(value))
   226  	case decimal.Decimal:
   227  		if t.definesColumn {
   228  			val, err := decimal.NewFromString(value.StringFixed(int32(t.scale)))
   229  			if err != nil {
   230  				return decimal.NullDecimal{}, err
   231  			}
   232  			res = val
   233  		} else {
   234  			res = value
   235  		}
   236  	case []uint8:
   237  		return t.ConvertToNullDecimal(string(value))
   238  	case decimal.NullDecimal:
   239  		// This is the equivalent of passing in a nil
   240  		if !value.Valid {
   241  			return decimal.NullDecimal{}, nil
   242  		}
   243  		return t.ConvertToNullDecimal(value.Decimal)
   244  	case JSONDocument:
   245  		return t.ConvertToNullDecimal(value.Val)
   246  	default:
   247  		return decimal.NullDecimal{}, ErrConvertingToDecimal.New(v)
   248  	}
   249  
   250  	return decimal.NullDecimal{Decimal: res, Valid: true}, nil
   251  }
   252  
   253  func (t DecimalType_) BoundsCheck(v decimal.Decimal) (decimal.Decimal, sql.ConvertInRange, error) {
   254  	if -v.Exponent() > int32(t.scale) {
   255  		// TODO : add 'Data truncated' warning
   256  		v = v.Round(int32(t.scale))
   257  	}
   258  	// TODO add shortcut for common case
   259  	// ex: certain num of bits fast tracks OK
   260  	if !v.Abs().LessThan(t.exclusiveUpperBound) {
   261  		return decimal.Decimal{}, sql.InRange, ErrConvertToDecimalLimit.New()
   262  	}
   263  	return v, sql.InRange, nil
   264  }
   265  
   266  // MustConvert implements the Type interface.
   267  func (t DecimalType_) MustConvert(v interface{}) interface{} {
   268  	value, _, err := t.Convert(v)
   269  	if err != nil {
   270  		panic(err)
   271  	}
   272  	return value
   273  }
   274  
   275  // Equals implements the Type interface.
   276  func (t DecimalType_) Equals(otherType sql.Type) bool {
   277  	if ot, ok := otherType.(DecimalType_); ok {
   278  		return t.precision == ot.precision && t.scale == ot.scale
   279  	}
   280  	return false
   281  }
   282  
   283  // MaxTextResponseByteLength implements the Type interface
   284  func (t DecimalType_) MaxTextResponseByteLength(_ *sql.Context) uint32 {
   285  	if t.scale == 0 {
   286  		// if no digits are reserved for the right-hand side of the decimal point,
   287  		// just return precision plus one byte for sign
   288  		return uint32(t.precision + 1)
   289  	} else {
   290  		// otherwise return precision plus one byte for sign plus one byte for the decimal point
   291  		return uint32(t.precision + 2)
   292  	}
   293  }
   294  
   295  // Promote implements the Type interface.
   296  func (t DecimalType_) Promote() sql.Type {
   297  	if t.definesColumn {
   298  		return MustCreateColumnDecimalType(DecimalTypeMaxPrecision, t.scale)
   299  	}
   300  	return MustCreateDecimalType(DecimalTypeMaxPrecision, t.scale)
   301  }
   302  
   303  // SQL implements Type interface.
   304  func (t DecimalType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) {
   305  	if v == nil {
   306  		return sqltypes.NULL, nil
   307  	}
   308  	value, err := t.ConvertToNullDecimal(v)
   309  	if err != nil {
   310  		return sqltypes.Value{}, err
   311  	}
   312  
   313  	val := AppendAndSliceString(dest, t.DecimalValueStringFixed(value.Decimal))
   314  
   315  	return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil
   316  }
   317  
   318  // String implements Type interface.
   319  func (t DecimalType_) String() string {
   320  	return fmt.Sprintf("decimal(%v,%v)", t.precision, t.scale)
   321  }
   322  
   323  // ValueType implements Type interface.
   324  func (t DecimalType_) ValueType() reflect.Type {
   325  	return decimalValueType
   326  }
   327  
   328  // Zero implements Type interface.
   329  func (t DecimalType_) Zero() interface{} {
   330  	return decimal.NewFromInt(0)
   331  }
   332  
   333  // CollationCoercibility implements sql.CollationCoercible interface.
   334  func (DecimalType_) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   335  	return sql.Collation_binary, 5
   336  }
   337  
   338  // ExclusiveUpperBound implements DecimalType interface.
   339  func (t DecimalType_) ExclusiveUpperBound() decimal.Decimal {
   340  	return t.exclusiveUpperBound
   341  }
   342  
   343  // MaximumScale implements DecimalType interface.
   344  func (t DecimalType_) MaximumScale() uint8 {
   345  	if t.precision >= DecimalTypeMaxScale {
   346  		return DecimalTypeMaxScale
   347  	}
   348  	return t.precision
   349  }
   350  
   351  // Precision implements DecimalType interface.
   352  func (t DecimalType_) Precision() uint8 {
   353  	return t.precision
   354  }
   355  
   356  // Scale implements DecimalType interface.
   357  func (t DecimalType_) Scale() uint8 {
   358  	return t.scale
   359  }
   360  
   361  // DecimalValueStringFixed returns string value for the given decimal value. If decimal type value is for valid table column only,
   362  // it should use scale defined by the column. Otherwise, the result value should use its own precision and scale.
   363  func (t DecimalType_) DecimalValueStringFixed(v decimal.Decimal) string {
   364  	if t.definesColumn {
   365  		return v.StringFixed(int32(t.scale))
   366  	} else {
   367  		return v.StringFixed(v.Exponent() * -1)
   368  	}
   369  }