github.com/klaytn/klaytn@v1.12.1/crypto/bn256/cloudflare/lattice.go (about)

     1  // Copyright 2018 The klaytn Authors
     2  //
     3  // This file is derived from crypto/bn256/cloudflare/lattice.go (2018/06/04).
     4  // See LICENSE in the top directory for the original copyright and license.
     5  
     6  package bn256
     7  
     8  import (
     9  	"math/big"
    10  
    11  	"github.com/klaytn/klaytn/common"
    12  )
    13  
    14  var half = new(big.Int).Rsh(Order, 1)
    15  
    16  var curveLattice = &lattice{
    17  	vectors: [][]*big.Int{
    18  		{bigFromBase10("147946756881789319000765030803803410728"), bigFromBase10("147946756881789319010696353538189108491")},
    19  		{bigFromBase10("147946756881789319020627676272574806254"), bigFromBase10("-147946756881789318990833708069417712965")},
    20  	},
    21  	inverse: []*big.Int{
    22  		bigFromBase10("147946756881789318990833708069417712965"),
    23  		bigFromBase10("147946756881789319010696353538189108491"),
    24  	},
    25  	det: bigFromBase10("43776485743678550444492811490514550177096728800832068687396408373151616991234"),
    26  }
    27  
    28  var targetLattice = &lattice{
    29  	vectors: [][]*big.Int{
    30  		{bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697763"), bigFromBase10("9931322734385697764")},
    31  		{bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848882"), bigFromBase10("-9931322734385697762")},
    32  		{bigFromBase10("-9931322734385697762"), bigFromBase10("-4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("-4965661367192848882")},
    33  		{bigFromBase10("9931322734385697763"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881")},
    34  	},
    35  	inverse: []*big.Int{
    36  		bigFromBase10("734653495049373973658254490726798021314063399421879442165"),
    37  		bigFromBase10("147946756881789319000765030803803410728"),
    38  		bigFromBase10("-147946756881789319005730692170996259609"),
    39  		bigFromBase10("1469306990098747947464455738335385361643788813749140841702"),
    40  	},
    41  	det: new(big.Int).Set(Order),
    42  }
    43  
    44  type lattice struct {
    45  	vectors [][]*big.Int
    46  	inverse []*big.Int
    47  	det     *big.Int
    48  }
    49  
    50  // decompose takes a scalar mod Order as input and finds a short, positive decomposition of it wrt to the lattice basis.
    51  func (l *lattice) decompose(k *big.Int) []*big.Int {
    52  	n := len(l.inverse)
    53  
    54  	// Calculate closest vector in lattice to <k,0,0,...> with Babai's rounding.
    55  	c := make([]*big.Int, n)
    56  	for i := 0; i < n; i++ {
    57  		c[i] = new(big.Int).Mul(k, l.inverse[i])
    58  		round(c[i], l.det)
    59  	}
    60  
    61  	// Transform vectors according to c and subtract <k,0,0,...>.
    62  	out := make([]*big.Int, n)
    63  	temp := new(big.Int)
    64  
    65  	for i := 0; i < n; i++ {
    66  		out[i] = new(big.Int)
    67  
    68  		for j := 0; j < n; j++ {
    69  			temp.Mul(c[j], l.vectors[j][i])
    70  			out[i].Add(out[i], temp)
    71  		}
    72  
    73  		out[i].Neg(out[i])
    74  		out[i].Add(out[i], l.vectors[0][i]).Add(out[i], l.vectors[0][i])
    75  	}
    76  	out[0].Add(out[0], k)
    77  
    78  	return out
    79  }
    80  
    81  func (l *lattice) Precompute(add func(i, j uint)) {
    82  	n := uint(len(l.vectors))
    83  	total := uint(1) << n
    84  
    85  	for i := uint(0); i < n; i++ {
    86  		for j := uint(0); j < total; j++ {
    87  			if (j>>i)&1 == 1 {
    88  				add(i, j)
    89  			}
    90  		}
    91  	}
    92  }
    93  
    94  func (l *lattice) Multi(scalar *big.Int) []uint8 {
    95  	decomp := l.decompose(scalar)
    96  
    97  	maxLen := 0
    98  	for _, x := range decomp {
    99  		if x.BitLen() > maxLen {
   100  			maxLen = x.BitLen()
   101  		}
   102  	}
   103  
   104  	out := make([]uint8, maxLen)
   105  	for j, x := range decomp {
   106  		for i := 0; i < maxLen; i++ {
   107  			out[i] += uint8(x.Bit(i)) << uint(j)
   108  		}
   109  	}
   110  
   111  	return out
   112  }
   113  
   114  // round sets num to num/denom rounded to the nearest integer.
   115  func round(num, denom *big.Int) {
   116  	r := new(big.Int)
   117  	num.DivMod(num, denom, r)
   118  
   119  	if r.Cmp(half) == 1 {
   120  		num.Add(num, common.Big1)
   121  	}
   122  }