github.com/cloudflare/circl@v1.5.0/zk/dleq/dleq.go (about) 1 // Package dleq provides zero-knowledge proofs of Discrete-Logarithm Equivalence (DLEQ). 2 // 3 // This implementation is compatible with the one used for VOPRFs [1]. 4 // It supports batching proofs to amortize the cost of the proof generation and 5 // verification. 6 // 7 // References: 8 // 9 // [1] RFC-9497: https://www.rfc-editor.org/info/rfc9497 10 package dleq 11 12 import ( 13 "crypto" 14 "encoding/binary" 15 "io" 16 17 "github.com/cloudflare/circl/group" 18 ) 19 20 const ( 21 labelSeed = "Seed-" 22 labelChallenge = "Challenge" 23 labelComposite = "Composite" 24 labelHashToScalar = "HashToScalar-" 25 ) 26 27 type Params struct { 28 G group.Group 29 H crypto.Hash 30 DST []byte 31 } 32 33 type Proof struct { 34 c, s group.Scalar 35 } 36 37 type Prover struct{ Params } 38 39 func (p Prover) Prove(k group.Scalar, a, ka, b, kb group.Element, rnd io.Reader) (*Proof, error) { 40 return p.ProveBatch(k, a, ka, []group.Element{b}, []group.Element{kb}, rnd) 41 } 42 43 func (p Prover) ProveWithRandomness(k group.Scalar, a, ka, b, kb group.Element, rnd group.Scalar) (*Proof, error) { 44 return p.ProveBatchWithRandomness(k, a, ka, []group.Element{b}, []group.Element{kb}, rnd) 45 } 46 47 func (p Prover) ProveBatch(k group.Scalar, a, ka group.Element, bi, kbi []group.Element, rnd io.Reader) (*Proof, error) { 48 return p.ProveBatchWithRandomness(k, a, ka, bi, kbi, p.Params.G.RandomScalar(rnd)) 49 } 50 51 func (p Prover) ProveBatchWithRandomness( 52 k group.Scalar, 53 a, ka group.Element, 54 bi, kbi []group.Element, 55 rnd group.Scalar, 56 ) (*Proof, error) { 57 M, Z, err := p.computeComposites(k, ka, bi, kbi) 58 if err != nil { 59 return nil, err 60 } 61 62 kAm, err := ka.MarshalBinaryCompress() 63 if err != nil { 64 return nil, err 65 } 66 67 a0, err := M.MarshalBinaryCompress() 68 if err != nil { 69 return nil, err 70 } 71 72 a1, err := Z.MarshalBinaryCompress() 73 if err != nil { 74 return nil, err 75 } 76 77 t2 := p.G.NewElement().Mul(a, rnd) 78 a2, err := t2.MarshalBinaryCompress() 79 if err != nil { 80 return nil, err 81 } 82 83 t3 := p.G.NewElement().Mul(M, rnd) 84 a3, err := t3.MarshalBinaryCompress() 85 if err != nil { 86 return nil, err 87 } 88 89 cc := p.doChallenge([5][]byte{kAm, a0, a1, a2, a3}) 90 ss := p.G.NewScalar() 91 ss.Mul(cc, k) 92 ss.Sub(rnd, ss) 93 94 return &Proof{cc, ss}, nil 95 } 96 97 func (p Params) computeComposites( 98 k group.Scalar, 99 ka group.Element, 100 bi []group.Element, 101 kbi []group.Element, 102 ) (m, z group.Element, err error) { 103 kAm, err := ka.MarshalBinaryCompress() 104 if err != nil { 105 return nil, nil, err 106 } 107 108 lenBuf := []byte{0, 0} 109 H := p.H.New() 110 111 binary.BigEndian.PutUint16(lenBuf, uint16(len(kAm))) 112 mustWrite(H, lenBuf) 113 mustWrite(H, kAm) 114 115 seedDST := append(append([]byte{}, labelSeed...), p.DST...) 116 binary.BigEndian.PutUint16(lenBuf, uint16(len(seedDST))) 117 mustWrite(H, lenBuf) 118 mustWrite(H, seedDST) 119 120 seed := H.Sum(nil) 121 122 m = p.G.Identity() 123 z = p.G.Identity() 124 h2sDST := append(append([]byte{}, labelHashToScalar...), p.DST...) 125 for j := range bi { 126 h2Input := []byte{} 127 128 Bij, err := bi[j].MarshalBinaryCompress() 129 if err != nil { 130 return nil, nil, err 131 } 132 133 kBij, err := kbi[j].MarshalBinaryCompress() 134 if err != nil { 135 return nil, nil, err 136 } 137 138 binary.BigEndian.PutUint16(lenBuf, uint16(len(seed))) 139 h2Input = append(append(h2Input, lenBuf...), seed...) 140 141 binary.BigEndian.PutUint16(lenBuf, uint16(j)) 142 h2Input = append(h2Input, lenBuf...) 143 144 binary.BigEndian.PutUint16(lenBuf, uint16(len(Bij))) 145 h2Input = append(append(h2Input, lenBuf...), Bij...) 146 147 binary.BigEndian.PutUint16(lenBuf, uint16(len(kBij))) 148 h2Input = append(append(h2Input, lenBuf...), kBij...) 149 150 h2Input = append(h2Input, labelComposite...) 151 dj := p.G.HashToScalar(h2Input, h2sDST) 152 Mj := p.G.NewElement() 153 Mj.Mul(bi[j], dj) 154 m.Add(m, Mj) 155 156 if k == nil { 157 Zj := p.G.NewElement() 158 Zj.Mul(kbi[j], dj) 159 z.Add(z, Zj) 160 } 161 } 162 163 if k != nil { 164 z.Mul(m, k) 165 } 166 167 return m, z, nil 168 } 169 170 func (p Params) doChallenge(a [5][]byte) group.Scalar { 171 h2Input := []byte{} 172 lenBuf := []byte{0, 0} 173 174 for i := range a { 175 binary.BigEndian.PutUint16(lenBuf, uint16(len(a[i]))) 176 h2Input = append(append(h2Input, lenBuf...), a[i]...) 177 } 178 179 h2Input = append(h2Input, labelChallenge...) 180 dst := append(append([]byte{}, labelHashToScalar...), p.DST...) 181 182 return p.G.HashToScalar(h2Input, dst) 183 } 184 185 type Verifier struct{ Params } 186 187 func (v Verifier) Verify(a, ka, b, kb group.Element, p *Proof) bool { 188 return v.VerifyBatch(a, ka, []group.Element{b}, []group.Element{kb}, p) 189 } 190 191 func (v Verifier) VerifyBatch(a, ka group.Element, bi, kbi []group.Element, p *Proof) bool { 192 g := v.Params.G 193 M, Z, err := v.Params.computeComposites(nil, ka, bi, kbi) 194 if err != nil { 195 return false 196 } 197 198 sA := g.NewElement().Mul(a, p.s) 199 ckA := g.NewElement().Mul(ka, p.c) 200 t2 := g.NewElement().Add(sA, ckA) 201 sM := g.NewElement().Mul(M, p.s) 202 cZ := g.NewElement().Mul(Z, p.c) 203 t3 := g.NewElement().Add(sM, cZ) 204 205 kAm, err := ka.MarshalBinaryCompress() 206 if err != nil { 207 return false 208 } 209 210 a0, err := M.MarshalBinaryCompress() 211 if err != nil { 212 return false 213 } 214 a1, err := Z.MarshalBinaryCompress() 215 if err != nil { 216 return false 217 } 218 a2, err := t2.MarshalBinaryCompress() 219 if err != nil { 220 return false 221 } 222 a3, err := t3.MarshalBinaryCompress() 223 if err != nil { 224 return false 225 } 226 227 gotC := v.Params.doChallenge([5][]byte{kAm, a0, a1, a2, a3}) 228 229 return gotC.IsEqual(p.c) 230 } 231 232 func (p *Proof) MarshalBinary() ([]byte, error) { 233 g := p.c.Group() 234 scalarSize := int(g.Params().ScalarLength) 235 output := make([]byte, 0, 2*scalarSize) 236 237 serC, err := p.c.MarshalBinary() 238 if err != nil { 239 return nil, err 240 } 241 output = append(output, serC...) 242 243 serS, err := p.s.MarshalBinary() 244 if err != nil { 245 return nil, err 246 } 247 output = append(output, serS...) 248 249 return output, nil 250 } 251 252 func (p *Proof) UnmarshalBinary(g group.Group, data []byte) error { 253 scalarSize := int(g.Params().ScalarLength) 254 if len(data) < 2*scalarSize { 255 return io.ErrShortBuffer 256 } 257 258 c := g.NewScalar() 259 err := c.UnmarshalBinary(data[:scalarSize]) 260 if err != nil { 261 return err 262 } 263 264 s := g.NewScalar() 265 err = s.UnmarshalBinary(data[scalarSize : 2*scalarSize]) 266 if err != nil { 267 return err 268 } 269 270 p.c = c 271 p.s = s 272 273 return nil 274 } 275 276 func mustWrite(h io.Writer, bytes []byte) { 277 bytesLen, err := h.Write(bytes) 278 if err != nil { 279 panic(err) 280 } 281 if len(bytes) != bytesLen { 282 panic("dleq: failed to write on hash") 283 } 284 }