github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/utilities/crypto/bn256/cloudflare/lattice.go (about)

     1  package bn256
     2  
     3  import (
     4  	"math/big"
     5  )
     6  
     7  var half = new(big.Int).Rsh(Order, 1)
     8  
     9  var curveLattice = &lattice{
    10  	vectors: [][]*big.Int{
    11  		{bigFromBase10("147946756881789319000765030803803410728"), bigFromBase10("147946756881789319010696353538189108491")},
    12  		{bigFromBase10("147946756881789319020627676272574806254"), bigFromBase10("-147946756881789318990833708069417712965")},
    13  	},
    14  	inverse: []*big.Int{
    15  		bigFromBase10("147946756881789318990833708069417712965"),
    16  		bigFromBase10("147946756881789319010696353538189108491"),
    17  	},
    18  	det: bigFromBase10("43776485743678550444492811490514550177096728800832068687396408373151616991234"),
    19  }
    20  
    21  var targetLattice = &lattice{
    22  	vectors: [][]*big.Int{
    23  		{bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697763"), bigFromBase10("9931322734385697764")},
    24  		{bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848882"), bigFromBase10("-9931322734385697762")},
    25  		{bigFromBase10("-9931322734385697762"), bigFromBase10("-4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("-4965661367192848882")},
    26  		{bigFromBase10("9931322734385697763"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881")},
    27  	},
    28  	inverse: []*big.Int{
    29  		bigFromBase10("734653495049373973658254490726798021314063399421879442165"),
    30  		bigFromBase10("147946756881789319000765030803803410728"),
    31  		bigFromBase10("-147946756881789319005730692170996259609"),
    32  		bigFromBase10("1469306990098747947464455738335385361643788813749140841702"),
    33  	},
    34  	det: new(big.Int).Set(Order),
    35  }
    36  
    37  type lattice struct {
    38  	vectors [][]*big.Int
    39  	inverse []*big.Int
    40  	det     *big.Int
    41  }
    42  
    43  // decompose takes a scalar mod Order as input and finds a short, positive decomposition of it wrt to the lattice basis.
    44  func (l *lattice) decompose(k *big.Int) []*big.Int {
    45  	n := len(l.inverse)
    46  
    47  	// Calculate closest vector in lattice to <k,0,0,...> with Babai's rounding.
    48  	c := make([]*big.Int, n)
    49  	for i := 0; i < n; i++ {
    50  		c[i] = new(big.Int).Mul(k, l.inverse[i])
    51  		round(c[i], l.det)
    52  	}
    53  
    54  	// Transform vectors according to c and subtract <k,0,0,...>.
    55  	out := make([]*big.Int, n)
    56  	temp := new(big.Int)
    57  
    58  	for i := 0; i < n; i++ {
    59  		out[i] = new(big.Int)
    60  
    61  		for j := 0; j < n; j++ {
    62  			temp.Mul(c[j], l.vectors[j][i])
    63  			out[i].Add(out[i], temp)
    64  		}
    65  
    66  		out[i].Neg(out[i])
    67  		out[i].Add(out[i], l.vectors[0][i]).Add(out[i], l.vectors[0][i])
    68  	}
    69  	out[0].Add(out[0], k)
    70  
    71  	return out
    72  }
    73  
    74  func (l *lattice) Precompute(add func(i, j uint)) {
    75  	n := uint(len(l.vectors))
    76  	total := uint(1) << n
    77  
    78  	for i := uint(0); i < n; i++ {
    79  		for j := uint(0); j < total; j++ {
    80  			if (j>>i)&1 == 1 {
    81  				add(i, j)
    82  			}
    83  		}
    84  	}
    85  }
    86  
    87  func (l *lattice) Multi(scalar *big.Int) []uint8 {
    88  	decomp := l.decompose(scalar)
    89  
    90  	maxLen := 0
    91  	for _, x := range decomp {
    92  		if x.BitLen() > maxLen {
    93  			maxLen = x.BitLen()
    94  		}
    95  	}
    96  
    97  	out := make([]uint8, maxLen)
    98  	for j, x := range decomp {
    99  		for i := 0; i < maxLen; i++ {
   100  			out[i] += uint8(x.Bit(i)) << uint(j)
   101  		}
   102  	}
   103  
   104  	return out
   105  }
   106  
   107  // round sets num to num/denom rounded to the nearest integer.
   108  func round(num, denom *big.Int) {
   109  	r := new(big.Int)
   110  	num.DivMod(num, denom, r)
   111  
   112  	if r.Cmp(half) == 1 {
   113  		num.Add(num, big.NewInt(1))
   114  	}
   115  }