github.com/consensys/gnark-crypto@v0.14.0/internal/generator/test_vector_utils/small_rational/small-rational.go (about) 1 package small_rational 2 3 import ( 4 "crypto/rand" 5 "fmt" 6 "math/big" 7 "strconv" 8 "strings" 9 ) 10 11 const Bytes = 64 12 13 type SmallRational struct { 14 text string //For debugging purposes 15 numerator big.Int 16 denominator big.Int // By convention, denominator == 0 also indicates zero 17 } 18 19 var smallPrimes = []*big.Int{ 20 big.NewInt(2), big.NewInt(3), big.NewInt(5), 21 big.NewInt(7), big.NewInt(11), big.NewInt(13), 22 } 23 24 func bigDivides(p, a *big.Int) bool { 25 var remainder big.Int 26 remainder.Mod(a, p) 27 return remainder.BitLen() == 0 28 } 29 30 func (z *SmallRational) UpdateText() { 31 z.text = z.Text(10) 32 } 33 34 func (z *SmallRational) simplify() { 35 36 if z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 { 37 return 38 } 39 40 var num, den big.Int 41 42 num.Set(&z.numerator) 43 den.Set(&z.denominator) 44 45 for _, p := range smallPrimes { 46 for bigDivides(p, &num) && bigDivides(p, &den) { 47 num.Div(&num, p) 48 den.Div(&den, p) 49 } 50 } 51 52 if bigDivides(&den, &num) { 53 num.Div(&num, &den) 54 den.SetInt64(1) 55 } 56 57 z.numerator = num 58 z.denominator = den 59 60 } 61 func (z *SmallRational) Square(x *SmallRational) *SmallRational { 62 var num, den big.Int 63 num.Mul(&x.numerator, &x.numerator) 64 den.Mul(&x.denominator, &x.denominator) 65 66 z.numerator = num 67 z.denominator = den 68 69 z.UpdateText() 70 71 return z 72 } 73 74 func (z *SmallRational) String() string { 75 z.text = z.Text(10) 76 return z.text 77 } 78 79 func (z *SmallRational) Add(x, y *SmallRational) *SmallRational { 80 if x.denominator.BitLen() == 0 { 81 *z = *y 82 } else if y.denominator.BitLen() == 0 { 83 *z = *x 84 } else { 85 //TODO: Exploit cases where one denom divides the other 86 var numDen, denNum big.Int 87 numDen.Mul(&x.numerator, &y.denominator) 88 denNum.Mul(&x.denominator, &y.numerator) 89 90 numDen.Add(&denNum, &numDen) 91 z.numerator = numDen //to avoid shallow copy problems 92 93 denNum.Mul(&x.denominator, &y.denominator) 94 z.denominator = denNum 95 z.simplify() 96 } 97 98 z.UpdateText() 99 100 return z 101 } 102 103 func (z *SmallRational) IsZero() bool { 104 return z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 105 } 106 107 func (z *SmallRational) Inverse(x *SmallRational) *SmallRational { 108 if x.IsZero() { 109 *z = *x 110 } else { 111 *z = SmallRational{numerator: x.denominator, denominator: x.numerator} 112 z.UpdateText() 113 } 114 115 return z 116 } 117 118 func (z *SmallRational) Neg(x *SmallRational) *SmallRational { 119 z.numerator.Neg(&x.numerator) 120 z.denominator = x.denominator 121 122 if x.text == "" { 123 x.UpdateText() 124 } 125 126 if x.text[0] == '-' { 127 z.text = x.text[1:] 128 } else { 129 z.text = "-" + x.text 130 } 131 132 return z 133 } 134 135 func (z *SmallRational) Double(x *SmallRational) *SmallRational { 136 137 var y big.Int 138 139 if x.denominator.Bit(0) == 0 { 140 z.numerator = x.numerator 141 y.Rsh(&x.denominator, 1) 142 z.denominator = y 143 } else { 144 y.Lsh(&x.numerator, 1) 145 z.numerator = y 146 z.denominator = x.denominator 147 } 148 149 z.UpdateText() 150 151 return z 152 } 153 154 func (z *SmallRational) Sign() int { 155 return z.numerator.Sign() * z.denominator.Sign() 156 } 157 158 func (z *SmallRational) MarshalJSON() ([]byte, error) { 159 return []byte(z.String()), nil 160 } 161 162 func (z *SmallRational) UnmarshalJson(data []byte) error { 163 _, err := z.SetInterface(string(data)) 164 return err 165 } 166 167 func (z *SmallRational) Equal(x *SmallRational) bool { 168 return z.Cmp(x) == 0 169 } 170 171 func (z *SmallRational) Sub(x, y *SmallRational) *SmallRational { 172 var yNeg SmallRational 173 yNeg.Neg(y) 174 z.Add(x, &yNeg) 175 176 z.UpdateText() 177 return z 178 } 179 180 func (z *SmallRational) Cmp(x *SmallRational) int { 181 zSign, xSign := z.Sign(), x.Sign() 182 183 if zSign > xSign { 184 return 1 185 } 186 if zSign < xSign { 187 return -1 188 } 189 190 var Z, X big.Int 191 Z.Mul(&z.numerator, &x.denominator) 192 X.Mul(&x.numerator, &z.denominator) 193 194 Z.Abs(&Z) 195 X.Abs(&X) 196 197 return Z.Cmp(&X) * zSign 198 199 } 200 201 func BatchInvert(a []SmallRational) []SmallRational { 202 res := make([]SmallRational, len(a)) 203 for i := range a { 204 res[i].Inverse(&a[i]) 205 } 206 return res 207 } 208 209 func (z *SmallRational) Mul(x, y *SmallRational) *SmallRational { 210 var num, den big.Int 211 212 num.Mul(&x.numerator, &y.numerator) 213 den.Mul(&x.denominator, &y.denominator) 214 215 z.numerator = num 216 z.denominator = den 217 218 z.simplify() 219 z.UpdateText() 220 return z 221 } 222 223 func (z *SmallRational) SetOne() *SmallRational { 224 return z.SetInt64(1) 225 } 226 227 func (z *SmallRational) SetZero() *SmallRational { 228 return z.SetInt64(0) 229 } 230 231 func (z *SmallRational) SetInt64(i int64) *SmallRational { 232 z.numerator = *big.NewInt(i) 233 z.denominator = *big.NewInt(1) 234 z.text = strconv.FormatInt(i, 10) 235 return z 236 } 237 238 func (z *SmallRational) SetRandom() (*SmallRational, error) { 239 240 bytes := make([]byte, 1) 241 n, err := rand.Read(bytes) 242 if err != nil { 243 return nil, err 244 } 245 if n != len(bytes) { 246 return nil, fmt.Errorf("%d bytes read instead of %d", n, len(bytes)) 247 } 248 249 z.numerator = *big.NewInt(int64(bytes[0]%16) - 8) 250 z.denominator = *big.NewInt(int64((bytes[0]) / 16)) 251 252 z.simplify() 253 z.UpdateText() 254 255 return z, nil 256 } 257 258 func (z *SmallRational) SetUint64(i uint64) { 259 var num big.Int 260 num.SetUint64(i) 261 z.numerator = num 262 z.denominator = *big.NewInt(1) 263 z.text = strconv.FormatUint(i, 10) 264 } 265 266 func (z *SmallRational) IsOne() bool { 267 return z.numerator.Cmp(&z.denominator) == 0 && z.denominator.BitLen() != 0 268 } 269 270 func (z *SmallRational) Text(base int) string { 271 272 if z.denominator.BitLen() == 0 { 273 return "0" 274 } 275 276 if z.denominator.Sign() < 0 { 277 var num, den big.Int 278 num.Neg(&z.numerator) 279 den.Neg(&z.denominator) 280 z.numerator = num 281 z.denominator = den 282 } 283 284 if bigDivides(&z.denominator, &z.numerator) { 285 var num big.Int 286 num.Div(&z.numerator, &z.denominator) 287 z.numerator = num 288 z.denominator = *big.NewInt(1) 289 } 290 291 numerator := z.numerator.Text(base) 292 293 if z.denominator.IsInt64() && z.denominator.Int64() == 1 { 294 return numerator 295 } 296 297 return numerator + "/" + z.denominator.Text(base) 298 } 299 300 func (z *SmallRational) Set(x *SmallRational) *SmallRational { 301 *z = *x // shallow copy is safe because ops are never in place 302 return z 303 } 304 305 func (z *SmallRational) SetInterface(x interface{}) (*SmallRational, error) { 306 307 switch v := x.(type) { 308 case *SmallRational: 309 *z = *v 310 case SmallRational: 311 *z = v 312 case int64: 313 z.SetInt64(v) 314 case int: 315 z.SetInt64(int64(v)) 316 case float64: 317 asInt := int64(v) 318 if float64(asInt) != v { 319 return nil, fmt.Errorf("cannot currently parse float") 320 } 321 z.SetInt64(asInt) 322 case string: 323 z.text = v 324 sep := strings.Split(v, "/") 325 switch len(sep) { 326 case 1: 327 if asInt, err := strconv.Atoi(sep[0]); err == nil { 328 z.SetInt64(int64(asInt)) 329 } else { 330 return nil, err 331 } 332 case 2: 333 var err error 334 var num, denom int 335 num, err = strconv.Atoi(sep[0]) 336 if err != nil { 337 return nil, err 338 } 339 denom, err = strconv.Atoi(sep[1]) 340 if err != nil { 341 return nil, err 342 } 343 z.numerator = *big.NewInt(int64(num)) 344 z.denominator = *big.NewInt(int64(denom)) 345 default: 346 return nil, fmt.Errorf("cannot parse \"%s\"", v) 347 } 348 default: 349 return nil, fmt.Errorf("cannot parse %T", x) 350 } 351 352 return z, nil 353 } 354 355 func bigIntToBytesSigned(dst []byte, src big.Int) { 356 src.FillBytes(dst[1:]) 357 dst[0] = 0 358 if src.Sign() < 0 { 359 dst[0] = 255 360 } 361 } 362 363 func (z *SmallRational) Bytes() [Bytes]byte { 364 var res [Bytes]byte 365 bigIntToBytesSigned(res[:Bytes/2], z.numerator) 366 bigIntToBytesSigned(res[Bytes/2:], z.denominator) 367 return res 368 } 369 370 func bytesToBigIntSigned(src []byte) big.Int { 371 var res big.Int 372 res.SetBytes(src[1:]) 373 if src[0] != 0 { 374 res.Neg(&res) 375 } 376 return res 377 } 378 379 // BigInt returns sets dst to the value of z if it is an integer. 380 // if z is not an integer, nil is returned. 381 // if the given dst is nil, the address of the numerator is returned. 382 // if the given dst is non-nil, it is returned. 383 func (z *SmallRational) BigInt(dst *big.Int) *big.Int { 384 if z.denominator.Cmp(big.NewInt(1)) != 0 { 385 return nil 386 } 387 if dst == nil { 388 return &z.numerator 389 } 390 dst.Set(&z.numerator) 391 return dst 392 } 393 394 func (z *SmallRational) SetBytes(b []byte) { 395 if len(b) > Bytes/2 { 396 z.numerator = bytesToBigIntSigned(b[:Bytes/2]) 397 z.denominator = bytesToBigIntSigned(b[Bytes/2:]) 398 } else { 399 z.numerator.SetBytes(b) 400 z.denominator.SetInt64(1) 401 } 402 z.simplify() 403 z.UpdateText() 404 } 405 406 func Modulus() *big.Int { 407 res := big.NewInt(1) 408 res.Lsh(res, 64) 409 return res 410 }