github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/types/decimal.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"errors"
     8  	"fmt"
     9  
    10  	"github.com/ericlagergren/decimal"
    11  )
    12  
    13  var (
    14  	// DecimalContext is a global context that will be used when creating
    15  	// decimals. It should be set once before any sqlboiler and then
    16  	// assumed to be read-only after sqlboiler's first use.
    17  	DecimalContext decimal.Context
    18  
    19  	nullBytes = []byte("null")
    20  )
    21  
    22  var (
    23  	_ driver.Valuer = Decimal{}
    24  	_ driver.Valuer = NullDecimal{}
    25  	_ sql.Scanner   = &Decimal{}
    26  	_ sql.Scanner   = &NullDecimal{}
    27  )
    28  
    29  // Decimal is a DECIMAL in sql. Its zero value is valid for use with both
    30  // Value and Scan.
    31  //
    32  // Although decimal can represent NaN and Infinity it will return an error
    33  // if an attempt to store these values in the database is made.
    34  //
    35  // Because it cannot be nil, when Big is nil Value() will return "0"
    36  // It will error if an attempt to Scan() a "null" value into it.
    37  type Decimal struct {
    38  	*decimal.Big
    39  }
    40  
    41  // NullDecimal is the same as Decimal, but allows the Big pointer to be nil.
    42  // See documentation for Decimal for more details.
    43  //
    44  // When going into a database, if Big is nil it's value will be "null".
    45  type NullDecimal struct {
    46  	*decimal.Big
    47  }
    48  
    49  // NewDecimal creates a new decimal from a decimal
    50  func NewDecimal(d *decimal.Big) Decimal {
    51  	return Decimal{Big: d}
    52  }
    53  
    54  // NewNullDecimal creates a new null decimal from a decimal
    55  func NewNullDecimal(d *decimal.Big) NullDecimal {
    56  	return NullDecimal{Big: d}
    57  }
    58  
    59  // Value implements driver.Valuer.
    60  func (d Decimal) Value() (driver.Value, error) {
    61  	return decimalValue(d.Big, false)
    62  }
    63  
    64  // Scan implements sql.Scanner.
    65  func (d *Decimal) Scan(val interface{}) error {
    66  	newD, err := decimalScan(d.Big, val, false)
    67  	if err != nil {
    68  		return err
    69  	}
    70  
    71  	d.Big = newD
    72  	return nil
    73  }
    74  
    75  // UnmarshalJSON allows marshalling JSON into a null pointer
    76  func (d *Decimal) UnmarshalJSON(data []byte) error {
    77  	if d.Big == nil {
    78  		d.Big = new(decimal.Big)
    79  	}
    80  
    81  	return d.Big.UnmarshalJSON(data)
    82  }
    83  
    84  // Randomize implements sqlboiler's randomize interface
    85  func (d *Decimal) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) {
    86  	d.Big = randomDecimal(nextInt, fieldType, false)
    87  }
    88  
    89  // Value implements driver.Valuer.
    90  func (n NullDecimal) Value() (driver.Value, error) {
    91  	return decimalValue(n.Big, true)
    92  }
    93  
    94  // Scan implements sql.Scanner.
    95  func (n *NullDecimal) Scan(val interface{}) error {
    96  	newD, err := decimalScan(n.Big, val, true)
    97  	if err != nil {
    98  		return err
    99  	}
   100  
   101  	n.Big = newD
   102  	return nil
   103  }
   104  
   105  // UnmarshalJSON allows marshalling JSON into a null pointer
   106  func (n *NullDecimal) UnmarshalJSON(data []byte) error {
   107  	if bytes.Equal(data, nullBytes) {
   108  		if n != nil {
   109  			n.Big = nil
   110  		}
   111  		return nil
   112  	}
   113  
   114  	if n.Big == nil {
   115  		n.Big = decimal.WithContext(DecimalContext)
   116  	}
   117  
   118  	return n.Big.UnmarshalJSON(data)
   119  }
   120  
   121  // String impl
   122  func (n NullDecimal) String() string {
   123  	if n.Big == nil {
   124  		return "nil"
   125  	}
   126  	return n.Big.String()
   127  }
   128  
   129  func (n NullDecimal) Format(f fmt.State, verb rune) {
   130  	if n.Big == nil {
   131  		fmt.Fprint(f, "nil")
   132  		return
   133  	}
   134  	n.Big.Format(f, verb)
   135  }
   136  
   137  // MarshalJSON marshals a decimal value
   138  func (n NullDecimal) MarshalJSON() ([]byte, error) {
   139  	if n.Big == nil {
   140  		return nullBytes, nil
   141  	}
   142  
   143  	return n.Big.MarshalText()
   144  }
   145  
   146  // IsZero implements qmhelper.Nullable
   147  func (n NullDecimal) IsZero() bool {
   148  	return n.Big == nil
   149  }
   150  
   151  // Randomize implements sqlboiler's randomize interface
   152  func (n *NullDecimal) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) {
   153  	n.Big = randomDecimal(nextInt, fieldType, shouldBeNull)
   154  }
   155  
   156  func randomDecimal(nextInt func() int64, fieldType string, shouldBeNull bool) *decimal.Big {
   157  	if shouldBeNull {
   158  		return nil
   159  	}
   160  
   161  	randVal := fmt.Sprintf("%d.%d", nextInt()%10, nextInt()%10)
   162  	random, success := decimal.WithContext(DecimalContext).SetString(randVal)
   163  	if !success {
   164  		panic("randVal could not be turned into a decimal")
   165  	}
   166  
   167  	return random
   168  }
   169  
   170  func decimalValue(d *decimal.Big, canNull bool) (driver.Value, error) {
   171  	if d == nil {
   172  		if canNull {
   173  			return nil, nil
   174  		}
   175  
   176  		return "0", nil
   177  	}
   178  
   179  	if d.IsNaN(0) {
   180  		return nil, errors.New("refusing to allow NaN into database")
   181  	}
   182  	if d.IsInf(0) {
   183  		return nil, errors.New("refusing to allow infinity into database")
   184  	}
   185  
   186  	return d.String(), nil
   187  }
   188  
   189  func decimalScan(d *decimal.Big, val interface{}, canNull bool) (*decimal.Big, error) {
   190  	if val == nil {
   191  		if !canNull {
   192  			return nil, errors.New("null cannot be scanned into decimal")
   193  		}
   194  
   195  		return nil, nil
   196  	}
   197  
   198  	switch t := val.(type) {
   199  	case float64:
   200  		if d == nil {
   201  			d = decimal.WithContext(DecimalContext)
   202  		}
   203  		d.SetFloat64(t)
   204  		return d, nil
   205  	case int64:
   206  		return decimal.WithContext(DecimalContext).SetMantScale(t, 0), nil
   207  	case string:
   208  		if d == nil {
   209  			d = decimal.WithContext(DecimalContext)
   210  		}
   211  		if _, ok := d.SetString(t); !ok {
   212  			if err := d.Context.Err(); err != nil {
   213  				return nil, err
   214  			}
   215  			return nil, fmt.Errorf("invalid decimal syntax: %q", t)
   216  		}
   217  		return d, nil
   218  	case []byte:
   219  		if d == nil {
   220  			d = decimal.WithContext(DecimalContext)
   221  		}
   222  		if err := d.UnmarshalText(t); err != nil {
   223  			return nil, err
   224  		}
   225  		return d, nil
   226  	default:
   227  		return nil, fmt.Errorf("cannot scan decimal value: %#v", val)
   228  	}
   229  }