github.com/consensys/gnark-crypto@v0.14.0/internal/generator/test_vector_utils/small_rational/small-rational.go (about)

     1  package small_rational
     2  
     3  import (
     4  	"crypto/rand"
     5  	"fmt"
     6  	"math/big"
     7  	"strconv"
     8  	"strings"
     9  )
    10  
    11  const Bytes = 64
    12  
    13  type SmallRational struct {
    14  	text        string //For debugging purposes
    15  	numerator   big.Int
    16  	denominator big.Int // By convention, denominator == 0 also indicates zero
    17  }
    18  
    19  var smallPrimes = []*big.Int{
    20  	big.NewInt(2), big.NewInt(3), big.NewInt(5),
    21  	big.NewInt(7), big.NewInt(11), big.NewInt(13),
    22  }
    23  
    24  func bigDivides(p, a *big.Int) bool {
    25  	var remainder big.Int
    26  	remainder.Mod(a, p)
    27  	return remainder.BitLen() == 0
    28  }
    29  
    30  func (z *SmallRational) UpdateText() {
    31  	z.text = z.Text(10)
    32  }
    33  
    34  func (z *SmallRational) simplify() {
    35  
    36  	if z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 {
    37  		return
    38  	}
    39  
    40  	var num, den big.Int
    41  
    42  	num.Set(&z.numerator)
    43  	den.Set(&z.denominator)
    44  
    45  	for _, p := range smallPrimes {
    46  		for bigDivides(p, &num) && bigDivides(p, &den) {
    47  			num.Div(&num, p)
    48  			den.Div(&den, p)
    49  		}
    50  	}
    51  
    52  	if bigDivides(&den, &num) {
    53  		num.Div(&num, &den)
    54  		den.SetInt64(1)
    55  	}
    56  
    57  	z.numerator = num
    58  	z.denominator = den
    59  
    60  }
    61  func (z *SmallRational) Square(x *SmallRational) *SmallRational {
    62  	var num, den big.Int
    63  	num.Mul(&x.numerator, &x.numerator)
    64  	den.Mul(&x.denominator, &x.denominator)
    65  
    66  	z.numerator = num
    67  	z.denominator = den
    68  
    69  	z.UpdateText()
    70  
    71  	return z
    72  }
    73  
    74  func (z *SmallRational) String() string {
    75  	z.text = z.Text(10)
    76  	return z.text
    77  }
    78  
    79  func (z *SmallRational) Add(x, y *SmallRational) *SmallRational {
    80  	if x.denominator.BitLen() == 0 {
    81  		*z = *y
    82  	} else if y.denominator.BitLen() == 0 {
    83  		*z = *x
    84  	} else {
    85  		//TODO: Exploit cases where one denom divides the other
    86  		var numDen, denNum big.Int
    87  		numDen.Mul(&x.numerator, &y.denominator)
    88  		denNum.Mul(&x.denominator, &y.numerator)
    89  
    90  		numDen.Add(&denNum, &numDen)
    91  		z.numerator = numDen //to avoid shallow copy problems
    92  
    93  		denNum.Mul(&x.denominator, &y.denominator)
    94  		z.denominator = denNum
    95  		z.simplify()
    96  	}
    97  
    98  	z.UpdateText()
    99  
   100  	return z
   101  }
   102  
   103  func (z *SmallRational) IsZero() bool {
   104  	return z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0
   105  }
   106  
   107  func (z *SmallRational) Inverse(x *SmallRational) *SmallRational {
   108  	if x.IsZero() {
   109  		*z = *x
   110  	} else {
   111  		*z = SmallRational{numerator: x.denominator, denominator: x.numerator}
   112  		z.UpdateText()
   113  	}
   114  
   115  	return z
   116  }
   117  
   118  func (z *SmallRational) Neg(x *SmallRational) *SmallRational {
   119  	z.numerator.Neg(&x.numerator)
   120  	z.denominator = x.denominator
   121  
   122  	if x.text == "" {
   123  		x.UpdateText()
   124  	}
   125  
   126  	if x.text[0] == '-' {
   127  		z.text = x.text[1:]
   128  	} else {
   129  		z.text = "-" + x.text
   130  	}
   131  
   132  	return z
   133  }
   134  
   135  func (z *SmallRational) Double(x *SmallRational) *SmallRational {
   136  
   137  	var y big.Int
   138  
   139  	if x.denominator.Bit(0) == 0 {
   140  		z.numerator = x.numerator
   141  		y.Rsh(&x.denominator, 1)
   142  		z.denominator = y
   143  	} else {
   144  		y.Lsh(&x.numerator, 1)
   145  		z.numerator = y
   146  		z.denominator = x.denominator
   147  	}
   148  
   149  	z.UpdateText()
   150  
   151  	return z
   152  }
   153  
   154  func (z *SmallRational) Sign() int {
   155  	return z.numerator.Sign() * z.denominator.Sign()
   156  }
   157  
   158  func (z *SmallRational) MarshalJSON() ([]byte, error) {
   159  	return []byte(z.String()), nil
   160  }
   161  
   162  func (z *SmallRational) UnmarshalJson(data []byte) error {
   163  	_, err := z.SetInterface(string(data))
   164  	return err
   165  }
   166  
   167  func (z *SmallRational) Equal(x *SmallRational) bool {
   168  	return z.Cmp(x) == 0
   169  }
   170  
   171  func (z *SmallRational) Sub(x, y *SmallRational) *SmallRational {
   172  	var yNeg SmallRational
   173  	yNeg.Neg(y)
   174  	z.Add(x, &yNeg)
   175  
   176  	z.UpdateText()
   177  	return z
   178  }
   179  
   180  func (z *SmallRational) Cmp(x *SmallRational) int {
   181  	zSign, xSign := z.Sign(), x.Sign()
   182  
   183  	if zSign > xSign {
   184  		return 1
   185  	}
   186  	if zSign < xSign {
   187  		return -1
   188  	}
   189  
   190  	var Z, X big.Int
   191  	Z.Mul(&z.numerator, &x.denominator)
   192  	X.Mul(&x.numerator, &z.denominator)
   193  
   194  	Z.Abs(&Z)
   195  	X.Abs(&X)
   196  
   197  	return Z.Cmp(&X) * zSign
   198  
   199  }
   200  
   201  func BatchInvert(a []SmallRational) []SmallRational {
   202  	res := make([]SmallRational, len(a))
   203  	for i := range a {
   204  		res[i].Inverse(&a[i])
   205  	}
   206  	return res
   207  }
   208  
   209  func (z *SmallRational) Mul(x, y *SmallRational) *SmallRational {
   210  	var num, den big.Int
   211  
   212  	num.Mul(&x.numerator, &y.numerator)
   213  	den.Mul(&x.denominator, &y.denominator)
   214  
   215  	z.numerator = num
   216  	z.denominator = den
   217  
   218  	z.simplify()
   219  	z.UpdateText()
   220  	return z
   221  }
   222  
   223  func (z *SmallRational) SetOne() *SmallRational {
   224  	return z.SetInt64(1)
   225  }
   226  
   227  func (z *SmallRational) SetZero() *SmallRational {
   228  	return z.SetInt64(0)
   229  }
   230  
   231  func (z *SmallRational) SetInt64(i int64) *SmallRational {
   232  	z.numerator = *big.NewInt(i)
   233  	z.denominator = *big.NewInt(1)
   234  	z.text = strconv.FormatInt(i, 10)
   235  	return z
   236  }
   237  
   238  func (z *SmallRational) SetRandom() (*SmallRational, error) {
   239  
   240  	bytes := make([]byte, 1)
   241  	n, err := rand.Read(bytes)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  	if n != len(bytes) {
   246  		return nil, fmt.Errorf("%d bytes read instead of %d", n, len(bytes))
   247  	}
   248  
   249  	z.numerator = *big.NewInt(int64(bytes[0]%16) - 8)
   250  	z.denominator = *big.NewInt(int64((bytes[0]) / 16))
   251  
   252  	z.simplify()
   253  	z.UpdateText()
   254  
   255  	return z, nil
   256  }
   257  
   258  func (z *SmallRational) SetUint64(i uint64) {
   259  	var num big.Int
   260  	num.SetUint64(i)
   261  	z.numerator = num
   262  	z.denominator = *big.NewInt(1)
   263  	z.text = strconv.FormatUint(i, 10)
   264  }
   265  
   266  func (z *SmallRational) IsOne() bool {
   267  	return z.numerator.Cmp(&z.denominator) == 0 && z.denominator.BitLen() != 0
   268  }
   269  
   270  func (z *SmallRational) Text(base int) string {
   271  
   272  	if z.denominator.BitLen() == 0 {
   273  		return "0"
   274  	}
   275  
   276  	if z.denominator.Sign() < 0 {
   277  		var num, den big.Int
   278  		num.Neg(&z.numerator)
   279  		den.Neg(&z.denominator)
   280  		z.numerator = num
   281  		z.denominator = den
   282  	}
   283  
   284  	if bigDivides(&z.denominator, &z.numerator) {
   285  		var num big.Int
   286  		num.Div(&z.numerator, &z.denominator)
   287  		z.numerator = num
   288  		z.denominator = *big.NewInt(1)
   289  	}
   290  
   291  	numerator := z.numerator.Text(base)
   292  
   293  	if z.denominator.IsInt64() && z.denominator.Int64() == 1 {
   294  		return numerator
   295  	}
   296  
   297  	return numerator + "/" + z.denominator.Text(base)
   298  }
   299  
   300  func (z *SmallRational) Set(x *SmallRational) *SmallRational {
   301  	*z = *x // shallow copy is safe because ops are never in place
   302  	return z
   303  }
   304  
   305  func (z *SmallRational) SetInterface(x interface{}) (*SmallRational, error) {
   306  
   307  	switch v := x.(type) {
   308  	case *SmallRational:
   309  		*z = *v
   310  	case SmallRational:
   311  		*z = v
   312  	case int64:
   313  		z.SetInt64(v)
   314  	case int:
   315  		z.SetInt64(int64(v))
   316  	case float64:
   317  		asInt := int64(v)
   318  		if float64(asInt) != v {
   319  			return nil, fmt.Errorf("cannot currently parse float")
   320  		}
   321  		z.SetInt64(asInt)
   322  	case string:
   323  		z.text = v
   324  		sep := strings.Split(v, "/")
   325  		switch len(sep) {
   326  		case 1:
   327  			if asInt, err := strconv.Atoi(sep[0]); err == nil {
   328  				z.SetInt64(int64(asInt))
   329  			} else {
   330  				return nil, err
   331  			}
   332  		case 2:
   333  			var err error
   334  			var num, denom int
   335  			num, err = strconv.Atoi(sep[0])
   336  			if err != nil {
   337  				return nil, err
   338  			}
   339  			denom, err = strconv.Atoi(sep[1])
   340  			if err != nil {
   341  				return nil, err
   342  			}
   343  			z.numerator = *big.NewInt(int64(num))
   344  			z.denominator = *big.NewInt(int64(denom))
   345  		default:
   346  			return nil, fmt.Errorf("cannot parse \"%s\"", v)
   347  		}
   348  	default:
   349  		return nil, fmt.Errorf("cannot parse %T", x)
   350  	}
   351  
   352  	return z, nil
   353  }
   354  
   355  func bigIntToBytesSigned(dst []byte, src big.Int) {
   356  	src.FillBytes(dst[1:])
   357  	dst[0] = 0
   358  	if src.Sign() < 0 {
   359  		dst[0] = 255
   360  	}
   361  }
   362  
   363  func (z *SmallRational) Bytes() [Bytes]byte {
   364  	var res [Bytes]byte
   365  	bigIntToBytesSigned(res[:Bytes/2], z.numerator)
   366  	bigIntToBytesSigned(res[Bytes/2:], z.denominator)
   367  	return res
   368  }
   369  
   370  func bytesToBigIntSigned(src []byte) big.Int {
   371  	var res big.Int
   372  	res.SetBytes(src[1:])
   373  	if src[0] != 0 {
   374  		res.Neg(&res)
   375  	}
   376  	return res
   377  }
   378  
   379  // BigInt returns sets dst to the value of z if it is an integer.
   380  // if z is not an integer, nil is returned.
   381  // if the given dst is nil, the address of the numerator is returned.
   382  // if the given dst is non-nil, it is returned.
   383  func (z *SmallRational) BigInt(dst *big.Int) *big.Int {
   384  	if z.denominator.Cmp(big.NewInt(1)) != 0 {
   385  		return nil
   386  	}
   387  	if dst == nil {
   388  		return &z.numerator
   389  	}
   390  	dst.Set(&z.numerator)
   391  	return dst
   392  }
   393  
   394  func (z *SmallRational) SetBytes(b []byte) {
   395  	if len(b) > Bytes/2 {
   396  		z.numerator = bytesToBigIntSigned(b[:Bytes/2])
   397  		z.denominator = bytesToBigIntSigned(b[Bytes/2:])
   398  	} else {
   399  		z.numerator.SetBytes(b)
   400  		z.denominator.SetInt64(1)
   401  	}
   402  	z.simplify()
   403  	z.UpdateText()
   404  }
   405  
   406  func Modulus() *big.Int {
   407  	res := big.NewInt(1)
   408  	res.Lsh(res, 64)
   409  	return res
   410  }