github.com/dusk-network/dusk-crypto@v0.1.3/rangeproof/innerproduct/innerproduct.go (about) 1 package innerproduct 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "io" 8 "math/bits" 9 10 ristretto "github.com/bwesterb/go-ristretto" 11 "github.com/dusk-network/dusk-crypto/rangeproof/fiatshamir" 12 "github.com/dusk-network/dusk-crypto/rangeproof/vector" 13 ) 14 15 // This is a reference of the innerProduct implementation at rust 16 17 // Proof represents an innner product proof struct 18 type Proof struct { 19 L, R []ristretto.Point 20 A, B ristretto.Scalar // a and b are capitalised so that they are exported, in paper it is `a``b` 21 } 22 23 // Generate generates an inner product proof or an error 24 // if proof cannot be constucted 25 func Generate(GVec, HVec []ristretto.Point, aVec, bVec, HprimeFactors []ristretto.Scalar, Q ristretto.Point) (*Proof, error) { 26 n := uint32(len(GVec)) 27 28 // XXX : When n is not a power of two, will the bulletproof struct pad it 29 // or will the inner product proof struct? 30 if !isPower2(uint32(n)) { 31 return nil, errors.New("[IPProof]: size of n (NM) is not a power of 2") 32 } 33 34 a := make([]ristretto.Scalar, len(aVec)) 35 copy(a, aVec) 36 b := make([]ristretto.Scalar, len(bVec)) 37 copy(b, bVec) 38 G := make([]ristretto.Point, len(GVec)) 39 copy(G, GVec) 40 H := make([]ristretto.Point, len(HVec)) 41 copy(H, HVec) 42 43 hs := fiatshamir.HashCacher{Cache: []byte{}} 44 45 lgN := bits.TrailingZeros(nextPow2(uint(n))) 46 47 Lj := make([]ristretto.Point, 0, lgN) 48 Rj := make([]ristretto.Point, 0, lgN) 49 50 if n != 1 { 51 n = n / 2 52 53 aL, aR, err := vector.SplitScalars(a, n) 54 bL, bR, err := vector.SplitScalars(b, n) 55 GL, GR, err := vector.SplitPoints(G, n) 56 HL, HR, err := vector.SplitPoints(H, n) 57 58 cL, err := vector.InnerProduct(aL, bR) 59 if err != nil { 60 61 } 62 cR, err := vector.InnerProduct(aR, bL) 63 if err != nil { 64 return nil, err 65 } 66 67 // L = aL * GR + bR * HL * HPrime[0..n] + cL * Q = e1 + e2 + e3 68 69 e1, err := vector.Exp(aL, GR, int(n), 1) 70 if err != nil { 71 return nil, err 72 73 } 74 75 bRYi := make([]ristretto.Scalar, len(bR)) 76 copy(bRYi, bR) 77 78 for i := range bRYi { 79 bRYi[i].Mul(&bRYi[i], &HprimeFactors[i]) 80 } 81 82 e2, err := vector.Exp(bRYi, HL, int(n), 1) 83 if err != nil { 84 return nil, err 85 } 86 87 var e3 ristretto.Point 88 e3.ScalarMult(&Q, &cL) 89 90 var L ristretto.Point 91 L.SetZero() 92 L.Add(&e1, &e2) 93 L.Add(&L, &e3) 94 95 Lj = append(Lj, L) 96 97 // R = aR * GL + bL * HR * HPrimeFactors[n .. 2n] + cR * Q = e4 + e5 + e6 98 99 e4, err := vector.Exp(aR, GL, int(n), 1) 100 if err != nil { 101 return nil, err 102 } 103 104 bLYi := make([]ristretto.Scalar, len(bL)) 105 copy(bLYi, bL) 106 107 for i := range bLYi { 108 bLYi[i].Mul(&bLYi[i], &HprimeFactors[uint32(i)+n]) 109 } 110 111 e5, err := vector.Exp(bLYi, HR, int(n), 1) 112 if err != nil { 113 return nil, err 114 } 115 116 var e6 ristretto.Point 117 e6.ScalarMult(&Q, &cR) 118 119 var R ristretto.Point 120 R.SetZero() 121 R.Add(&e4, &e5) 122 R.Add(&R, &e6) 123 Rj = append(Rj, R) 124 125 hs.Append(L.Bytes(), R.Bytes()) 126 127 u := hs.Derive() 128 var uinv ristretto.Scalar 129 uinv.Inverse(&u) 130 131 var a1, a2, b1, b2, h1a, h2a ristretto.Scalar 132 var g1, g2, h1, h2 ristretto.Point 133 134 for i := uint32(0); i < n; i++ { 135 136 a1.Mul(&aL[i], &u) 137 a2.Mul(&aR[i], &uinv) 138 aL[i].Add(&a1, &a2) 139 140 b1.Mul(&bL[i], &uinv) 141 b2.Mul(&bR[i], &u) 142 bL[i].Add(&b1, &b2) 143 144 g1.ScalarMult(&GL[i], &uinv) 145 g2.ScalarMult(&GR[i], &u) 146 GL[i].Add(&g1, &g2) 147 148 h1a.Mul(&HprimeFactors[i], &u) 149 h1.ScalarMult(&HL[i], &h1a) 150 h2a.Mul(&HprimeFactors[i+n], &uinv) 151 h2.ScalarMult(&HR[i], &h2a) 152 HL[i].Add(&h1, &h2) 153 } 154 155 a = aL 156 b = bL 157 G = GL 158 H = HL 159 } 160 161 for n > 1 { 162 163 n = n / 2 164 165 aL, aR, err := vector.SplitScalars(a, n) 166 bL, bR, err := vector.SplitScalars(b, n) 167 GL, GR, err := vector.SplitPoints(G, n) 168 HL, HR, err := vector.SplitPoints(H, n) 169 170 cL, err := vector.InnerProduct(aL, bR) 171 if err != nil { 172 return nil, err 173 } 174 cR, err := vector.InnerProduct(aR, bL) 175 if err != nil { 176 return nil, err 177 } 178 179 // L = aL * GR + bR * HL + cL * Q = e1 + e2 + e3 180 181 e1, err := vector.Exp(aL, GR, int(n), 1) 182 if err != nil { 183 return nil, err 184 } 185 e2, err := vector.Exp(bR, HL, int(n), 1) 186 if err != nil { 187 return nil, err 188 } 189 var e3 ristretto.Point 190 e3.ScalarMult(&Q, &cL) 191 192 var L ristretto.Point 193 L.SetZero() 194 L.Add(&e1, &e2) 195 L.Add(&L, &e3) 196 197 Lj = append(Lj, L) 198 199 // R = aR * GL + bL * HR + cR * Q = e4 + e5 + e6 200 201 e4, err := vector.Exp(aR, GL, int(n), 1) 202 if err != nil { 203 return nil, err 204 } 205 e5, err := vector.Exp(bL, HR, int(n), 1) 206 if err != nil { 207 return nil, err 208 } 209 var e6 ristretto.Point 210 e6.ScalarMult(&Q, &cR) 211 212 var R ristretto.Point 213 R.SetZero() 214 R.Add(&e4, &e5) 215 R.Add(&R, &e6) 216 Rj = append(Rj, R) 217 218 hs.Append(L.Bytes(), R.Bytes()) 219 220 u := hs.Derive() 221 var uinv ristretto.Scalar 222 uinv.Inverse(&u) 223 224 // aL = aL * u + aR *uinv = a1 + a2 - aprime 225 // bL = bR * u + bL *uinv = b1 + b2 - bprime 226 // GL = GL * uinv + GR * u = g1 + g2 - gprime 227 // HL = HL * u + HR * uinv = h1 + h2 - hprime 228 229 var a1, a2, b1, b2 ristretto.Scalar 230 var g1, g2, h1, h2 ristretto.Point 231 232 for i := uint32(0); i < n; i++ { 233 234 a1.Mul(&aL[i], &u) 235 a2.Mul(&aR[i], &uinv) 236 aL[i].Add(&a1, &a2) 237 238 b1.Mul(&bL[i], &uinv) 239 b2.Mul(&bR[i], &u) 240 bL[i].Add(&b1, &b2) 241 242 g1.ScalarMult(&GL[i], &uinv) 243 g2.ScalarMult(&GR[i], &u) 244 GL[i].Add(&g1, &g2) 245 246 h1.ScalarMult(&HL[i], &u) 247 h2.ScalarMult(&HR[i], &uinv) 248 HL[i].Add(&h1, &h2) 249 } 250 251 a = aL 252 b = bL 253 G = GL 254 H = HL 255 } 256 257 return &Proof{ 258 L: Lj, 259 R: Rj, 260 A: a[len(a)-1], 261 B: b[len(b)-1], 262 }, nil 263 } 264 265 // VerifScalars generates the challenge squared, the inverse challenge squared 266 // and s for a given inner product proof 267 func (proof *Proof) VerifScalars() ([]ristretto.Scalar, []ristretto.Scalar, []ristretto.Scalar) { 268 // generate scalars for verification 269 270 if len(proof.L) != len(proof.R) { 271 return nil, nil, nil 272 } 273 274 lgN := len(proof.L) 275 n := uint32(1 << uint(lgN)) 276 277 hs := fiatshamir.HashCacher{Cache: []byte{}} 278 279 // 1. compute x's 280 xChals := make([]ristretto.Scalar, 0, lgN) 281 for k := range proof.L { 282 hs.Append(proof.L[k].Bytes(), proof.R[k].Bytes()) 283 xChals = append(xChals, hs.Derive()) 284 } 285 286 // 2. compute inverse of x's 287 invXChals := make([]ristretto.Scalar, 0, lgN) 288 289 var invProd ristretto.Scalar // this will be the product of all of the inverses 290 invProd.SetOne() 291 292 for k := range xChals { 293 294 var xinv ristretto.Scalar 295 xinv.Inverse(&xChals[k]) 296 297 invProd.Mul(&invProd, &xinv) 298 299 invXChals = append(invXChals, xinv) 300 } 301 302 // 3. compute x^2 and inv(x)^2 303 chalSq := make([]ristretto.Scalar, 0, lgN) 304 invChalSq := make([]ristretto.Scalar, 0, lgN) 305 306 for k := range xChals { 307 var sq ristretto.Scalar 308 var invSq ristretto.Scalar 309 310 sq.Square(&xChals[k]) 311 invSq.Square(&invXChals[k]) 312 313 chalSq = append(chalSq, sq) 314 invChalSq = append(invChalSq, invSq) 315 } 316 317 // 4. compute s 318 s := make([]ristretto.Scalar, 0, n) 319 320 // push the inverse product 321 s = append(s, invProd) 322 323 for i := uint32(1); i < n; i++ { 324 325 lgI := 32 - 1 - bits.LeadingZeros32(i) 326 k := uint32(1 << uint(lgI)) 327 328 uLgISq := chalSq[(lgN-1)-lgI] 329 330 var sRes ristretto.Scalar 331 sRes.Mul(&s[i-k], &uLgISq) 332 s = append(s, sRes) 333 } 334 335 return chalSq, invChalSq, s 336 } 337 338 // Verify is used for unit tests and verifies that a given proof evaluates to the point P 339 func (proof *Proof) Verify(G, H, L, R []ristretto.Point, HprimeFactor []ristretto.Scalar, Q, P ristretto.Point, n int) bool { 340 uSq, uInvSq, s := proof.VerifScalars() 341 342 sInv := make([]ristretto.Scalar, len(s)) 343 copy(sInv, s) 344 345 // reverse s 346 for i, j := 0, len(sInv)-1; i < j; i, j = i+1, j-1 { 347 sInv[i], sInv[j] = sInv[j], sInv[i] 348 } 349 350 aTimesS := vector.MulScalar(s, proof.A) 351 hTimesbDivS := vector.MulScalar(sInv, proof.B) 352 for i, bDivS := range hTimesbDivS { 353 hTimesbDivS[i].Mul(&bDivS, &HprimeFactor[i]) 354 } 355 356 negUSq := make([]ristretto.Scalar, len(uSq)) 357 for i := range negUSq { 358 negUSq[i].Neg(&uSq[i]) 359 } 360 361 negUInvSq := make([]ristretto.Scalar, len(uInvSq)) 362 for i := range negUInvSq { 363 negUInvSq[i].Neg(&uInvSq[i]) 364 } 365 366 // Scalars 367 scalars := make([]ristretto.Scalar, 0) 368 369 var baseC ristretto.Scalar 370 baseC.Mul(&proof.A, &proof.B) 371 372 scalars = append(scalars, baseC) 373 scalars = append(scalars, aTimesS...) 374 scalars = append(scalars, hTimesbDivS...) 375 scalars = append(scalars, negUSq...) 376 scalars = append(scalars, negUInvSq...) 377 378 // Points 379 points := make([]ristretto.Point, 0) 380 points = append(points, Q) 381 points = append(points, G...) 382 points = append(points, H...) 383 points = append(points, proof.L...) 384 points = append(points, proof.R...) 385 386 have, err := vector.Exp(scalars, points, n, 1) 387 if err != nil { 388 return false 389 } 390 return have.Equals(&P) 391 } 392 393 // Encode a Proof 394 func (proof *Proof) Encode(w io.Writer) error { 395 396 err := binary.Write(w, binary.BigEndian, proof.A.Bytes()) 397 if err != nil { 398 return err 399 } 400 err = binary.Write(w, binary.BigEndian, proof.B.Bytes()) 401 if err != nil { 402 return err 403 } 404 lenL := uint32(len(proof.L)) 405 406 for i := uint32(0); i < lenL; i++ { 407 err = binary.Write(w, binary.BigEndian, proof.L[i].Bytes()) 408 if err != nil { 409 return err 410 } 411 err = binary.Write(w, binary.BigEndian, proof.R[i].Bytes()) 412 if err != nil { 413 return err 414 } 415 } 416 return nil 417 } 418 419 // Decode a Proof 420 func (proof *Proof) Decode(r io.Reader) error { 421 if proof == nil { 422 return errors.New("struct is nil") 423 } 424 425 var ABytes, BBytes [32]byte 426 err := binary.Read(r, binary.BigEndian, &ABytes) 427 if err != nil { 428 return err 429 } 430 err = binary.Read(r, binary.BigEndian, &BBytes) 431 if err != nil { 432 return err 433 } 434 proof.A.SetBytes(&ABytes) 435 proof.B.SetBytes(&BBytes) 436 437 buf := &bytes.Buffer{} 438 _, err = buf.ReadFrom(r) 439 if err != nil { 440 return err 441 } 442 numBytes := len(buf.Bytes()) 443 if numBytes%32 != 0 { 444 return errors.New("proof was not formatted correctly") 445 } 446 lenL := uint32(numBytes / 64) 447 448 proof.L = make([]ristretto.Point, lenL) 449 proof.R = make([]ristretto.Point, lenL) 450 451 for i := uint32(0); i < lenL; i++ { 452 var LBytes, RBytes [32]byte 453 err = binary.Read(buf, binary.BigEndian, &LBytes) 454 if err != nil { 455 return err 456 } 457 err = binary.Read(buf, binary.BigEndian, &RBytes) 458 if err != nil { 459 return err 460 } 461 proof.L[i].SetBytes(&LBytes) 462 proof.R[i].SetBytes(&RBytes) 463 } 464 465 return nil 466 } 467 468 // Equals test another proof for equality 469 func (proof *Proof) Equals(other Proof) bool { 470 if ok := proof.A.Equals(&other.A); !ok { 471 return false 472 } 473 474 if ok := proof.B.Equals(&other.B); !ok { 475 return false 476 } 477 478 for i := range proof.L { 479 if ok := proof.L[i].Equals(&other.L[i]); !ok { 480 return false 481 } 482 483 if ok := proof.R[i].Equals(&other.R[i]); !ok { 484 return false 485 } 486 } 487 488 return true 489 } 490 491 func nextPow2(n uint) uint { 492 n-- 493 n |= n >> 1 494 n |= n >> 2 495 n |= n >> 4 496 n |= n >> 8 497 n |= n >> 16 498 return n 499 } 500 501 func isPower2(n uint32) bool { 502 return (n & (n - 1)) == 0 503 } 504 505 // DiffNextPow2 checks the closest next pow2 and returns the necessary padding 506 // amount to get to the that 507 func DiffNextPow2(n uint32) uint32 { 508 pow2 := nextPow2(uint(n)) 509 padAmount := uint32(pow2) - n + 1 510 return padAmount 511 }