github.com/ethereum/go-ethereum@v1.16.1/crypto/bn256/cloudflare/lattice.go (about) 1 package bn256 2 3 import ( 4 "math/big" 5 ) 6 7 var half = new(big.Int).Rsh(Order, 1) 8 9 var curveLattice = &lattice{ 10 vectors: [][]*big.Int{ 11 {bigFromBase10("147946756881789319000765030803803410728"), bigFromBase10("147946756881789319010696353538189108491")}, 12 {bigFromBase10("147946756881789319020627676272574806254"), bigFromBase10("-147946756881789318990833708069417712965")}, 13 }, 14 inverse: []*big.Int{ 15 bigFromBase10("147946756881789318990833708069417712965"), 16 bigFromBase10("147946756881789319010696353538189108491"), 17 }, 18 det: bigFromBase10("43776485743678550444492811490514550177096728800832068687396408373151616991234"), 19 } 20 21 var targetLattice = &lattice{ 22 vectors: [][]*big.Int{ 23 {bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697763"), bigFromBase10("9931322734385697764")}, 24 {bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848882"), bigFromBase10("-9931322734385697762")}, 25 {bigFromBase10("-9931322734385697762"), bigFromBase10("-4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("-4965661367192848882")}, 26 {bigFromBase10("9931322734385697763"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881")}, 27 }, 28 inverse: []*big.Int{ 29 bigFromBase10("734653495049373973658254490726798021314063399421879442165"), 30 bigFromBase10("147946756881789319000765030803803410728"), 31 bigFromBase10("-147946756881789319005730692170996259609"), 32 bigFromBase10("1469306990098747947464455738335385361643788813749140841702"), 33 }, 34 det: new(big.Int).Set(Order), 35 } 36 37 type lattice struct { 38 vectors [][]*big.Int 39 inverse []*big.Int 40 det *big.Int 41 } 42 43 // decompose takes a scalar mod Order as input and finds a short, positive decomposition of it wrt to the lattice basis. 44 func (l *lattice) decompose(k *big.Int) []*big.Int { 45 n := len(l.inverse) 46 47 // Calculate closest vector in lattice to <k,0,0,...> with Babai's rounding. 48 c := make([]*big.Int, n) 49 for i := 0; i < n; i++ { 50 c[i] = new(big.Int).Mul(k, l.inverse[i]) 51 round(c[i], l.det) 52 } 53 54 // Transform vectors according to c and subtract <k,0,0,...>. 55 out := make([]*big.Int, n) 56 temp := new(big.Int) 57 58 for i := 0; i < n; i++ { 59 out[i] = new(big.Int) 60 61 for j := 0; j < n; j++ { 62 temp.Mul(c[j], l.vectors[j][i]) 63 out[i].Add(out[i], temp) 64 } 65 66 out[i].Neg(out[i]) 67 out[i].Add(out[i], l.vectors[0][i]).Add(out[i], l.vectors[0][i]) 68 } 69 out[0].Add(out[0], k) 70 71 return out 72 } 73 74 func (l *lattice) Precompute(add func(i, j uint)) { 75 n := uint(len(l.vectors)) 76 total := uint(1) << n 77 78 for i := uint(0); i < n; i++ { 79 for j := uint(0); j < total; j++ { 80 if (j>>i)&1 == 1 { 81 add(i, j) 82 } 83 } 84 } 85 } 86 87 func (l *lattice) Multi(scalar *big.Int) []uint8 { 88 decomp := l.decompose(scalar) 89 90 maxLen := 0 91 for _, x := range decomp { 92 if x.BitLen() > maxLen { 93 maxLen = x.BitLen() 94 } 95 } 96 97 out := make([]uint8, maxLen) 98 for j, x := range decomp { 99 for i := 0; i < maxLen; i++ { 100 out[i] += uint8(x.Bit(i)) << uint(j) 101 } 102 } 103 104 return out 105 } 106 107 // round sets num to num/denom rounded to the nearest integer. 108 func round(num, denom *big.Int) { 109 r := new(big.Int) 110 num.DivMod(num, denom, r) 111 112 if r.Cmp(half) == 1 { 113 num.Add(num, big.NewInt(1)) 114 } 115 }