code.vegaprotocol.io/vega@v0.79.0/libs/num/uint.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package num
    17  
    18  import (
    19  	"database/sql/driver"
    20  	"errors"
    21  	"fmt"
    22  	"math/big"
    23  	"sort"
    24  
    25  	"github.com/holiman/uint256"
    26  )
    27  
    28  var (
    29  	// max uint256 value.
    30  	big1    = big.NewInt(1)
    31  	maxU256 = new(big.Int).Sub(new(big.Int).Lsh(big1, 256), big1)
    32  
    33  	// initialise max variable.
    34  	maxUint = setMaxUint()
    35  	zero    = NewUint(0)
    36  	one     = NewUint(1)
    37  )
    38  
    39  // Uint A wrapper for a big unsigned int.
    40  type Uint struct {
    41  	u uint256.Int
    42  }
    43  
    44  // NewUint creates a new Uint with the value of the
    45  // uint64 passed as a parameter.
    46  func NewUint(val uint64) *Uint {
    47  	return &Uint{*uint256.NewInt(val)}
    48  }
    49  
    50  func UintOne() *Uint {
    51  	return one.Clone()
    52  }
    53  
    54  func UintZero() *Uint {
    55  	return zero.Clone()
    56  }
    57  
    58  // only called once, to initialise maxUint.
    59  func setMaxUint() *Uint {
    60  	b, _ := UintFromBig(maxU256)
    61  	return b
    62  }
    63  
    64  // MaxUint returns max value for uint256.
    65  func MaxUint() *Uint {
    66  	return maxUint.Clone()
    67  }
    68  
    69  // Min returns the smallest of the 2 numbers.
    70  func Min(a, b *Uint) *Uint {
    71  	if a.LT(b) {
    72  		return a.Clone()
    73  	}
    74  	return b.Clone()
    75  }
    76  
    77  // Max returns the largest of the 2 numbers.
    78  func Max(a, b *Uint) *Uint {
    79  	if a.GT(b) {
    80  		return a.Clone()
    81  	}
    82  	return b.Clone()
    83  }
    84  
    85  // UintFromHex instantiate a uint from and hex string.
    86  func UintFromHex(hex string) (*Uint, error) {
    87  	u, err := uint256.FromHex(hex)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	return &Uint{*u}, nil
    92  }
    93  
    94  // UintFromBig construct a new Uint with a big.Int
    95  // returns true if overflow happened.
    96  func UintFromBig(b *big.Int) (*Uint, bool) {
    97  	u, ok := uint256.FromBig(b)
    98  	// ok means an overflow happened
    99  	if ok {
   100  		return NewUint(0), true
   101  	}
   102  	return &Uint{*u}, false
   103  }
   104  
   105  // UintFromBig construct a new Uint with a big.Int
   106  // panics if overflow happened.
   107  func MustUintFromBig(b *big.Int) *Uint {
   108  	u, ok := uint256.FromBig(b)
   109  	// ok means an overflow happened
   110  	if ok {
   111  		panic("uint underflow")
   112  	}
   113  	return &Uint{*u}
   114  }
   115  
   116  // UintFromBytes allows for the conversion from Uint.Bytes() back to a Uint.
   117  func UintFromBytes(b []byte) *Uint {
   118  	u := &Uint{
   119  		u: uint256.Int{},
   120  	}
   121  	u.u.SetBytes(b)
   122  	return u
   123  }
   124  
   125  // UintFromDecimal returns a decimal version of the Uint, setting the bool to true if overflow occurred.
   126  func UintFromDecimal(d Decimal) (*Uint, bool) {
   127  	u, ok := d.Uint()
   128  	return &Uint{*u}, ok
   129  }
   130  
   131  func UintFromDecimalWithFraction(d Decimal) (*Uint, Decimal) {
   132  	u, ok := UintFromDecimal(d)
   133  	if ok {
   134  		return u, Decimal{}
   135  	}
   136  	return u, DecimalPart(d)
   137  }
   138  
   139  // UintFromUint64 allows for the conversion from uint64.
   140  func UintFromUint64(ui uint64) *Uint {
   141  	u := &Uint{
   142  		u: uint256.Int{},
   143  	}
   144  	u.u.SetUint64(ui)
   145  	return u
   146  }
   147  
   148  // UnmarshalJSON implements the json.Unmarshaler interface.
   149  func (u *Uint) UnmarshalJSON(numericBytes []byte) error {
   150  	if string(numericBytes) == "null" {
   151  		return nil
   152  	}
   153  
   154  	str, err := unquoteIfQuoted(numericBytes)
   155  	if err != nil {
   156  		return fmt.Errorf("error decoding string '%s': %s", numericBytes, err)
   157  	}
   158  
   159  	numeric, overflown := UintFromString(str, 10)
   160  	if overflown {
   161  		return errors.New("overflowing value")
   162  	}
   163  	*u = *numeric
   164  	return nil
   165  }
   166  
   167  // MarshalJSON implements the json.Marshaler interface.
   168  func (u Uint) MarshalJSON() ([]byte, error) {
   169  	return []byte(u.String()), nil
   170  }
   171  
   172  // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation
   173  // is already used when encoding to text, this method stores that string as []byte.
   174  func (u *Uint) UnmarshalBinary(data []byte) error {
   175  	u.u.SetBytes(data)
   176  	return nil
   177  }
   178  
   179  // MarshalBinary implements the encoding.BinaryMarshaler interface.
   180  func (u Uint) MarshalBinary() (data []byte, err error) {
   181  	return u.u.Bytes(), nil
   182  }
   183  
   184  // Scan implements the sql.Scanner interface for database deserialization.
   185  func (u *Uint) Scan(value interface{}) error {
   186  	return u.u.Scan(value)
   187  }
   188  
   189  // Value implements the driver.Valuer interface for database serialization.
   190  func (u Uint) Value() (driver.Value, error) {
   191  	return u.String(), nil
   192  }
   193  
   194  // ToDecimal returns the value of the Uint as a Decimal.
   195  func (u *Uint) ToDecimal() Decimal {
   196  	return DecimalFromUint(u)
   197  }
   198  
   199  // UintFromString created a new Uint from a string
   200  // interpreted using the give base.
   201  // A big.Int is used to read the string, so
   202  // all error related to big.Int parsing applied here.
   203  // will return true if an error/overflow happened.
   204  func UintFromString(str string, base int) (*Uint, bool) {
   205  	b, ok := big.NewInt(0).SetString(str, base)
   206  	if !ok {
   207  		return NewUint(0), true
   208  	}
   209  	return UintFromBig(b)
   210  }
   211  
   212  // MustUintFromString creates a new Uint from a string
   213  // interpreted using the given base.
   214  // A big.Int is used to read the string, so
   215  // all errors related to big.Int parsing are applied here.
   216  // The core will panic if an error/overflow happens.
   217  func MustUintFromString(str string, base int) *Uint {
   218  	b, ok := big.NewInt(0).SetString(str, base)
   219  	if !ok {
   220  		panic("uint underflow")
   221  	}
   222  	return MustUintFromBig(b)
   223  }
   224  
   225  // Sum just removes the need to write num.NewUint(0).Sum(x, y, z)
   226  // so you can write num.Sum(x, y, z) instead, equivalent to x + y + z.
   227  func Sum(vals ...*Uint) *Uint {
   228  	return NewUint(0).AddSum(vals...)
   229  }
   230  
   231  func (u *Uint) Set(oth *Uint) *Uint {
   232  	u.u.Set(&oth.u)
   233  	return u
   234  }
   235  
   236  func (u *Uint) SetUint64(val uint64) *Uint {
   237  	u.u.SetUint64(val)
   238  	return u
   239  }
   240  
   241  func (u Uint) Uint64() uint64 {
   242  	return u.u.Uint64()
   243  }
   244  
   245  func (u Uint) BigInt() *big.Int {
   246  	return u.u.ToBig()
   247  }
   248  
   249  func (u Uint) Float64() float64 {
   250  	d := DecimalFromUint(&u)
   251  	retVal, _ := d.Float64()
   252  	return retVal
   253  }
   254  
   255  // Add will add x and y then store the result
   256  // into u
   257  // this is equivalent to:
   258  // `u = x + y`
   259  // u is returned for convenience, no
   260  // new variable is created.
   261  func (u *Uint) Add(x, y *Uint) *Uint {
   262  	u.u.Add(&x.u, &y.u)
   263  	return u
   264  }
   265  
   266  // AddUint64 will add x and y then store the result
   267  // into u
   268  // this is equivalent to:
   269  // `u = x + y`
   270  // u is returned for convenience, no
   271  // new variable is created.
   272  func (u *Uint) AddUint64(x *Uint, y uint64) *Uint {
   273  	u.u.AddUint64(&x.u, y)
   274  	return u
   275  }
   276  
   277  // AddSum adds multiple values at the same time to a given uint
   278  // so x.AddSum(y, z) is equivalent to x + y + z.
   279  func (u *Uint) AddSum(vals ...*Uint) *Uint {
   280  	for _, x := range vals {
   281  		u.u.Add(&u.u, &x.u)
   282  	}
   283  	return u
   284  }
   285  
   286  // AddOverflow will subtract y to x then store the result
   287  // into u
   288  // this is equivalent to:
   289  // `u = x - y`
   290  // u is returned for convenience, no
   291  // new variable is created.
   292  // False is returned if an overflow occurred.
   293  func (u *Uint) AddOverflow(x, y *Uint) (*Uint, bool) {
   294  	_, ok := u.u.AddOverflow(&x.u, &y.u)
   295  	return u, ok
   296  }
   297  
   298  // Sub will subtract y from x then store the result
   299  // into u
   300  // this is equivalent to:
   301  // `u = x - y`
   302  // u is returned for convenience, no
   303  // new variable is created.
   304  func (u *Uint) Sub(x, y *Uint) *Uint {
   305  	u.u.Sub(&x.u, &y.u)
   306  	return u
   307  }
   308  
   309  // SubOverflow will subtract y to x then store the result
   310  // into u
   311  // this is equivalent to:
   312  // `u = x - y`
   313  // u is returned for convenience, no
   314  // new variable is created.
   315  // False is returned if an overflow occurred.
   316  func (u *Uint) SubOverflow(x, y *Uint) (*Uint, bool) {
   317  	_, ok := u.u.SubOverflow(&x.u, &y.u)
   318  	return u, ok
   319  }
   320  
   321  // Delta will subtract y from x and store the result
   322  // unless x-y overflowed, in which case the neg field will be set
   323  // and the result of y - x is set instead.
   324  func (u *Uint) Delta(x, y *Uint) (*Uint, bool) {
   325  	// y is the bigger value - swap the two
   326  	if y.GT(x) {
   327  		_ = u.Sub(y, x)
   328  		return u, true
   329  	}
   330  	_ = u.Sub(x, y)
   331  	return u, false
   332  }
   333  
   334  // DeltaI will subtract y from x and store the result.
   335  func (u *Uint) DeltaI(x, y *Uint) *Int {
   336  	d, s := u.Delta(x, y)
   337  	return IntFromUint(d, !s)
   338  }
   339  
   340  // Mul will multiply x and y then store the result
   341  // into u
   342  // this is equivalent to:
   343  // `u = x * y`
   344  // u is returned for convenience, no
   345  // new variable is created.
   346  func (u *Uint) Mul(x, y *Uint) *Uint {
   347  	u.u.Mul(&x.u, &y.u)
   348  	return u
   349  }
   350  
   351  // Div will divide x by y then store the result
   352  // into u
   353  // this is equivalent to:
   354  // `u = x / y`
   355  // u is returned for convenience, no
   356  // new variable is created.
   357  func (u *Uint) Div(x, y *Uint) *Uint {
   358  	u.u.Div(&x.u, &y.u)
   359  	return u
   360  }
   361  
   362  // Mod sets u to the modulus x%y for y != 0 and returns u.
   363  // If y == 0, u is set to 0.
   364  func (u *Uint) Mod(x, y *Uint) *Uint {
   365  	u.u.Mod(&x.u, &y.u)
   366  	return u
   367  }
   368  
   369  func (u *Uint) Exp(x, y *Uint) *Uint {
   370  	u.u.Exp(&x.u, &y.u)
   371  	return u
   372  }
   373  
   374  // Sqrt calculates the integer-square root of the given Uint.
   375  func (u *Uint) SqrtInt(x *Uint) *Uint {
   376  	u.u.Sqrt(&x.u)
   377  	return u
   378  }
   379  
   380  // Sqrt calculates the square root in decimals of the given Uint.
   381  func (u *Uint) Sqrt(x *Uint) Decimal {
   382  	if x.IsZero() {
   383  		return DecimalZero()
   384  	}
   385  	// integer sqrt is a good approximation
   386  	r := UintOne().SqrtInt(x).ToDecimal()
   387  
   388  	// so now lets do a few iterations using Heron's Method to get closer
   389  	// r_i = (r + u/r) / 2
   390  	ud := x.ToDecimal()
   391  	for i := 0; i < 6; i++ {
   392  		r = r.Add(ud.Div(r)).Div(DecimalFromInt64(2))
   393  	}
   394  
   395  	return r
   396  }
   397  
   398  // LT with check if the value stored in u is
   399  // lesser than oth
   400  // this is equivalent to:
   401  // `u < oth`.
   402  func (u Uint) LT(oth *Uint) bool {
   403  	return u.u.Lt(&oth.u)
   404  }
   405  
   406  // LTUint64 with check if the value stored in u is
   407  // lesser than oth
   408  // this is equivalent to:
   409  // `u < oth`.
   410  func (u Uint) LTUint64(oth uint64) bool {
   411  	return u.u.LtUint64(oth)
   412  }
   413  
   414  // LTE with check if the value stored in u is
   415  // lesser than or equal to oth
   416  // this is equivalent to:
   417  // `u <= oth`.
   418  func (u Uint) LTE(oth *Uint) bool {
   419  	return u.u.Lt(&oth.u) || u.u.Eq(&oth.u)
   420  }
   421  
   422  // LTEUint64 with check if the value stored in u is
   423  // lesser than or equal to oth
   424  // this is equivalent to:
   425  // `u <= oth`.
   426  func (u Uint) LTEUint64(oth uint64) bool {
   427  	return u.u.LtUint64(oth) || u.EQUint64(oth)
   428  }
   429  
   430  // EQ with check if the value stored in u is
   431  // equal to oth
   432  // this is equivalent to:
   433  // `u == oth`.
   434  func (u Uint) EQ(oth *Uint) bool {
   435  	return u.u.Eq(&oth.u)
   436  }
   437  
   438  // EQUint64 with check if the value stored in u is
   439  // equal to oth
   440  // this is equivalent to:
   441  // `u == oth`.
   442  func (u Uint) EQUint64(oth uint64) bool {
   443  	return u.u.Eq(uint256.NewInt(oth))
   444  }
   445  
   446  // NEQ with check if the value stored in u is
   447  // different than oth
   448  // this is equivalent to:
   449  // `u != oth`.
   450  func (u Uint) NEQ(oth *Uint) bool {
   451  	return !u.u.Eq(&oth.u)
   452  }
   453  
   454  // NEQUint64 with check if the value stored in u is
   455  // different than oth
   456  // this is equivalent to:
   457  // `u != oth`.
   458  func (u Uint) NEQUint64(oth uint64) bool {
   459  	return !u.u.Eq(uint256.NewInt(oth))
   460  }
   461  
   462  // GT with check if the value stored in u is
   463  // greater than oth
   464  // this is equivalent to:
   465  // `u > oth`.
   466  func (u Uint) GT(oth *Uint) bool {
   467  	return u.u.Gt(&oth.u)
   468  }
   469  
   470  // GTUint64 with check if the value stored in u is
   471  // greater than oth
   472  // this is equivalent to:
   473  // `u > oth`.
   474  func (u Uint) GTUint64(oth uint64) bool {
   475  	return u.u.GtUint64(oth)
   476  }
   477  
   478  // GTE with check if the value stored in u is
   479  // greater than or equal to oth
   480  // this is equivalent to:
   481  // `u >= oth`.
   482  func (u Uint) GTE(oth *Uint) bool {
   483  	return u.u.Gt(&oth.u) || u.u.Eq(&oth.u)
   484  }
   485  
   486  // GTEUint64 with check if the value stored in u is
   487  // greater than or equal to oth
   488  // this is equivalent to:
   489  // `u >= oth`.
   490  func (u Uint) GTEUint64(oth uint64) bool {
   491  	return u.u.GtUint64(oth) || u.EQUint64(oth)
   492  }
   493  
   494  // IsZero return whether u == 0 or not.
   495  func (u Uint) IsZero() bool {
   496  	return u.u.IsZero()
   497  }
   498  
   499  // IsNegative returns whether the value is < 0.
   500  func (u Uint) IsNegative() bool {
   501  	return u.u.Sign() == -1
   502  }
   503  
   504  // Copy create a copy of the uint
   505  // this if the equivalent to:
   506  // u = x.
   507  func (u *Uint) Copy(x *Uint) *Uint {
   508  	u.u = x.u
   509  	return u
   510  }
   511  
   512  // Clone create copy of this value
   513  // this is the equivalent to:
   514  // x := u.
   515  func (u Uint) Clone() *Uint {
   516  	return &Uint{u.u}
   517  }
   518  
   519  // Hex returns the hexadecimal representation
   520  // of the stored value.
   521  func (u Uint) Hex() string {
   522  	return u.u.Hex()
   523  }
   524  
   525  // String returns the stored value as a string
   526  // this is internally using big.Int.String().
   527  func (u Uint) String() string {
   528  	return u.u.ToBig().String()
   529  }
   530  
   531  // Format implement fmt.Formatter.
   532  func (u Uint) Format(s fmt.State, ch rune) {
   533  	u.u.Format(s, ch)
   534  }
   535  
   536  // Bytes return the internal representation
   537  // of the Uint as [32]bytes, BigEndian encoded
   538  // array.
   539  func (u Uint) Bytes() [32]byte {
   540  	return u.u.Bytes32()
   541  }
   542  
   543  // UintToUint64 convert a uint to uint64
   544  // return 0 if nil.
   545  func UintToUint64(u *Uint) uint64 {
   546  	if u != nil {
   547  		return u.Uint64()
   548  	}
   549  	return 0
   550  }
   551  
   552  // UintToString convert a uint to uint64
   553  // return "0" if nil.
   554  func UintToString(u *Uint) string {
   555  	if u != nil {
   556  		return u.String()
   557  	}
   558  	return "0"
   559  }
   560  
   561  // Median calculates the median of the slice of uints.
   562  // it is assumed that no nils are allowed, no zeros are allowed.
   563  func Median(nums []*Uint) *Uint {
   564  	if nums == nil {
   565  		return nil
   566  	}
   567  	numsCopy := make([]*Uint, 0, len(nums))
   568  	for _, u := range nums {
   569  		if u != nil && !u.IsZero() {
   570  			numsCopy = append(numsCopy, u.Clone())
   571  		}
   572  	}
   573  	sort.Slice(numsCopy, func(i, j int) bool {
   574  		return numsCopy[i].LT(numsCopy[j])
   575  	})
   576  	if len(numsCopy) == 0 {
   577  		return nil
   578  	}
   579  
   580  	mid := len(numsCopy) / 2
   581  	if len(numsCopy)%2 == 1 {
   582  		return numsCopy[mid]
   583  	}
   584  	return UintZero().Div(Sum(numsCopy[mid], numsCopy[mid-1]), NewUint(2))
   585  }
   586  
   587  func unquoteIfQuoted(value interface{}) (string, error) {
   588  	var bytes []byte
   589  
   590  	switch v := value.(type) {
   591  	case string:
   592  		bytes = []byte(v)
   593  	case []byte:
   594  		bytes = v
   595  	default:
   596  		return "", fmt.Errorf("could not convert value '%+v' to byte array of type '%T'",
   597  			value, value)
   598  	}
   599  
   600  	// If the amount is quoted, strip the quotes
   601  	if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' {
   602  		bytes = bytes[1 : len(bytes)-1]
   603  	}
   604  	return string(bytes), nil
   605  }