github.com/ericlagergren/ctb@v0.0.0-20220810041818-96749d9c394d/lll/big.go (about) 1 package lll 2 3 import ( 4 "fmt" 5 "math/big" 6 ) 7 8 // T is an integer. 9 // 10 // Unlike math/big, T is a value type. 11 type T interface { 12 Sign() int 13 Cmp(T) int 14 CmpAbs(T) int 15 Add(T) T 16 Mul(T) T 17 Sub(T) T 18 Quo(T) T 19 String() string 20 } 21 22 func SetInt(z *big.Int, x T) { 23 switch x := x.(type) { 24 case *Int: 25 z.Set(&x.x) 26 case *Frac: 27 if !x.x.IsInt() { 28 SetInt(z, round(x)) 29 } else { 30 z.Set(x.x.Num()) 31 } 32 } 33 } 34 35 // Int is an integer. 36 // 37 // Int implements T. 38 type Int struct { 39 x big.Int 40 } 41 42 var _ T = (*Int)(nil) 43 44 // I64 creates an Int from x. 45 func I64(x int64) T { 46 var z Int 47 z.x.SetInt64(x) 48 return &z 49 } 50 51 // I copies x into a new Int. 52 func I(x *big.Int) T { 53 var z Int 54 z.x.Set(x) 55 return &z 56 } 57 58 func (x *Int) String() string { 59 return x.x.String() 60 } 61 62 func (x *Int) Sign() int { 63 return x.x.Sign() 64 } 65 66 func (x *Int) CmpAbs(y T) int { 67 switch { 68 case x.Sign() == 0 && y.Sign() == 0: 69 return 0 70 } 71 switch y := y.(type) { 72 case *Int: 73 return x.x.CmpAbs(&y.x) 74 case *Frac: 75 if y.x.IsInt() { 76 return x.x.CmpAbs(y.x.Num()) 77 } 78 var tmp big.Rat 79 tmp.SetInt(&x.x) 80 // Set sign(x) = sign(y) 81 if y.Sign() < 0 { 82 tmp.Neg(&tmp) 83 } else { 84 tmp.Abs(&tmp) 85 } 86 return tmp.Cmp(&y.x) 87 default: 88 panic(fmt.Sprintf("unknown type: %T", y)) 89 } 90 } 91 92 func (x *Int) Cmp(y T) int { 93 switch { 94 case x.Sign() < y.Sign(): 95 return -1 96 case x.Sign() > y.Sign(): 97 return +1 98 case x.Sign() == 0 && y.Sign() == 0: 99 return 0 100 } 101 switch y := y.(type) { 102 case *Int: 103 return x.x.Cmp(&y.x) 104 case *Frac: 105 var tmp big.Rat 106 tmp.SetInt(&x.x) 107 return tmp.Cmp(&y.x) 108 default: 109 panic(fmt.Sprintf("unknown type: %T", y)) 110 } 111 } 112 113 func (x *Int) Add(y T) T { 114 switch y := y.(type) { 115 case *Int: 116 var z Int 117 z.x.Add(&x.x, &y.x) 118 return &z 119 case *Frac: 120 if y.x.IsInt() { 121 var z Int 122 z.x.Add(&x.x, y.x.Num()) 123 return &z 124 } 125 var tmp big.Rat 126 tmp.SetInt(&x.x) 127 var z Frac 128 z.x.Add(&tmp, &y.x) 129 return &z 130 default: 131 panic(fmt.Sprintf("unknown type: %T", y)) 132 } 133 } 134 135 func (x *Int) Mul(y T) T { 136 switch y := y.(type) { 137 case *Int: 138 var z Int 139 z.x.Mul(&x.x, &y.x) 140 return &z 141 case *Frac: 142 if y.x.IsInt() { 143 var z Int 144 z.x.Mul(&x.x, y.x.Num()) 145 return &z 146 } 147 var z Frac 148 var tmp big.Rat 149 tmp.SetInt(&x.x) 150 z.x.Mul(&tmp, &y.x) 151 return &z 152 default: 153 panic(fmt.Sprintf("unknown type: %T", y)) 154 } 155 } 156 157 func (x *Int) Sub(y T) T { 158 switch y := y.(type) { 159 case *Int: 160 var z Int 161 z.x.Sub(&x.x, &y.x) 162 return &z 163 case *Frac: 164 if y.x.IsInt() { 165 var z Int 166 z.x.Sub(&x.x, y.x.Num()) 167 return &z 168 } 169 var tmp big.Rat 170 tmp.SetInt(&x.x) 171 var z Frac 172 z.x.Sub(&tmp, &y.x) 173 return &z 174 default: 175 panic(fmt.Sprintf("unknown type: %T", y)) 176 } 177 } 178 179 func (x *Int) Quo(y T) T { 180 switch y := y.(type) { 181 case *Int: 182 var z Frac 183 z.x.SetFrac(&x.x, &y.x) 184 return &z 185 case *Frac: 186 var tmp big.Rat 187 tmp.SetInt(&x.x) 188 var z Frac 189 z.x.Quo(&tmp, &y.x) 190 return &z 191 default: 192 panic(fmt.Sprintf("unknown type: %T", y)) 193 } 194 } 195 196 // Frac is a fraction (rational) number. 197 // 198 // Frac implements T. 199 type Frac struct { 200 x big.Rat 201 } 202 203 var _ T = (*Frac)(nil) 204 205 // F64 creates a Frac from a numerateor and denominator. 206 func F64(n, d int64) T { 207 if d == 1 { 208 return I64(n) 209 } 210 var z Frac 211 z.x.SetFrac64(n, d) 212 return &z 213 } 214 215 // F copies the numerator and denominator into a Frac. 216 func F(n, d *big.Int) T { 217 var z Frac 218 z.x.SetFrac(n, d) 219 return &z 220 } 221 222 func (x *Frac) Sign() int { 223 return x.x.Sign() 224 } 225 226 func (x *Frac) Cmp(y T) int { 227 switch { 228 case x.Sign() < y.Sign(): 229 return -1 230 case x.Sign() > y.Sign(): 231 return +1 232 case x.Sign() == 0 && y.Sign() == 0: 233 return 0 234 } 235 switch y := y.(type) { 236 case *Int: 237 var tmp big.Rat 238 tmp.SetInt(&y.x) 239 return x.x.Cmp(&tmp) 240 case *Frac: 241 return x.x.Cmp(&y.x) 242 default: 243 panic(fmt.Sprintf("unknown type: %T", y)) 244 } 245 } 246 247 func (x *Frac) CmpAbs(y T) int { 248 switch { 249 case x.Sign() == 0 && y.Sign() == 0: 250 return 0 251 } 252 switch y := y.(type) { 253 case *Int: 254 if x.x.IsInt() { 255 return y.x.CmpAbs(x.x.Num()) 256 } 257 r := +1 258 var tmp big.Rat 259 tmp.SetInt(&y.x) 260 // Set sign(y) = sign(x) 261 if x.Sign() < 0 { 262 r = -1 263 tmp.Neg(&tmp) 264 } else { 265 tmp.Abs(&tmp) 266 } 267 return x.x.Cmp(&tmp) * r 268 case *Frac: 269 if x.Sign() == y.Sign() { 270 return x.x.Cmp(&y.x) 271 } 272 r := +1 273 var tmp big.Rat 274 // Set sign(y) = sign(x) 275 if x.Sign() < 0 { 276 r = -1 277 tmp.Neg(&y.x) 278 } else { 279 tmp.Abs(&y.x) 280 } 281 return x.x.Cmp(&tmp) * r 282 default: 283 panic(fmt.Sprintf("unknown type: %T", y)) 284 } 285 } 286 287 func (x *Frac) Add(y T) T { 288 switch y := y.(type) { 289 case *Frac: 290 var z Frac 291 z.x.Add(&x.x, &y.x) 292 return &z 293 case *Int: 294 var tmp big.Rat 295 tmp.SetInt(&y.x) 296 var z Frac 297 z.x.Add(&x.x, &tmp) 298 return &z 299 default: 300 panic(fmt.Sprintf("unknown type: %T", y)) 301 } 302 } 303 304 func (x *Frac) Mul(y T) T { 305 switch y := y.(type) { 306 case *Frac: 307 var z Frac 308 z.x.Mul(&x.x, &y.x) 309 return &z 310 case *Int: 311 var tmp big.Rat 312 tmp.SetInt(&y.x) 313 var z Frac 314 z.x.Mul(&x.x, &tmp) 315 return &z 316 default: 317 panic(fmt.Sprintf("unknown type: %T", y)) 318 } 319 } 320 321 func (x *Frac) Sub(y T) T { 322 switch y := y.(type) { 323 case *Frac: 324 var z Frac 325 z.x.Sub(&x.x, &y.x) 326 return &z 327 case *Int: 328 var tmp big.Rat 329 tmp.SetInt(&y.x) 330 var z Frac 331 z.x.Sub(&x.x, &tmp) 332 return &z 333 default: 334 panic(fmt.Sprintf("unknown type: %T", y)) 335 } 336 } 337 338 func (x *Frac) Quo(y T) T { 339 switch y := y.(type) { 340 case *Frac: 341 var z Frac 342 z.x.Quo(&x.x, &y.x) 343 return &z 344 case *Int: 345 var tmp big.Rat 346 tmp.SetInt(&y.x) 347 var z Frac 348 z.x.Quo(&x.x, &tmp) 349 return &z 350 default: 351 panic(fmt.Sprintf("unknown type: %T", y)) 352 } 353 } 354 355 func (f *Frac) String() string { 356 return f.x.String() 357 } 358 359 var bigOne = big.NewInt(1) 360 361 func round(x T) T { 362 switch x := x.(type) { 363 case *Int: 364 return x 365 case *Frac: 366 if x.x.IsInt() { 367 return x 368 } 369 370 var z Int // result 371 var r big.Int // scratch 372 373 n := x.x.Num() 374 d := x.x.Denom() 375 376 // Rats are always normalized, meaning the following 377 // holds: 378 // if x.IsInt then n.Cmp(d) != 0 379 if n.CmpAbs(d) < 0 { 380 // Proper fraction. 381 if r.Add(n, n).CmpAbs(d) >= 0 { 382 z.x.Add(&z.x, bigOne) 383 } 384 // Round down to zero. 385 return &z 386 } 387 388 // Improper fraction. 389 z.x.QuoRem(n, d, &r) 390 // Is r >= 0.5? If so, round up away from zero. 391 if r.Add(&r, &r).CmpAbs(d) >= 0 { 392 if x.Sign() < 0 { 393 z.x.Sub(&z.x, bigOne) 394 } else { 395 z.x.Add(&z.x, bigOne) 396 } 397 } 398 return &z 399 default: 400 panic(fmt.Sprintf("unknown type: %T", x)) 401 } 402 } 403 404 func sq(x T) T { 405 return x.Mul(x) 406 } 407 408 var ( 409 one = I64(1) 410 half = F64(1, 2) 411 quart = F64(1, 4) 412 ) 413 414 // Reduction computes the Lenstra–Lenstra–Lovász 415 // lattice basis reduction algorithm. 416 // 417 // B is a lattice basis 418 // b0, b1, ... bn in Z^m 419 // delta must be in (1/4, 1), typically 3/4. 420 func Reduction(delta T, B [][]T) [][]T { 421 if delta.Cmp(quart) < 0 || delta.Cmp(one) >= 0 { 422 panic("delta out of range") 423 } 424 Bstar := gramSchmidt(nil, B) 425 mu := func(i, j int) T { 426 return projCoff(Bstar[j], B[i]) 427 } 428 n := len(B) 429 k := 1 430 for k < n { 431 for j := k - 1; j >= 0; j-- { 432 mukj := mu(k, j) 433 if mukj.CmpAbs(half) > 0 { 434 bj := scale(nil, B[j], round(mukj)) 435 B[k] = sub(B[k], B[k], bj) 436 Bstar = gramSchmidt(Bstar, B) 437 } 438 } 439 dmksq := delta.Sub(sq(mu(k, k-1))) 440 pbsk1 := sdot(Bstar[k-1]) 441 if sdot(Bstar[k]).Cmp(dmksq.Mul(pbsk1)) >= 0 { 442 k++ 443 } else { 444 B[k], B[k-1] = B[k-1], B[k] 445 Bstar = gramSchmidt(Bstar, B) 446 k-- 447 if k < 1 { 448 k = 1 449 } 450 } 451 } 452 return B 453 } 454 455 func gramSchmidt(u, v [][]T) [][]T { 456 u = u[:0] 457 for _, vi := range v { 458 ui := vi 459 for _, uj := range u { 460 // ui -= uj*vi 461 uj = proj(nil, uj, vi) 462 ui = sub(nil, ui, uj) 463 } 464 if len(ui) > 0 { 465 u = append(u, ui) 466 } 467 } 468 return u 469 } 470 471 // scale is 472 // for i := range x { 473 // z[i] = x[i] * c 474 // } 475 func scale(z, x []T, c T) []T { 476 z = zmake(z, len(x)) 477 for i := range x { 478 z[i] = x[i].Mul(c) 479 } 480 return z 481 } 482 483 // mul is 484 // for i := range x { 485 // z[i] = x[i] * y[i] 486 // } 487 func mul(z, x, y []T) []T { 488 z = zmake(z, len(x)) 489 for i := range x { 490 z[i] = x[i].Mul(y[i]) 491 } 492 return z 493 } 494 495 // sub is 496 // for i := range x { 497 // z[i] = x[i] - y[i] 498 // } 499 func sub(z, x, y []T) []T { 500 z = zmake(z, len(x)) 501 for i := range x { 502 z[i] = x[i].Sub(y[i]) 503 } 504 return z 505 } 506 507 // proj is 508 // c := projCoff(x, y) 509 // scale(z, x, c) 510 func proj(z, x, y []T) []T { 511 z = zmake(z, len(x)) 512 return scale(z, x, projCoff(x, y)) 513 } 514 515 // projCoff is 516 // dot(x, y) / sdot(x) 517 func projCoff(x, y []T) T { 518 return dot(x, y).Quo(sdot(x)) 519 } 520 521 // dot is 522 // for i := range x { 523 // sum += x[i] * y[i] 524 // } 525 func dot(x, y []T) T { 526 sum := I64(0) 527 for i := range x { 528 sum = sum.Add(x[i].Mul(y[i])) 529 } 530 return sum 531 } 532 533 // sdot is 534 // dot(x, x) 535 func sdot(x []T) T { 536 return dot(x, x) 537 } 538 539 func zmake(z []T, n int) []T { 540 if n <= cap(z) { 541 return z[:n] 542 } 543 return make([]T, n) 544 }