code.vegaprotocol.io/vega@v0.79.0/libs/num/uint.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package num 17 18 import ( 19 "database/sql/driver" 20 "errors" 21 "fmt" 22 "math/big" 23 "sort" 24 25 "github.com/holiman/uint256" 26 ) 27 28 var ( 29 // max uint256 value. 30 big1 = big.NewInt(1) 31 maxU256 = new(big.Int).Sub(new(big.Int).Lsh(big1, 256), big1) 32 33 // initialise max variable. 34 maxUint = setMaxUint() 35 zero = NewUint(0) 36 one = NewUint(1) 37 ) 38 39 // Uint A wrapper for a big unsigned int. 40 type Uint struct { 41 u uint256.Int 42 } 43 44 // NewUint creates a new Uint with the value of the 45 // uint64 passed as a parameter. 46 func NewUint(val uint64) *Uint { 47 return &Uint{*uint256.NewInt(val)} 48 } 49 50 func UintOne() *Uint { 51 return one.Clone() 52 } 53 54 func UintZero() *Uint { 55 return zero.Clone() 56 } 57 58 // only called once, to initialise maxUint. 59 func setMaxUint() *Uint { 60 b, _ := UintFromBig(maxU256) 61 return b 62 } 63 64 // MaxUint returns max value for uint256. 65 func MaxUint() *Uint { 66 return maxUint.Clone() 67 } 68 69 // Min returns the smallest of the 2 numbers. 70 func Min(a, b *Uint) *Uint { 71 if a.LT(b) { 72 return a.Clone() 73 } 74 return b.Clone() 75 } 76 77 // Max returns the largest of the 2 numbers. 78 func Max(a, b *Uint) *Uint { 79 if a.GT(b) { 80 return a.Clone() 81 } 82 return b.Clone() 83 } 84 85 // UintFromHex instantiate a uint from and hex string. 86 func UintFromHex(hex string) (*Uint, error) { 87 u, err := uint256.FromHex(hex) 88 if err != nil { 89 return nil, err 90 } 91 return &Uint{*u}, nil 92 } 93 94 // UintFromBig construct a new Uint with a big.Int 95 // returns true if overflow happened. 96 func UintFromBig(b *big.Int) (*Uint, bool) { 97 u, ok := uint256.FromBig(b) 98 // ok means an overflow happened 99 if ok { 100 return NewUint(0), true 101 } 102 return &Uint{*u}, false 103 } 104 105 // UintFromBig construct a new Uint with a big.Int 106 // panics if overflow happened. 107 func MustUintFromBig(b *big.Int) *Uint { 108 u, ok := uint256.FromBig(b) 109 // ok means an overflow happened 110 if ok { 111 panic("uint underflow") 112 } 113 return &Uint{*u} 114 } 115 116 // UintFromBytes allows for the conversion from Uint.Bytes() back to a Uint. 117 func UintFromBytes(b []byte) *Uint { 118 u := &Uint{ 119 u: uint256.Int{}, 120 } 121 u.u.SetBytes(b) 122 return u 123 } 124 125 // UintFromDecimal returns a decimal version of the Uint, setting the bool to true if overflow occurred. 126 func UintFromDecimal(d Decimal) (*Uint, bool) { 127 u, ok := d.Uint() 128 return &Uint{*u}, ok 129 } 130 131 func UintFromDecimalWithFraction(d Decimal) (*Uint, Decimal) { 132 u, ok := UintFromDecimal(d) 133 if ok { 134 return u, Decimal{} 135 } 136 return u, DecimalPart(d) 137 } 138 139 // UintFromUint64 allows for the conversion from uint64. 140 func UintFromUint64(ui uint64) *Uint { 141 u := &Uint{ 142 u: uint256.Int{}, 143 } 144 u.u.SetUint64(ui) 145 return u 146 } 147 148 // UnmarshalJSON implements the json.Unmarshaler interface. 149 func (u *Uint) UnmarshalJSON(numericBytes []byte) error { 150 if string(numericBytes) == "null" { 151 return nil 152 } 153 154 str, err := unquoteIfQuoted(numericBytes) 155 if err != nil { 156 return fmt.Errorf("error decoding string '%s': %s", numericBytes, err) 157 } 158 159 numeric, overflown := UintFromString(str, 10) 160 if overflown { 161 return errors.New("overflowing value") 162 } 163 *u = *numeric 164 return nil 165 } 166 167 // MarshalJSON implements the json.Marshaler interface. 168 func (u Uint) MarshalJSON() ([]byte, error) { 169 return []byte(u.String()), nil 170 } 171 172 // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation 173 // is already used when encoding to text, this method stores that string as []byte. 174 func (u *Uint) UnmarshalBinary(data []byte) error { 175 u.u.SetBytes(data) 176 return nil 177 } 178 179 // MarshalBinary implements the encoding.BinaryMarshaler interface. 180 func (u Uint) MarshalBinary() (data []byte, err error) { 181 return u.u.Bytes(), nil 182 } 183 184 // Scan implements the sql.Scanner interface for database deserialization. 185 func (u *Uint) Scan(value interface{}) error { 186 return u.u.Scan(value) 187 } 188 189 // Value implements the driver.Valuer interface for database serialization. 190 func (u Uint) Value() (driver.Value, error) { 191 return u.String(), nil 192 } 193 194 // ToDecimal returns the value of the Uint as a Decimal. 195 func (u *Uint) ToDecimal() Decimal { 196 return DecimalFromUint(u) 197 } 198 199 // UintFromString created a new Uint from a string 200 // interpreted using the give base. 201 // A big.Int is used to read the string, so 202 // all error related to big.Int parsing applied here. 203 // will return true if an error/overflow happened. 204 func UintFromString(str string, base int) (*Uint, bool) { 205 b, ok := big.NewInt(0).SetString(str, base) 206 if !ok { 207 return NewUint(0), true 208 } 209 return UintFromBig(b) 210 } 211 212 // MustUintFromString creates a new Uint from a string 213 // interpreted using the given base. 214 // A big.Int is used to read the string, so 215 // all errors related to big.Int parsing are applied here. 216 // The core will panic if an error/overflow happens. 217 func MustUintFromString(str string, base int) *Uint { 218 b, ok := big.NewInt(0).SetString(str, base) 219 if !ok { 220 panic("uint underflow") 221 } 222 return MustUintFromBig(b) 223 } 224 225 // Sum just removes the need to write num.NewUint(0).Sum(x, y, z) 226 // so you can write num.Sum(x, y, z) instead, equivalent to x + y + z. 227 func Sum(vals ...*Uint) *Uint { 228 return NewUint(0).AddSum(vals...) 229 } 230 231 func (u *Uint) Set(oth *Uint) *Uint { 232 u.u.Set(&oth.u) 233 return u 234 } 235 236 func (u *Uint) SetUint64(val uint64) *Uint { 237 u.u.SetUint64(val) 238 return u 239 } 240 241 func (u Uint) Uint64() uint64 { 242 return u.u.Uint64() 243 } 244 245 func (u Uint) BigInt() *big.Int { 246 return u.u.ToBig() 247 } 248 249 func (u Uint) Float64() float64 { 250 d := DecimalFromUint(&u) 251 retVal, _ := d.Float64() 252 return retVal 253 } 254 255 // Add will add x and y then store the result 256 // into u 257 // this is equivalent to: 258 // `u = x + y` 259 // u is returned for convenience, no 260 // new variable is created. 261 func (u *Uint) Add(x, y *Uint) *Uint { 262 u.u.Add(&x.u, &y.u) 263 return u 264 } 265 266 // AddUint64 will add x and y then store the result 267 // into u 268 // this is equivalent to: 269 // `u = x + y` 270 // u is returned for convenience, no 271 // new variable is created. 272 func (u *Uint) AddUint64(x *Uint, y uint64) *Uint { 273 u.u.AddUint64(&x.u, y) 274 return u 275 } 276 277 // AddSum adds multiple values at the same time to a given uint 278 // so x.AddSum(y, z) is equivalent to x + y + z. 279 func (u *Uint) AddSum(vals ...*Uint) *Uint { 280 for _, x := range vals { 281 u.u.Add(&u.u, &x.u) 282 } 283 return u 284 } 285 286 // AddOverflow will subtract y to x then store the result 287 // into u 288 // this is equivalent to: 289 // `u = x - y` 290 // u is returned for convenience, no 291 // new variable is created. 292 // False is returned if an overflow occurred. 293 func (u *Uint) AddOverflow(x, y *Uint) (*Uint, bool) { 294 _, ok := u.u.AddOverflow(&x.u, &y.u) 295 return u, ok 296 } 297 298 // Sub will subtract y from x then store the result 299 // into u 300 // this is equivalent to: 301 // `u = x - y` 302 // u is returned for convenience, no 303 // new variable is created. 304 func (u *Uint) Sub(x, y *Uint) *Uint { 305 u.u.Sub(&x.u, &y.u) 306 return u 307 } 308 309 // SubOverflow will subtract y to x then store the result 310 // into u 311 // this is equivalent to: 312 // `u = x - y` 313 // u is returned for convenience, no 314 // new variable is created. 315 // False is returned if an overflow occurred. 316 func (u *Uint) SubOverflow(x, y *Uint) (*Uint, bool) { 317 _, ok := u.u.SubOverflow(&x.u, &y.u) 318 return u, ok 319 } 320 321 // Delta will subtract y from x and store the result 322 // unless x-y overflowed, in which case the neg field will be set 323 // and the result of y - x is set instead. 324 func (u *Uint) Delta(x, y *Uint) (*Uint, bool) { 325 // y is the bigger value - swap the two 326 if y.GT(x) { 327 _ = u.Sub(y, x) 328 return u, true 329 } 330 _ = u.Sub(x, y) 331 return u, false 332 } 333 334 // DeltaI will subtract y from x and store the result. 335 func (u *Uint) DeltaI(x, y *Uint) *Int { 336 d, s := u.Delta(x, y) 337 return IntFromUint(d, !s) 338 } 339 340 // Mul will multiply x and y then store the result 341 // into u 342 // this is equivalent to: 343 // `u = x * y` 344 // u is returned for convenience, no 345 // new variable is created. 346 func (u *Uint) Mul(x, y *Uint) *Uint { 347 u.u.Mul(&x.u, &y.u) 348 return u 349 } 350 351 // Div will divide x by y then store the result 352 // into u 353 // this is equivalent to: 354 // `u = x / y` 355 // u is returned for convenience, no 356 // new variable is created. 357 func (u *Uint) Div(x, y *Uint) *Uint { 358 u.u.Div(&x.u, &y.u) 359 return u 360 } 361 362 // Mod sets u to the modulus x%y for y != 0 and returns u. 363 // If y == 0, u is set to 0. 364 func (u *Uint) Mod(x, y *Uint) *Uint { 365 u.u.Mod(&x.u, &y.u) 366 return u 367 } 368 369 func (u *Uint) Exp(x, y *Uint) *Uint { 370 u.u.Exp(&x.u, &y.u) 371 return u 372 } 373 374 // Sqrt calculates the integer-square root of the given Uint. 375 func (u *Uint) SqrtInt(x *Uint) *Uint { 376 u.u.Sqrt(&x.u) 377 return u 378 } 379 380 // Sqrt calculates the square root in decimals of the given Uint. 381 func (u *Uint) Sqrt(x *Uint) Decimal { 382 if x.IsZero() { 383 return DecimalZero() 384 } 385 // integer sqrt is a good approximation 386 r := UintOne().SqrtInt(x).ToDecimal() 387 388 // so now lets do a few iterations using Heron's Method to get closer 389 // r_i = (r + u/r) / 2 390 ud := x.ToDecimal() 391 for i := 0; i < 6; i++ { 392 r = r.Add(ud.Div(r)).Div(DecimalFromInt64(2)) 393 } 394 395 return r 396 } 397 398 // LT with check if the value stored in u is 399 // lesser than oth 400 // this is equivalent to: 401 // `u < oth`. 402 func (u Uint) LT(oth *Uint) bool { 403 return u.u.Lt(&oth.u) 404 } 405 406 // LTUint64 with check if the value stored in u is 407 // lesser than oth 408 // this is equivalent to: 409 // `u < oth`. 410 func (u Uint) LTUint64(oth uint64) bool { 411 return u.u.LtUint64(oth) 412 } 413 414 // LTE with check if the value stored in u is 415 // lesser than or equal to oth 416 // this is equivalent to: 417 // `u <= oth`. 418 func (u Uint) LTE(oth *Uint) bool { 419 return u.u.Lt(&oth.u) || u.u.Eq(&oth.u) 420 } 421 422 // LTEUint64 with check if the value stored in u is 423 // lesser than or equal to oth 424 // this is equivalent to: 425 // `u <= oth`. 426 func (u Uint) LTEUint64(oth uint64) bool { 427 return u.u.LtUint64(oth) || u.EQUint64(oth) 428 } 429 430 // EQ with check if the value stored in u is 431 // equal to oth 432 // this is equivalent to: 433 // `u == oth`. 434 func (u Uint) EQ(oth *Uint) bool { 435 return u.u.Eq(&oth.u) 436 } 437 438 // EQUint64 with check if the value stored in u is 439 // equal to oth 440 // this is equivalent to: 441 // `u == oth`. 442 func (u Uint) EQUint64(oth uint64) bool { 443 return u.u.Eq(uint256.NewInt(oth)) 444 } 445 446 // NEQ with check if the value stored in u is 447 // different than oth 448 // this is equivalent to: 449 // `u != oth`. 450 func (u Uint) NEQ(oth *Uint) bool { 451 return !u.u.Eq(&oth.u) 452 } 453 454 // NEQUint64 with check if the value stored in u is 455 // different than oth 456 // this is equivalent to: 457 // `u != oth`. 458 func (u Uint) NEQUint64(oth uint64) bool { 459 return !u.u.Eq(uint256.NewInt(oth)) 460 } 461 462 // GT with check if the value stored in u is 463 // greater than oth 464 // this is equivalent to: 465 // `u > oth`. 466 func (u Uint) GT(oth *Uint) bool { 467 return u.u.Gt(&oth.u) 468 } 469 470 // GTUint64 with check if the value stored in u is 471 // greater than oth 472 // this is equivalent to: 473 // `u > oth`. 474 func (u Uint) GTUint64(oth uint64) bool { 475 return u.u.GtUint64(oth) 476 } 477 478 // GTE with check if the value stored in u is 479 // greater than or equal to oth 480 // this is equivalent to: 481 // `u >= oth`. 482 func (u Uint) GTE(oth *Uint) bool { 483 return u.u.Gt(&oth.u) || u.u.Eq(&oth.u) 484 } 485 486 // GTEUint64 with check if the value stored in u is 487 // greater than or equal to oth 488 // this is equivalent to: 489 // `u >= oth`. 490 func (u Uint) GTEUint64(oth uint64) bool { 491 return u.u.GtUint64(oth) || u.EQUint64(oth) 492 } 493 494 // IsZero return whether u == 0 or not. 495 func (u Uint) IsZero() bool { 496 return u.u.IsZero() 497 } 498 499 // IsNegative returns whether the value is < 0. 500 func (u Uint) IsNegative() bool { 501 return u.u.Sign() == -1 502 } 503 504 // Copy create a copy of the uint 505 // this if the equivalent to: 506 // u = x. 507 func (u *Uint) Copy(x *Uint) *Uint { 508 u.u = x.u 509 return u 510 } 511 512 // Clone create copy of this value 513 // this is the equivalent to: 514 // x := u. 515 func (u Uint) Clone() *Uint { 516 return &Uint{u.u} 517 } 518 519 // Hex returns the hexadecimal representation 520 // of the stored value. 521 func (u Uint) Hex() string { 522 return u.u.Hex() 523 } 524 525 // String returns the stored value as a string 526 // this is internally using big.Int.String(). 527 func (u Uint) String() string { 528 return u.u.ToBig().String() 529 } 530 531 // Format implement fmt.Formatter. 532 func (u Uint) Format(s fmt.State, ch rune) { 533 u.u.Format(s, ch) 534 } 535 536 // Bytes return the internal representation 537 // of the Uint as [32]bytes, BigEndian encoded 538 // array. 539 func (u Uint) Bytes() [32]byte { 540 return u.u.Bytes32() 541 } 542 543 // UintToUint64 convert a uint to uint64 544 // return 0 if nil. 545 func UintToUint64(u *Uint) uint64 { 546 if u != nil { 547 return u.Uint64() 548 } 549 return 0 550 } 551 552 // UintToString convert a uint to uint64 553 // return "0" if nil. 554 func UintToString(u *Uint) string { 555 if u != nil { 556 return u.String() 557 } 558 return "0" 559 } 560 561 // Median calculates the median of the slice of uints. 562 // it is assumed that no nils are allowed, no zeros are allowed. 563 func Median(nums []*Uint) *Uint { 564 if nums == nil { 565 return nil 566 } 567 numsCopy := make([]*Uint, 0, len(nums)) 568 for _, u := range nums { 569 if u != nil && !u.IsZero() { 570 numsCopy = append(numsCopy, u.Clone()) 571 } 572 } 573 sort.Slice(numsCopy, func(i, j int) bool { 574 return numsCopy[i].LT(numsCopy[j]) 575 }) 576 if len(numsCopy) == 0 { 577 return nil 578 } 579 580 mid := len(numsCopy) / 2 581 if len(numsCopy)%2 == 1 { 582 return numsCopy[mid] 583 } 584 return UintZero().Div(Sum(numsCopy[mid], numsCopy[mid-1]), NewUint(2)) 585 } 586 587 func unquoteIfQuoted(value interface{}) (string, error) { 588 var bytes []byte 589 590 switch v := value.(type) { 591 case string: 592 bytes = []byte(v) 593 case []byte: 594 bytes = v 595 default: 596 return "", fmt.Errorf("could not convert value '%+v' to byte array of type '%T'", 597 value, value) 598 } 599 600 // If the amount is quoted, strip the quotes 601 if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' { 602 bytes = bytes[1 : len(bytes)-1] 603 } 604 return string(bytes), nil 605 }