github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/examples/gno.land/p/demo/uint256/arithmetic.gno (about)

     1  // arithmetic provides arithmetic operations for Uint objects.
     2  // This includes basic binary operations such as addition, subtraction, multiplication, division, and modulo operations
     3  // as well as overflow checks, and negation. These functions are essential for numeric
     4  // calculations using 256-bit unsigned integers.
     5  package uint256
     6  
     7  import (
     8  	"math/bits"
     9  )
    10  
    11  // Add sets z to the sum x+y
    12  func (z *Uint) Add(x, y *Uint) *Uint {
    13  	var carry uint64
    14  	z.arr[0], carry = bits.Add64(x.arr[0], y.arr[0], 0)
    15  	z.arr[1], carry = bits.Add64(x.arr[1], y.arr[1], carry)
    16  	z.arr[2], carry = bits.Add64(x.arr[2], y.arr[2], carry)
    17  	z.arr[3], _ = bits.Add64(x.arr[3], y.arr[3], carry)
    18  	return z
    19  }
    20  
    21  // AddOverflow sets z to the sum x+y, and returns z and whether overflow occurred
    22  func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) {
    23  	var carry uint64
    24  	z.arr[0], carry = bits.Add64(x.arr[0], y.arr[0], 0)
    25  	z.arr[1], carry = bits.Add64(x.arr[1], y.arr[1], carry)
    26  	z.arr[2], carry = bits.Add64(x.arr[2], y.arr[2], carry)
    27  	z.arr[3], carry = bits.Add64(x.arr[3], y.arr[3], carry)
    28  	return z, carry != 0
    29  }
    30  
    31  // Sub sets z to the difference x-y
    32  func (z *Uint) Sub(x, y *Uint) *Uint {
    33  	var carry uint64
    34  	z.arr[0], carry = bits.Sub64(x.arr[0], y.arr[0], 0)
    35  	z.arr[1], carry = bits.Sub64(x.arr[1], y.arr[1], carry)
    36  	z.arr[2], carry = bits.Sub64(x.arr[2], y.arr[2], carry)
    37  	z.arr[3], _ = bits.Sub64(x.arr[3], y.arr[3], carry)
    38  	return z
    39  }
    40  
    41  // SubOverflow sets z to the difference x-y and returns z and true if the operation underflowed
    42  func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) {
    43  	var carry uint64
    44  	z.arr[0], carry = bits.Sub64(x.arr[0], y.arr[0], 0)
    45  	z.arr[1], carry = bits.Sub64(x.arr[1], y.arr[1], carry)
    46  	z.arr[2], carry = bits.Sub64(x.arr[2], y.arr[2], carry)
    47  	z.arr[3], carry = bits.Sub64(x.arr[3], y.arr[3], carry)
    48  	return z, carry != 0
    49  }
    50  
    51  // Neg returns -x mod 2^256.
    52  func (z *Uint) Neg(x *Uint) *Uint {
    53  	return z.Sub(new(Uint), x)
    54  }
    55  
    56  // commented out for possible overflow
    57  // Mul sets z to the product x*y
    58  func (z *Uint) Mul(x, y *Uint) *Uint {
    59  	var (
    60  		res              Uint
    61  		carry            uint64
    62  		res1, res2, res3 uint64
    63  	)
    64  
    65  	carry, res.arr[0] = bits.Mul64(x.arr[0], y.arr[0])
    66  	carry, res1 = umulHop(carry, x.arr[1], y.arr[0])
    67  	carry, res2 = umulHop(carry, x.arr[2], y.arr[0])
    68  	res3 = x.arr[3]*y.arr[0] + carry
    69  
    70  	carry, res.arr[1] = umulHop(res1, x.arr[0], y.arr[1])
    71  	carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry)
    72  	res3 = res3 + x.arr[2]*y.arr[1] + carry
    73  
    74  	carry, res.arr[2] = umulHop(res2, x.arr[0], y.arr[2])
    75  	res3 = res3 + x.arr[1]*y.arr[2] + carry
    76  
    77  	res.arr[3] = res3 + x.arr[0]*y.arr[3]
    78  
    79  	return z.Set(&res)
    80  }
    81  
    82  // MulOverflow sets z to the product x*y, and returns z and  whether overflow occurred
    83  func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) {
    84  	p := umul(x, y)
    85  	copy(z.arr[:], p[:4])
    86  	return z, (p[4] | p[5] | p[6] | p[7]) != 0
    87  }
    88  
    89  // commented out for possible overflow
    90  // Div sets z to the quotient x/y for returns z.
    91  // If y == 0, z is set to 0
    92  func (z *Uint) Div(x, y *Uint) *Uint {
    93  	if y.IsZero() || y.Gt(x) {
    94  		return z.Clear()
    95  	}
    96  	if x.Eq(y) {
    97  		return z.SetOne()
    98  	}
    99  	// Shortcut some cases
   100  	if x.IsUint64() {
   101  		return z.SetUint64(x.Uint64() / y.Uint64())
   102  	}
   103  
   104  	// At this point, we know
   105  	// x/y ; x > y > 0
   106  
   107  	var quot Uint
   108  	udivrem(quot.arr[:], x.arr[:], y)
   109  	return z.Set(&quot)
   110  }
   111  
   112  // MulMod calculates the modulo-m multiplication of x and y and
   113  // returns z.
   114  // If m == 0, z is set to 0 (OBS: differs from the big.Int)
   115  func (z *Uint) MulMod(x, y, m *Uint) *Uint {
   116  	if x.IsZero() || y.IsZero() || m.IsZero() {
   117  		return z.Clear()
   118  	}
   119  	p := umul(x, y)
   120  
   121  	if m.arr[3] != 0 {
   122  		mu := Reciprocal(m)
   123  		r := reduce4(p, m, mu)
   124  		return z.Set(&r)
   125  	}
   126  
   127  	var (
   128  		pl Uint
   129  		ph Uint
   130  	)
   131  
   132  	pl = Uint{arr: [4]uint64{p[0], p[1], p[2], p[3]}}
   133  	ph = Uint{arr: [4]uint64{p[4], p[5], p[6], p[7]}}
   134  
   135  	// If the multiplication is within 256 bits use Mod().
   136  	if ph.IsZero() {
   137  		return z.Mod(&pl, m)
   138  	}
   139  
   140  	var quot [8]uint64
   141  	rem := udivrem(quot[:], p[:], m)
   142  	return z.Set(&rem)
   143  }
   144  
   145  // Mod sets z to the modulus x%y for y != 0 and returns z.
   146  // If y == 0, z is set to 0 (OBS: differs from the big.Uint)
   147  func (z *Uint) Mod(x, y *Uint) *Uint {
   148  	if x.IsZero() || y.IsZero() {
   149  		return z.Clear()
   150  	}
   151  	switch x.Cmp(y) {
   152  	case -1:
   153  		// x < y
   154  		copy(z.arr[:], x.arr[:])
   155  		return z
   156  	case 0:
   157  		// x == y
   158  		return z.Clear() // They are equal
   159  	}
   160  
   161  	// At this point:
   162  	// x != 0
   163  	// y != 0
   164  	// x > y
   165  
   166  	// Shortcut trivial case
   167  	if x.IsUint64() {
   168  		return z.SetUint64(x.Uint64() % y.Uint64())
   169  	}
   170  
   171  	var quot Uint
   172  	*z = udivrem(quot.arr[:], x.arr[:], y)
   173  	return z
   174  }
   175  
   176  // DivMod sets z to the quotient x div y and m to the modulus x mod y and returns the pair (z, m) for y != 0.
   177  // If y == 0, both z and m are set to 0 (OBS: differs from the big.Int)
   178  func (z *Uint) DivMod(x, y, m *Uint) (*Uint, *Uint) {
   179  	if y.IsZero() {
   180  		return z.Clear(), m.Clear()
   181  	}
   182  	var quot Uint
   183  	*m = udivrem(quot.arr[:], x.arr[:], y)
   184  	*z = quot
   185  	return z, m
   186  }
   187  
   188  // Exp sets z = base**exponent mod 2**256, and returns z.
   189  func (z *Uint) Exp(base, exponent *Uint) *Uint {
   190  	res := Uint{arr: [4]uint64{1, 0, 0, 0}}
   191  	multiplier := *base
   192  	expBitLen := exponent.BitLen()
   193  
   194  	curBit := 0
   195  	word := exponent.arr[0]
   196  	for ; curBit < expBitLen && curBit < 64; curBit++ {
   197  		if word&1 == 1 {
   198  			res.Mul(&res, &multiplier)
   199  		}
   200  		multiplier.squared()
   201  		word >>= 1
   202  	}
   203  
   204  	word = exponent.arr[1]
   205  	for ; curBit < expBitLen && curBit < 128; curBit++ {
   206  		if word&1 == 1 {
   207  			res.Mul(&res, &multiplier)
   208  		}
   209  		multiplier.squared()
   210  		word >>= 1
   211  	}
   212  
   213  	word = exponent.arr[2]
   214  	for ; curBit < expBitLen && curBit < 192; curBit++ {
   215  		if word&1 == 1 {
   216  			res.Mul(&res, &multiplier)
   217  		}
   218  		multiplier.squared()
   219  		word >>= 1
   220  	}
   221  
   222  	word = exponent.arr[3]
   223  	for ; curBit < expBitLen && curBit < 256; curBit++ {
   224  		if word&1 == 1 {
   225  			res.Mul(&res, &multiplier)
   226  		}
   227  		multiplier.squared()
   228  		word >>= 1
   229  	}
   230  	return z.Set(&res)
   231  }
   232  
   233  func (z *Uint) squared() {
   234  	var (
   235  		res                    Uint
   236  		carry0, carry1, carry2 uint64
   237  		res1, res2             uint64
   238  	)
   239  
   240  	carry0, res.arr[0] = bits.Mul64(z.arr[0], z.arr[0])
   241  	carry0, res1 = umulHop(carry0, z.arr[0], z.arr[1])
   242  	carry0, res2 = umulHop(carry0, z.arr[0], z.arr[2])
   243  
   244  	carry1, res.arr[1] = umulHop(res1, z.arr[0], z.arr[1])
   245  	carry1, res2 = umulStep(res2, z.arr[1], z.arr[1], carry1)
   246  
   247  	carry2, res.arr[2] = umulHop(res2, z.arr[0], z.arr[2])
   248  
   249  	res.arr[3] = 2*(z.arr[0]*z.arr[3]+z.arr[1]*z.arr[2]) + carry0 + carry1 + carry2
   250  
   251  	z.Set(&res)
   252  }
   253  
   254  // udivrem divides u by d and produces both quotient and remainder.
   255  // The quotient is stored in provided quot - len(u)-len(d)+1 words.
   256  // It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words.
   257  // See Knuth, Volume 2, section 4.3.1, Algorithm D.
   258  func udivrem(quot, u []uint64, d *Uint) (rem Uint) {
   259  	var dLen int
   260  	for i := len(d.arr) - 1; i >= 0; i-- {
   261  		if d.arr[i] != 0 {
   262  			dLen = i + 1
   263  			break
   264  		}
   265  	}
   266  
   267  	shift := uint(bits.LeadingZeros64(d.arr[dLen-1]))
   268  
   269  	var dnStorage Uint
   270  	dn := dnStorage.arr[:dLen]
   271  	for i := dLen - 1; i > 0; i-- {
   272  		dn[i] = (d.arr[i] << shift) | (d.arr[i-1] >> (64 - shift))
   273  	}
   274  	dn[0] = d.arr[0] << shift
   275  
   276  	var uLen int
   277  	for i := len(u) - 1; i >= 0; i-- {
   278  		if u[i] != 0 {
   279  			uLen = i + 1
   280  			break
   281  		}
   282  	}
   283  
   284  	if uLen < dLen {
   285  		copy(rem.arr[:], u)
   286  		return rem
   287  	}
   288  
   289  	var unStorage [9]uint64
   290  	un := unStorage[:uLen+1]
   291  	un[uLen] = u[uLen-1] >> (64 - shift)
   292  	for i := uLen - 1; i > 0; i-- {
   293  		un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift))
   294  	}
   295  	un[0] = u[0] << shift
   296  
   297  	// TODO: Skip the highest word of numerator if not significant.
   298  
   299  	if dLen == 1 {
   300  		r := udivremBy1(quot, un, dn[0])
   301  		rem.SetUint64(r >> shift)
   302  		return rem
   303  	}
   304  
   305  	udivremKnuth(quot, un, dn)
   306  
   307  	for i := 0; i < dLen-1; i++ {
   308  		rem.arr[i] = (un[i] >> shift) | (un[i+1] << (64 - shift))
   309  	}
   310  	rem.arr[dLen-1] = un[dLen-1] >> shift
   311  
   312  	return rem
   313  }
   314  
   315  // umul computes full 256 x 256 -> 512 multiplication.
   316  func umul(x, y *Uint) [8]uint64 {
   317  	var (
   318  		res                           [8]uint64
   319  		carry, carry4, carry5, carry6 uint64
   320  		res1, res2, res3, res4, res5  uint64
   321  	)
   322  
   323  	carry, res[0] = bits.Mul64(x.arr[0], y.arr[0])
   324  	carry, res1 = umulHop(carry, x.arr[1], y.arr[0])
   325  	carry, res2 = umulHop(carry, x.arr[2], y.arr[0])
   326  	carry4, res3 = umulHop(carry, x.arr[3], y.arr[0])
   327  
   328  	carry, res[1] = umulHop(res1, x.arr[0], y.arr[1])
   329  	carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry)
   330  	carry, res3 = umulStep(res3, x.arr[2], y.arr[1], carry)
   331  	carry5, res4 = umulStep(carry4, x.arr[3], y.arr[1], carry)
   332  
   333  	carry, res[2] = umulHop(res2, x.arr[0], y.arr[2])
   334  	carry, res3 = umulStep(res3, x.arr[1], y.arr[2], carry)
   335  	carry, res4 = umulStep(res4, x.arr[2], y.arr[2], carry)
   336  	carry6, res5 = umulStep(carry5, x.arr[3], y.arr[2], carry)
   337  
   338  	carry, res[3] = umulHop(res3, x.arr[0], y.arr[3])
   339  	carry, res[4] = umulStep(res4, x.arr[1], y.arr[3], carry)
   340  	carry, res[5] = umulStep(res5, x.arr[2], y.arr[3], carry)
   341  	res[7], res[6] = umulStep(carry6, x.arr[3], y.arr[3], carry)
   342  
   343  	return res
   344  }
   345  
   346  // umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry.
   347  func umulStep(z, x, y, carry uint64) (hi, lo uint64) {
   348  	hi, lo = bits.Mul64(x, y)
   349  	lo, carry = bits.Add64(lo, carry, 0)
   350  	hi, _ = bits.Add64(hi, 0, carry)
   351  	lo, carry = bits.Add64(lo, z, 0)
   352  	hi, _ = bits.Add64(hi, 0, carry)
   353  	return hi, lo
   354  }
   355  
   356  // umulHop computes (hi * 2^64 + lo) = z + (x * y)
   357  func umulHop(z, x, y uint64) (hi, lo uint64) {
   358  	hi, lo = bits.Mul64(x, y)
   359  	lo, carry := bits.Add64(lo, z, 0)
   360  	hi, _ = bits.Add64(hi, 0, carry)
   361  	return hi, lo
   362  }
   363  
   364  // udivremBy1 divides u by single normalized word d and produces both quotient and remainder.
   365  // The quotient is stored in provided quot.
   366  func udivremBy1(quot, u []uint64, d uint64) (rem uint64) {
   367  	reciprocal := reciprocal2by1(d)
   368  	rem = u[len(u)-1] // Set the top word as remainder.
   369  	for j := len(u) - 2; j >= 0; j-- {
   370  		quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal)
   371  	}
   372  	return rem
   373  }
   374  
   375  // udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm.
   376  // The quotient is stored in provided quot - len(u)-len(d) words.
   377  // Updates u to contain the remainder - len(d) words.
   378  func udivremKnuth(quot, u, d []uint64) {
   379  	dh := d[len(d)-1]
   380  	dl := d[len(d)-2]
   381  	reciprocal := reciprocal2by1(dh)
   382  
   383  	for j := len(u) - len(d) - 1; j >= 0; j-- {
   384  		u2 := u[j+len(d)]
   385  		u1 := u[j+len(d)-1]
   386  		u0 := u[j+len(d)-2]
   387  
   388  		var qhat, rhat uint64
   389  		if u2 >= dh { // Division overflows.
   390  			qhat = ^uint64(0)
   391  			// TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
   392  		} else {
   393  			qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal)
   394  			ph, pl := bits.Mul64(qhat, dl)
   395  			if ph > rhat || (ph == rhat && pl > u0) {
   396  				qhat--
   397  				// TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
   398  			}
   399  		}
   400  
   401  		// Multiply and subtract.
   402  		borrow := subMulTo(u[j:], d, qhat)
   403  		u[j+len(d)] = u2 - borrow
   404  		if u2 < borrow { // Too much subtracted, add back.
   405  			qhat--
   406  			u[j+len(d)] += addTo(u[j:], d)
   407  		}
   408  
   409  		quot[j] = qhat // Store quotient digit.
   410  	}
   411  }
   412  
   413  // isBitSet returns true if bit n-th is set, where n = 0 is LSB.
   414  // The n must be <= 255.
   415  func (z *Uint) isBitSet(n uint) bool {
   416  	return (z.arr[n/64] & (1 << (n % 64))) != 0
   417  }
   418  
   419  // addTo computes x += y.
   420  // Requires len(x) >= len(y).
   421  func addTo(x, y []uint64) uint64 {
   422  	var carry uint64
   423  	for i := 0; i < len(y); i++ {
   424  		x[i], carry = bits.Add64(x[i], y[i], carry)
   425  	}
   426  	return carry
   427  }
   428  
   429  // subMulTo computes x -= y * multiplier.
   430  // Requires len(x) >= len(y).
   431  func subMulTo(x, y []uint64, multiplier uint64) uint64 {
   432  	var borrow uint64
   433  	for i := 0; i < len(y); i++ {
   434  		s, carry1 := bits.Sub64(x[i], borrow, 0)
   435  		ph, pl := bits.Mul64(y[i], multiplier)
   436  		t, carry2 := bits.Sub64(s, pl, 0)
   437  		x[i] = t
   438  		borrow = ph + carry1 + carry2
   439  	}
   440  	return borrow
   441  }
   442  
   443  // reciprocal2by1 computes <^d, ^0> / d.
   444  func reciprocal2by1(d uint64) uint64 {
   445  	reciprocal, _ := bits.Div64(^d, ^uint64(0), d)
   446  	return reciprocal
   447  }
   448  
   449  // udivrem2by1 divides <uh, ul> / d and produces both quotient and remainder.
   450  // It uses the provided d's reciprocal.
   451  // Implementation ported from https://github.com/chfast/intx and is based on
   452  // "Improved division by invariant integers", Algorithm 4.
   453  func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) {
   454  	qh, ql := bits.Mul64(reciprocal, uh)
   455  	ql, carry := bits.Add64(ql, ul, 0)
   456  	qh, _ = bits.Add64(qh, uh, carry)
   457  	qh++
   458  
   459  	r := ul - qh*d
   460  
   461  	if r > ql {
   462  		qh--
   463  		r += d
   464  	}
   465  
   466  	if r >= d {
   467  		qh++
   468  		r -= d
   469  	}
   470  
   471  	return qh, r
   472  }