github.com/piotrnar/gocoin@v0.0.0-20240512203912-faa0448c5e96/lib/secp256k1/schnorr.go (about)

     1  package secp256k1
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding"
     7  	"encoding/hex"
     8  	"hash"
     9  )
    10  
    11  func SchnorrsigChallenge(e *Number, r32, msg32, pubkey32 []byte) {
    12  	s := ShaMidstateChallenge()
    13  
    14  	s.Write(r32)
    15  	s.Write(pubkey32)
    16  	s.Write(msg32)
    17  	e.SetBytes(s.Sum(nil))
    18  }
    19  
    20  func SchnorrVerify(pkey, sig, msg []byte) (ret bool) {
    21  	var rx Field
    22  	var pk, r XY
    23  	var rj, pkj XYZ
    24  	var _s, _e Number
    25  
    26  	rx.SetB32(sig[:32])
    27  	pk.ParseXOnlyPubkey(pkey)
    28  
    29  	SchnorrsigChallenge(&_e, sig[:32], msg, pkey)
    30  	_e.sub(&TheCurve.Order, &_e)
    31  
    32  	_s.SetBytes(sig[32:])
    33  	pkj.SetXY(&pk)
    34  	pkj.ECmult(&rj, &_e, &_s)
    35  
    36  	r.SetXYZ(&rj)
    37  	if r.Infinity {
    38  		return false
    39  	}
    40  
    41  	r.Y.Normalize()
    42  	if r.Y.IsOdd() {
    43  		return false
    44  	}
    45  
    46  	r.X.Normalize()
    47  	return rx.Equals(&r.X)
    48  }
    49  
    50  func get_n_minus(in []byte) []byte {
    51  	var n Number
    52  	n.SetBytes(in)
    53  	n.sub(&TheCurve.Order, &n)
    54  	return n.get_bin(32)
    55  }
    56  
    57  func SchnorrSign(m, sk, a []byte) []byte {
    58  	var xyz XYZ
    59  	var n, x Number
    60  	var P, R XY
    61  	var d, t, k, e, res []byte
    62  
    63  	n.SetBytes(sk) // d
    64  	if n.is_zero() || !n.is_below(&TheCurve.Order) {
    65  		println("SchnorrSign: d out of range")
    66  		return nil
    67  	}
    68  	ECmultGen(&xyz, &n)
    69  	P.SetXYZ(&xyz)
    70  	P.Y.Normalize()
    71  	if P.Y.IsOdd() {
    72  		d = get_n_minus(sk)
    73  	} else {
    74  		d = sk
    75  	}
    76  
    77  	s := ShaMidstateAux()
    78  	s.Write(a)
    79  	t = s.Sum(nil)
    80  	for i := range t {
    81  		t[i] ^= d[i]
    82  	}
    83  
    84  	s = ShaMidstateNonce()
    85  	s.Write(t)
    86  	P.X.Normalize()
    87  	P.X.GetB32(t)
    88  	s.Write(t)
    89  	s.Write(m)
    90  	k0 := s.Sum(nil)
    91  
    92  	n.SetBytes(k0)
    93  	n.mod(&TheCurve.Order)
    94  	if n.is_zero() {
    95  		println("SchnorrSign: k' is zero")
    96  		return nil
    97  	}
    98  	ECmultGen(&xyz, &n)
    99  	R.SetXYZ(&xyz)
   100  	R.Y.Normalize()
   101  	if R.Y.IsOdd() {
   102  		k = get_n_minus(k0)
   103  	} else {
   104  		k = k0
   105  	}
   106  
   107  	res = make([]byte, 64)
   108  	R.X.Normalize()
   109  	P.X.Normalize()
   110  	R.X.GetB32(res[:32])
   111  	P.X.GetB32(res[32:])
   112  	copy(t, res[32:]) // save public key for the verify function
   113  	s = ShaMidstateChallenge()
   114  	s.Write(res)
   115  	s.Write(m)
   116  	e = s.Sum(nil)
   117  
   118  	n.SetBytes(e)
   119  	if !n.is_below(&TheCurve.Order) {
   120  		n.sub(&n, &TheCurve.Order) // we need to use "e mod N"
   121  	}
   122  
   123  	// signature: ((e * d + k) mod N)
   124  	x.SetBytes(d)
   125  	n.mul(&n, &x)
   126  
   127  	x.SetBytes(k)
   128  	n.add(&n, &x)
   129  	n.mod(&TheCurve.Order)
   130  
   131  	copy(res[32:], n.get_bin(32))
   132  	if !SchnorrVerify(t, res, m) {
   133  		println("SchnorrSign: verify error", hex.EncodeToString(res))
   134  		return nil
   135  	}
   136  	return res
   137  }
   138  
   139  func CheckPayToContract(m_keydata, base, hash []byte, parity bool) bool {
   140  	var base_point XY
   141  	base_point.ParseXOnlyPubkey(base)
   142  	return base_point.XOnlyPubkeyTweakAddCheck(m_keydata, parity, hash)
   143  }
   144  
   145  func (pk *XY) XOnlyPubkeyTweakAddCheck(tweaked_pubkey32 []byte, tweaked_pk_parity bool, hash []byte) bool {
   146  	var pk_expected32 [32]byte
   147  	var tweak Number
   148  
   149  	tweak.SetBytes(hash)
   150  	if !pk.ECPublicTweakAdd(&tweak) {
   151  		return false
   152  	}
   153  	pk.X.Normalize()
   154  	pk.Y.Normalize()
   155  	pk.X.GetB32(pk_expected32[:])
   156  
   157  	if bytes.Equal(pk_expected32[:], tweaked_pubkey32) {
   158  		if pk.Y.IsOdd() == tweaked_pk_parity {
   159  			return true
   160  		}
   161  	}
   162  
   163  	return false
   164  }
   165  
   166  func (key *XY) ECPublicTweakAdd(tweak *Number) bool {
   167  	var pt, pt2 XYZ
   168  	var one Number
   169  	pt.SetXY(key)
   170  	one.SetInt64(1)
   171  	pt.ECmult(&pt2, &one, tweak)
   172  	if pt2.IsInfinity() {
   173  		return false
   174  	}
   175  	key.SetXYZ(&pt2)
   176  	return true
   177  }
   178  
   179  var _sha_midstate_challenge, _sha_midstate_aux, _sha_midstate_nonce []byte
   180  
   181  func ShaMidstateChallenge() hash.Hash {
   182  	s := sha256.New()
   183  	unmarshaler, _ := s.(encoding.BinaryUnmarshaler)
   184  	unmarshaler.UnmarshalBinary(_sha_midstate_challenge)
   185  	return s
   186  }
   187  
   188  func ShaMidstateAux() hash.Hash {
   189  	s := sha256.New()
   190  	unmarshaler, _ := s.(encoding.BinaryUnmarshaler)
   191  	unmarshaler.UnmarshalBinary(_sha_midstate_aux)
   192  	return s
   193  }
   194  
   195  func ShaMidstateNonce() hash.Hash {
   196  	s := sha256.New()
   197  	unmarshaler, _ := s.(encoding.BinaryUnmarshaler)
   198  	unmarshaler.UnmarshalBinary(_sha_midstate_nonce)
   199  	return s
   200  }
   201  
   202  func init() {
   203  	s := sha256.New()
   204  	s.Write([]byte("BIP0340/challenge"))
   205  	c := s.Sum(nil)
   206  	s.Reset()
   207  	s.Write(c)
   208  	s.Write(c)
   209  	marshaler, _ := s.(encoding.BinaryMarshaler)
   210  	_sha_midstate_challenge, _ = marshaler.MarshalBinary()
   211  
   212  	s.Reset()
   213  	s.Write([]byte("BIP0340/aux"))
   214  	c = s.Sum(nil)
   215  	s.Reset()
   216  	s.Write(c)
   217  	s.Write(c)
   218  	marshaler, _ = s.(encoding.BinaryMarshaler)
   219  	_sha_midstate_aux, _ = marshaler.MarshalBinary()
   220  
   221  	s.Reset()
   222  	s.Write([]byte("BIP0340/nonce"))
   223  	c = s.Sum(nil)
   224  	s.Reset()
   225  	s.Write(c)
   226  	s.Write(c)
   227  	marshaler, _ = s.(encoding.BinaryMarshaler)
   228  	_sha_midstate_nonce, _ = marshaler.MarshalBinary()
   229  }