amuz.es/src/go/misc@v1.0.1/types/big.go (about)

     1  package types
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"github.com/ericlagergren/decimal"
     6  	"github.com/pkg/errors"
     7  )
     8  
     9  // Decimal is a DECIMAL in sql. Its zero value is valid for use with both
    10  // Value and Scan.
    11  //
    12  // Although decimal can represent NaN and Infinity it will return an error
    13  // if an attempt to store these values in the database is made.
    14  //
    15  // Because it cannot be nil, when Big is nil Value() will return "0"
    16  // It will error if an attempt to Scan() a "null" value into it.
    17  type Decimal struct {
    18  	decimal.Big
    19  }
    20  
    21  // NullDecimal is the same as Decimal, but allows the Big pointer to be nil.
    22  // See docmentation for Decimal for more details.
    23  //
    24  // When going into a database, if Big is nil it's value will be "null".
    25  type NullDecimal struct {
    26  	*decimal.Big
    27  }
    28  
    29  // NewDecimal creates a new decimal from a decimal
    30  func NewDecimal(d *decimal.Big) (num Decimal) {
    31  	if d != nil {
    32  		num.Big.Copy(d)
    33  	}
    34  	return
    35  }
    36  
    37  // NewNullDecimal creates a new null decimal from a decimal
    38  func NewNullDecimal(d *decimal.Big) NullDecimal {
    39  	return NullDecimal{Big: d}
    40  }
    41  
    42  // Value implements driver.Valuer.
    43  func (d Decimal) Value() (driver.Value, error) {
    44  	return decimalValue(&d.Big, false)
    45  }
    46  
    47  // Scan implements sql.Scanner.
    48  func (d *Decimal) Scan(val any) (err error) {
    49  	_, err = decimalScan(&d.Big, val, false)
    50  	return
    51  }
    52  
    53  // Value implements driver.Valuer.
    54  func (n NullDecimal) Value() (driver.Value, error) {
    55  	return decimalValue(n.Big, true)
    56  }
    57  
    58  // Scan implements sql.Scanner.
    59  func (n *NullDecimal) Scan(val any) error {
    60  	newD, err := decimalScan(n.Big, val, true)
    61  	if err != nil {
    62  		return err
    63  	}
    64  
    65  	n.Big = newD
    66  	return nil
    67  }
    68  
    69  func decimalValue(d *decimal.Big, canNull bool) (driver.Value, error) {
    70  	if canNull && d == nil {
    71  		return nil, nil
    72  	}
    73  
    74  	if d.IsNaN(0) {
    75  		return nil, errors.New("refusing to allow NaN into database")
    76  	}
    77  	if d.IsInf(0) {
    78  		return nil, errors.New("refusing to allow infinity into database")
    79  	}
    80  
    81  	return d.String(), nil
    82  }
    83  
    84  func decimalScan(d *decimal.Big, val any, canNull bool) (*decimal.Big, error) {
    85  	if val == nil {
    86  		if !canNull {
    87  			return nil, errors.New("null cannot be scanned into decimal")
    88  		}
    89  
    90  		return nil, nil
    91  	}
    92  
    93  	if d == nil {
    94  		d = new(decimal.Big)
    95  	}
    96  
    97  	switch t := val.(type) {
    98  	case float64:
    99  		d.SetFloat64(t)
   100  		return d, nil
   101  	case string:
   102  		if _, ok := d.SetString(t); !ok {
   103  			if err := d.Context.Err(); err != nil {
   104  				return nil, err
   105  			}
   106  			return nil, errors.Errorf("invalid decimal syntax: %q", t)
   107  		}
   108  		return d, nil
   109  	case []byte:
   110  		if err := d.UnmarshalText(t); err != nil {
   111  			return nil, err
   112  		}
   113  		return d, nil
   114  	default:
   115  		return nil, errors.Errorf("cannot scan decimal value: %#v", val)
   116  	}
   117  }