github.com/consensys/gnark@v0.11.0/internal/generator/backend/template/representations/solver.go.tmpl (about) 1 import ( 2 "errors" 3 "fmt" 4 "math/big" 5 "sync/atomic" 6 "strings" 7 "strconv" 8 "sync" 9 "math" 10 "github.com/consensys/gnark/constraint" 11 csolver "github.com/consensys/gnark/constraint/solver" 12 "github.com/rs/zerolog" 13 "github.com/consensys/gnark-crypto/ecc" 14 "github.com/consensys/gnark-crypto/field/pool" 15 {{ template "import_fr" . }} 16 ) 17 18 // solver represent the state of the solver during a call to System.Solve(...) 19 type solver struct { 20 *system 21 22 // values and solved are index by the wire (variable) id 23 values []fr.Element 24 solved []bool 25 nbSolved uint64 26 27 // maps hintID to hint function 28 mHintsFunctions map[csolver.HintID]csolver.Hint 29 30 // used to out api.Println 31 logger zerolog.Logger 32 nbTasks int 33 34 a,b,c fr.Vector // R1CS solver will compute the a,b,c matrices 35 36 q *big.Int 37 } 38 39 func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { 40 {{ if not .NoGKR -}} 41 // add GKR options to overwrite the placeholder 42 if cs.GkrInfo.Is() { 43 var gkrData GkrSolvingData 44 opts = append(opts, 45 csolver.OverrideHint(cs.GkrInfo.SolveHintID, GkrSolveHint(cs.GkrInfo, &gkrData)), 46 csolver.OverrideHint(cs.GkrInfo.ProveHintID, GkrProveHint(cs.GkrInfo.HashName, &gkrData))) 47 } 48 {{ end -}} 49 50 // parse options 51 opt, err := csolver.NewConfig(opts...) 52 if err != nil { 53 return nil, err 54 } 55 56 // check witness size 57 witnessOffset := 0 58 if cs.Type == constraint.SystemR1CS { 59 witnessOffset++ 60 } 61 62 nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables 63 expectedWitnessSize := len(cs.Public)-witnessOffset+len(cs.Secret) 64 65 if len(witness) != expectedWitnessSize { 66 return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) 67 } 68 69 // check all hints are there 70 hintFunctions := opt.HintFunctions 71 72 // hintsDependencies is from compile time; it contains the list of hints the solver **needs** 73 var missing []string 74 for hintUUID, hintID := range cs.MHintsDependencies { 75 if _, ok := hintFunctions[hintUUID]; !ok { 76 missing = append(missing, hintID) 77 } 78 } 79 80 if len(missing) > 0 { 81 return nil, fmt.Errorf("solver missing hint(s): %v", missing) 82 } 83 84 s := solver{ 85 system: cs, 86 values: make([]fr.Element, nbWires), 87 solved: make([]bool, nbWires), 88 mHintsFunctions: hintFunctions, 89 logger: opt.Logger, 90 nbTasks: opt.NbTasks, 91 q: cs.Field(), 92 } 93 94 // set the witness indexes as solved 95 if witnessOffset == 1 { 96 s.solved[0] = true // ONE_WIRE 97 s.values[0].SetOne() 98 } 99 copy(s.values[witnessOffset:], witness) 100 for i := range witness { 101 s.solved[i+witnessOffset] = true 102 } 103 104 // keep track of the number of wire instantiations we do, for a post solve sanity check 105 // to ensure we instantiated all wires 106 s.nbSolved += uint64(len(witness) + witnessOffset) 107 108 109 110 if s.Type == constraint.SystemR1CS { 111 n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) 112 s.a = make(fr.Vector, cs.GetNbConstraints(), n) 113 s.b = make(fr.Vector, cs.GetNbConstraints(), n) 114 s.c = make(fr.Vector, cs.GetNbConstraints(), n) 115 } 116 117 return &s, nil 118 } 119 120 121 func (s *solver) set(id int, value fr.Element) { 122 if s.solved[id] { 123 panic("solving the same wire twice should never happen.") 124 } 125 s.values[id] = value 126 s.solved[id] = true 127 atomic.AddUint64(&s.nbSolved, 1) 128 } 129 130 131 // computeTerm computes coeff*variable 132 func (s *solver) computeTerm(t constraint.Term) fr.Element { 133 cID, vID := t.CoeffID(), t.WireID() 134 135 if t.IsConstant() { 136 return s.Coefficients[cID] 137 } 138 139 if cID != 0 && !s.solved[vID] { 140 panic("computing a term with an unsolved wire") 141 } 142 143 switch cID { 144 case constraint.CoeffIdZero: 145 return fr.Element{} 146 case constraint.CoeffIdOne: 147 return s.values[vID] 148 case constraint.CoeffIdTwo: 149 var res fr.Element 150 res.Double(&s.values[vID]) 151 return res 152 case constraint.CoeffIdMinusOne: 153 var res fr.Element 154 res.Neg(&s.values[vID]) 155 return res 156 default: 157 var res fr.Element 158 res.Mul(&s.Coefficients[cID], &s.values[vID]) 159 return res 160 } 161 } 162 163 // r += (t.coeff*t.value) 164 // TODO @gbotrel check t.IsConstant on the caller side when necessary 165 func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { 166 cID := t.CoeffID() 167 vID := t.WireID() 168 169 if t.IsConstant() { 170 r.Add(r, &s.Coefficients[cID]) 171 return 172 } 173 174 switch cID { 175 case constraint.CoeffIdZero: 176 return 177 case constraint.CoeffIdOne: 178 r.Add(r, &s.values[vID]) 179 case constraint.CoeffIdTwo: 180 var res fr.Element 181 res.Double(&s.values[vID]) 182 r.Add(r, &res) 183 case constraint.CoeffIdMinusOne: 184 r.Sub(r, &s.values[vID]) 185 default: 186 var res fr.Element 187 res.Mul(&s.Coefficients[cID], &s.values[vID]) 188 r.Add(r, &res) 189 } 190 } 191 192 // solveWithHint executes a hint and assign the result to its defined outputs. 193 func (s *solver) solveWithHint(h *constraint.HintMapping) error { 194 // ensure hint function was provided 195 f, ok := s.mHintsFunctions[h.HintID] 196 if !ok { 197 return errors.New("missing hint function") 198 } 199 200 // tmp IO big int memory 201 nbInputs := len(h.Inputs) 202 nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) 203 inputs := make([]*big.Int, nbInputs) 204 outputs := make([]*big.Int, nbOutputs) 205 for i :=0; i < nbOutputs; i++ { 206 outputs[i] = pool.BigInt.Get() 207 outputs[i].SetUint64(0) 208 } 209 210 q := pool.BigInt.Get() 211 q.Set(s.q) 212 213 for i := 0; i < nbInputs; i++ { 214 var v fr.Element 215 for _, term := range h.Inputs[i] { 216 if term.IsConstant() { 217 v.Add(&v, &s.Coefficients[term.CoeffID()]) 218 continue 219 } 220 s.accumulateInto(term, &v) 221 } 222 inputs[i] = pool.BigInt.Get() 223 v.BigInt(inputs[i]) 224 } 225 226 227 err := f(q, inputs, outputs) 228 229 var v fr.Element 230 for i := range outputs { 231 v.SetBigInt(outputs[i]) 232 s.set(int(h.OutputRange.Start) + i, v) 233 pool.BigInt.Put(outputs[i]) 234 } 235 236 for i := range inputs { 237 pool.BigInt.Put(inputs[i]) 238 } 239 240 pool.BigInt.Put(q) 241 242 return err 243 } 244 245 func (s *solver) printLogs(logs []constraint.LogEntry) { 246 if s.logger.GetLevel() == zerolog.Disabled { 247 return 248 } 249 250 for i := 0; i < len(logs); i++ { 251 logLine := s.logValue(logs[i]) 252 s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) 253 } 254 } 255 256 const unsolvedVariable = "<unsolved>" 257 258 func (s *solver) logValue(log constraint.LogEntry) string { 259 var toResolve []interface{} 260 var ( 261 eval fr.Element 262 missingValue bool 263 ) 264 for j := 0; j < len(log.ToResolve); j++ { 265 // before eval le 266 267 missingValue = false 268 eval.SetZero() 269 270 for _, t := range log.ToResolve[j] { 271 // for each term in the linear expression 272 273 cID, vID := t.CoeffID(), t.WireID() 274 if t.IsConstant() { 275 // just add the constant 276 eval.Add(&eval, &s.Coefficients[cID]) 277 continue 278 } 279 280 if !s.solved[vID] { 281 missingValue = true 282 break // stop the loop we can't evaluate. 283 } 284 285 tv := s.computeTerm(t) 286 eval.Add(&eval, &tv) 287 } 288 289 290 // after 291 if missingValue { 292 toResolve = append(toResolve, unsolvedVariable) 293 } else { 294 // we have to append our accumulator 295 toResolve = append(toResolve, eval.String()) 296 } 297 298 } 299 if len(log.Stack) > 0 { 300 var sbb strings.Builder 301 for _, lID := range log.Stack { 302 location := s.SymbolTable.Locations[lID] 303 function := s.SymbolTable.Functions[location.FunctionID] 304 305 sbb.WriteString(function.Name) 306 sbb.WriteByte('\n') 307 sbb.WriteByte('\t') 308 sbb.WriteString(function.Filename) 309 sbb.WriteByte(':') 310 sbb.WriteString(strconv.Itoa(int(location.Line))) 311 sbb.WriteByte('\n') 312 } 313 toResolve = append(toResolve, sbb.String()) 314 } 315 return fmt.Sprintf(log.Format, toResolve...) 316 } 317 318 319 // divByCoeff sets res = res / t.Coeff 320 func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { 321 switch cID { 322 case constraint.CoeffIdOne: 323 return 324 case constraint.CoeffIdMinusOne: 325 res.Neg(res) 326 case constraint.CoeffIdZero: 327 panic("division by 0") 328 default: 329 // this is slow, but shouldn't happen as divByCoeff is called to 330 // remove the coeff of an unsolved wire 331 // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 332 res.Div(res, &solver.Coefficients[cID]) 333 } 334 } 335 336 337 338 339 // Implement constraint.Solver 340 func (s *solver) GetValue(cID, vID uint32) constraint.Element { 341 var r constraint.Element 342 e := s.computeTerm(constraint.Term{CID:cID,VID: vID}) 343 copy(r[:], e[:]) 344 return r 345 } 346 func (s *solver) GetCoeff(cID uint32) constraint.Element { 347 var r constraint.Element 348 copy(r[:], s.Coefficients[cID][:]) 349 return r 350 } 351 func (s *solver) SetValue(vID uint32, f constraint.Element) { 352 s.set(int(vID), *(*fr.Element)(f[:])) 353 } 354 355 func (s *solver) IsSolved(vID uint32) bool { 356 return s.solved[vID] 357 } 358 359 // Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), 360 // evaluates it and return the result and the number of uint32 word read. 361 func (s *solver) Read(calldata []uint32) (constraint.Element, int) { 362 if s.Type == constraint.SystemSparseR1CS { 363 if calldata[0] != 1 { 364 panic("invalid calldata") 365 } 366 return s.GetValue(calldata[1], calldata[2]), 3 367 } 368 var r fr.Element 369 n := int(calldata[0]) 370 j := 1 371 for k:= 0; k < n; k++ { 372 // we read k Terms 373 s.accumulateInto(constraint.Term{CID:calldata[j],VID: calldata[j+1]} , &r) 374 j+=2 375 } 376 377 var ret constraint.Element 378 copy(ret[:], r[:]) 379 return ret, j 380 } 381 382 383 384 // processInstruction decodes the instruction and execute blueprint-defined logic. 385 // an instruction can encode a hint, a custom constraint or a generic constraint. 386 func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { 387 // fetch the blueprint 388 blueprint := solver.Blueprints[pi.BlueprintID] 389 inst := pi.Unpack(&solver.System) 390 cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only 391 392 if solver.Type == constraint.SystemR1CS { 393 if bc, ok := blueprint.(constraint.BlueprintR1C); ok { 394 // TODO @gbotrel we use the solveR1C method for now, having user-defined 395 // blueprint for R1CS would require constraint.Solver interface to add methods 396 // to set a,b,c since it's more efficient to compute these while we solve. 397 bc.DecompressR1C(&scratch.tR1C, inst) 398 return solver.solveR1C(cID, &scratch.tR1C) 399 } 400 } 401 402 // blueprint declared "I know how to solve this." 403 if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { 404 if err := bc.Solve(solver, inst); err != nil { 405 return solver.wrapErrWithDebugInfo(cID, err) 406 } 407 return nil 408 } 409 410 // blueprint encodes a hint, we execute. 411 // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" 412 if bc, ok := blueprint.(constraint.BlueprintHint); ok { 413 bc.DecompressHint(&scratch.tHint,inst) 414 return solver.solveWithHint(&scratch.tHint) 415 } 416 417 418 return nil 419 } 420 421 422 // run runs the solver. it return an error if a constraint is not satisfied or if not all wires 423 // were instantiated. 424 func (solver *solver) run() error { 425 // minWorkPerCPU is the minimum target number of constraint a task should hold 426 // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed 427 // sequentially without sync. 428 const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. 429 430 // cs.Levels has a list of levels, where all constraints in a level l(n) are independent 431 // and may only have dependencies on previous levels 432 // for each constraint 433 // we are guaranteed that each R1C contains at most one unsolved wire 434 // first we solve the unsolved wire (if any) 435 // then we check that the constraint is valid 436 // if a[i] * b[i] != c[i]; it means the constraint is not satisfied 437 var wg sync.WaitGroup 438 chTasks := make(chan []uint32, solver.nbTasks) 439 chError := make(chan error, solver.nbTasks) 440 441 // start a worker pool 442 // each worker wait on chTasks 443 // a task is a slice of constraint indexes to be solved 444 for i := 0; i < solver.nbTasks; i++ { 445 go func() { 446 var scratch scratch 447 for t := range chTasks { 448 for _, i := range t { 449 if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { 450 chError <- err 451 wg.Done() 452 return 453 } 454 } 455 wg.Done() 456 } 457 }() 458 } 459 460 // clean up pool go routines 461 defer func() { 462 close(chTasks) 463 close(chError) 464 }() 465 466 var scratch scratch 467 468 // for each level, we push the tasks 469 for _, level := range solver.Levels { 470 471 // max CPU to use 472 maxCPU := float64(len(level)) / minWorkPerCPU 473 474 if maxCPU <= 1.0 || solver.nbTasks == 1 { 475 // we do it sequentially 476 for _, i := range level { 477 if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { 478 return err 479 } 480 } 481 continue 482 } 483 484 // number of tasks for this level is set to number of CPU 485 // but if we don't have enough work for all our CPU, it can be lower. 486 nbTasks := solver.nbTasks 487 maxTasks := int(math.Ceil(maxCPU)) 488 if nbTasks > maxTasks { 489 nbTasks = maxTasks 490 } 491 nbIterationsPerCpus := len(level) / nbTasks 492 493 // more CPUs than tasks: a CPU will work on exactly one iteration 494 // note: this depends on minWorkPerCPU constant 495 if nbIterationsPerCpus < 1 { 496 nbIterationsPerCpus = 1 497 nbTasks = len(level) 498 } 499 500 501 extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) 502 extraTasksOffset := 0 503 504 for i := 0; i < nbTasks; i++ { 505 wg.Add(1) 506 _start := i*nbIterationsPerCpus + extraTasksOffset 507 _end := _start + nbIterationsPerCpus 508 if extraTasks > 0 { 509 _end++ 510 extraTasks-- 511 extraTasksOffset++ 512 } 513 // since we're never pushing more than num CPU tasks 514 // we will never be blocked here 515 chTasks <- level[_start:_end] 516 } 517 518 // wait for the level to be done 519 wg.Wait() 520 521 if len(chError) > 0 { 522 return <-chError 523 } 524 } 525 526 if int(solver.nbSolved) != len(solver.values) { 527 return errors.New("solver didn't assign a value to all wires") 528 } 529 530 return nil 531 } 532 533 534 535 // solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly 536 // 537 // returns an error if the solver called a hint function that errored 538 // returns false, nil if there was no wire to solve 539 // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that 540 // the constraint is satisfied later. 541 func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { 542 a, b, c := &solver.a[cID],&solver.b[cID], &solver.c[cID] 543 544 // the index of the non-zero entry shows if L, R or O has an uninstantiated wire 545 // the content is the ID of the wire non instantiated 546 var loc uint8 547 548 var termToCompute constraint.Term 549 550 processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { 551 for _, t := range l { 552 vID := t.WireID() 553 554 // wire is already computed, we just accumulate in val 555 if solver.solved[vID] { 556 solver.accumulateInto(t, val) 557 continue 558 } 559 560 if loc != 0 { 561 panic("found more than one wire to instantiate") 562 } 563 termToCompute = t 564 loc = locValue 565 } 566 } 567 568 processLExp(r.L, a, 1) 569 processLExp(r.R, b, 2) 570 processLExp(r.O, c, 3) 571 572 573 574 if loc == 0 { 575 // there is nothing to solve, may happen if we have an assertion 576 // (ie a constraints that doesn't yield any output) 577 // or if we solved the unsolved wires with hint functions 578 var check fr.Element 579 if !check.Mul(a, b).Equal(c) { 580 return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) 581 } 582 return nil 583 } 584 585 // we compute the wire value and instantiate it 586 wID := termToCompute.WireID() 587 588 // solver result 589 var wire fr.Element 590 591 592 switch loc { 593 case 1: 594 if !b.IsZero() { 595 wire.Div(c, b). 596 Sub(&wire, a) 597 a.Add(a, &wire) 598 } else { 599 // we didn't actually ensure that a * b == c 600 var check fr.Element 601 if !check.Mul(a, b).Equal(c) { 602 return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) 603 } 604 } 605 case 2: 606 if !a.IsZero() { 607 wire.Div(c, a). 608 Sub(&wire, b) 609 b.Add(b, &wire) 610 } else { 611 var check fr.Element 612 if !check.Mul(a, b).Equal(c) { 613 return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) 614 } 615 } 616 case 3: 617 wire.Mul(a, b). 618 Sub(&wire, c) 619 620 c.Add(c, &wire) 621 } 622 623 // wire is the term (coeff * value) 624 // but in the solver we want to store the value only 625 // note that in gnark frontend, coeff here is always 1 or -1 626 solver.divByCoeff(&wire, termToCompute.CID) 627 solver.set(wID, wire) 628 629 return nil 630 } 631 632 // UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint 633 type UnsatisfiedConstraintError struct { 634 Err error 635 CID int // constraint ID 636 DebugInfo *string // optional debug info 637 } 638 639 func (r *UnsatisfiedConstraintError) Error() string { 640 if r.DebugInfo != nil { 641 return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) 642 } 643 return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) 644 } 645 646 647 func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { 648 var debugInfo *string 649 if dID, ok := solver.MDebug[int(cID)]; ok { 650 debugInfo = new(string) 651 *debugInfo = solver.logValue(solver.DebugInfo[dID]) 652 } 653 return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} 654 } 655 656 // temporary variables to avoid memallocs in hotloop 657 type scratch struct { 658 tR1C constraint.R1C 659 tHint constraint.HintMapping 660 } 661