github.com/cloudflare/circl@v1.5.0/zk/dleq/dleq.go (about)

     1  // Package dleq provides zero-knowledge proofs of Discrete-Logarithm Equivalence (DLEQ).
     2  //
     3  // This implementation is compatible with the one used for VOPRFs [1].
     4  // It supports batching proofs to amortize the cost of the proof generation and
     5  // verification.
     6  //
     7  // References:
     8  //
     9  //	[1] RFC-9497: https://www.rfc-editor.org/info/rfc9497
    10  package dleq
    11  
    12  import (
    13  	"crypto"
    14  	"encoding/binary"
    15  	"io"
    16  
    17  	"github.com/cloudflare/circl/group"
    18  )
    19  
    20  const (
    21  	labelSeed         = "Seed-"
    22  	labelChallenge    = "Challenge"
    23  	labelComposite    = "Composite"
    24  	labelHashToScalar = "HashToScalar-"
    25  )
    26  
    27  type Params struct {
    28  	G   group.Group
    29  	H   crypto.Hash
    30  	DST []byte
    31  }
    32  
    33  type Proof struct {
    34  	c, s group.Scalar
    35  }
    36  
    37  type Prover struct{ Params }
    38  
    39  func (p Prover) Prove(k group.Scalar, a, ka, b, kb group.Element, rnd io.Reader) (*Proof, error) {
    40  	return p.ProveBatch(k, a, ka, []group.Element{b}, []group.Element{kb}, rnd)
    41  }
    42  
    43  func (p Prover) ProveWithRandomness(k group.Scalar, a, ka, b, kb group.Element, rnd group.Scalar) (*Proof, error) {
    44  	return p.ProveBatchWithRandomness(k, a, ka, []group.Element{b}, []group.Element{kb}, rnd)
    45  }
    46  
    47  func (p Prover) ProveBatch(k group.Scalar, a, ka group.Element, bi, kbi []group.Element, rnd io.Reader) (*Proof, error) {
    48  	return p.ProveBatchWithRandomness(k, a, ka, bi, kbi, p.Params.G.RandomScalar(rnd))
    49  }
    50  
    51  func (p Prover) ProveBatchWithRandomness(
    52  	k group.Scalar,
    53  	a, ka group.Element,
    54  	bi, kbi []group.Element,
    55  	rnd group.Scalar,
    56  ) (*Proof, error) {
    57  	M, Z, err := p.computeComposites(k, ka, bi, kbi)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	kAm, err := ka.MarshalBinaryCompress()
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	a0, err := M.MarshalBinaryCompress()
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	a1, err := Z.MarshalBinaryCompress()
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	t2 := p.G.NewElement().Mul(a, rnd)
    78  	a2, err := t2.MarshalBinaryCompress()
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	t3 := p.G.NewElement().Mul(M, rnd)
    84  	a3, err := t3.MarshalBinaryCompress()
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	cc := p.doChallenge([5][]byte{kAm, a0, a1, a2, a3})
    90  	ss := p.G.NewScalar()
    91  	ss.Mul(cc, k)
    92  	ss.Sub(rnd, ss)
    93  
    94  	return &Proof{cc, ss}, nil
    95  }
    96  
    97  func (p Params) computeComposites(
    98  	k group.Scalar,
    99  	ka group.Element,
   100  	bi []group.Element,
   101  	kbi []group.Element,
   102  ) (m, z group.Element, err error) {
   103  	kAm, err := ka.MarshalBinaryCompress()
   104  	if err != nil {
   105  		return nil, nil, err
   106  	}
   107  
   108  	lenBuf := []byte{0, 0}
   109  	H := p.H.New()
   110  
   111  	binary.BigEndian.PutUint16(lenBuf, uint16(len(kAm)))
   112  	mustWrite(H, lenBuf)
   113  	mustWrite(H, kAm)
   114  
   115  	seedDST := append(append([]byte{}, labelSeed...), p.DST...)
   116  	binary.BigEndian.PutUint16(lenBuf, uint16(len(seedDST)))
   117  	mustWrite(H, lenBuf)
   118  	mustWrite(H, seedDST)
   119  
   120  	seed := H.Sum(nil)
   121  
   122  	m = p.G.Identity()
   123  	z = p.G.Identity()
   124  	h2sDST := append(append([]byte{}, labelHashToScalar...), p.DST...)
   125  	for j := range bi {
   126  		h2Input := []byte{}
   127  
   128  		Bij, err := bi[j].MarshalBinaryCompress()
   129  		if err != nil {
   130  			return nil, nil, err
   131  		}
   132  
   133  		kBij, err := kbi[j].MarshalBinaryCompress()
   134  		if err != nil {
   135  			return nil, nil, err
   136  		}
   137  
   138  		binary.BigEndian.PutUint16(lenBuf, uint16(len(seed)))
   139  		h2Input = append(append(h2Input, lenBuf...), seed...)
   140  
   141  		binary.BigEndian.PutUint16(lenBuf, uint16(j))
   142  		h2Input = append(h2Input, lenBuf...)
   143  
   144  		binary.BigEndian.PutUint16(lenBuf, uint16(len(Bij)))
   145  		h2Input = append(append(h2Input, lenBuf...), Bij...)
   146  
   147  		binary.BigEndian.PutUint16(lenBuf, uint16(len(kBij)))
   148  		h2Input = append(append(h2Input, lenBuf...), kBij...)
   149  
   150  		h2Input = append(h2Input, labelComposite...)
   151  		dj := p.G.HashToScalar(h2Input, h2sDST)
   152  		Mj := p.G.NewElement()
   153  		Mj.Mul(bi[j], dj)
   154  		m.Add(m, Mj)
   155  
   156  		if k == nil {
   157  			Zj := p.G.NewElement()
   158  			Zj.Mul(kbi[j], dj)
   159  			z.Add(z, Zj)
   160  		}
   161  	}
   162  
   163  	if k != nil {
   164  		z.Mul(m, k)
   165  	}
   166  
   167  	return m, z, nil
   168  }
   169  
   170  func (p Params) doChallenge(a [5][]byte) group.Scalar {
   171  	h2Input := []byte{}
   172  	lenBuf := []byte{0, 0}
   173  
   174  	for i := range a {
   175  		binary.BigEndian.PutUint16(lenBuf, uint16(len(a[i])))
   176  		h2Input = append(append(h2Input, lenBuf...), a[i]...)
   177  	}
   178  
   179  	h2Input = append(h2Input, labelChallenge...)
   180  	dst := append(append([]byte{}, labelHashToScalar...), p.DST...)
   181  
   182  	return p.G.HashToScalar(h2Input, dst)
   183  }
   184  
   185  type Verifier struct{ Params }
   186  
   187  func (v Verifier) Verify(a, ka, b, kb group.Element, p *Proof) bool {
   188  	return v.VerifyBatch(a, ka, []group.Element{b}, []group.Element{kb}, p)
   189  }
   190  
   191  func (v Verifier) VerifyBatch(a, ka group.Element, bi, kbi []group.Element, p *Proof) bool {
   192  	g := v.Params.G
   193  	M, Z, err := v.Params.computeComposites(nil, ka, bi, kbi)
   194  	if err != nil {
   195  		return false
   196  	}
   197  
   198  	sA := g.NewElement().Mul(a, p.s)
   199  	ckA := g.NewElement().Mul(ka, p.c)
   200  	t2 := g.NewElement().Add(sA, ckA)
   201  	sM := g.NewElement().Mul(M, p.s)
   202  	cZ := g.NewElement().Mul(Z, p.c)
   203  	t3 := g.NewElement().Add(sM, cZ)
   204  
   205  	kAm, err := ka.MarshalBinaryCompress()
   206  	if err != nil {
   207  		return false
   208  	}
   209  
   210  	a0, err := M.MarshalBinaryCompress()
   211  	if err != nil {
   212  		return false
   213  	}
   214  	a1, err := Z.MarshalBinaryCompress()
   215  	if err != nil {
   216  		return false
   217  	}
   218  	a2, err := t2.MarshalBinaryCompress()
   219  	if err != nil {
   220  		return false
   221  	}
   222  	a3, err := t3.MarshalBinaryCompress()
   223  	if err != nil {
   224  		return false
   225  	}
   226  
   227  	gotC := v.Params.doChallenge([5][]byte{kAm, a0, a1, a2, a3})
   228  
   229  	return gotC.IsEqual(p.c)
   230  }
   231  
   232  func (p *Proof) MarshalBinary() ([]byte, error) {
   233  	g := p.c.Group()
   234  	scalarSize := int(g.Params().ScalarLength)
   235  	output := make([]byte, 0, 2*scalarSize)
   236  
   237  	serC, err := p.c.MarshalBinary()
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  	output = append(output, serC...)
   242  
   243  	serS, err := p.s.MarshalBinary()
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	output = append(output, serS...)
   248  
   249  	return output, nil
   250  }
   251  
   252  func (p *Proof) UnmarshalBinary(g group.Group, data []byte) error {
   253  	scalarSize := int(g.Params().ScalarLength)
   254  	if len(data) < 2*scalarSize {
   255  		return io.ErrShortBuffer
   256  	}
   257  
   258  	c := g.NewScalar()
   259  	err := c.UnmarshalBinary(data[:scalarSize])
   260  	if err != nil {
   261  		return err
   262  	}
   263  
   264  	s := g.NewScalar()
   265  	err = s.UnmarshalBinary(data[scalarSize : 2*scalarSize])
   266  	if err != nil {
   267  		return err
   268  	}
   269  
   270  	p.c = c
   271  	p.s = s
   272  
   273  	return nil
   274  }
   275  
   276  func mustWrite(h io.Writer, bytes []byte) {
   277  	bytesLen, err := h.Write(bytes)
   278  	if err != nil {
   279  		panic(err)
   280  	}
   281  	if len(bytes) != bytesLen {
   282  		panic("dleq: failed to write on hash")
   283  	}
   284  }