github.com/consensys/gnark@v0.11.0/internal/generator/backend/template/representations/solver.go.tmpl (about)

     1  import (
     2  	"errors"
     3      "fmt"
     4  	"math/big"
     5  	"sync/atomic"
     6  	"strings"
     7  	"strconv"
     8  	"sync"
     9  	"math"
    10      "github.com/consensys/gnark/constraint"
    11  	csolver "github.com/consensys/gnark/constraint/solver"
    12      "github.com/rs/zerolog"
    13  	"github.com/consensys/gnark-crypto/ecc"
    14  	"github.com/consensys/gnark-crypto/field/pool"
    15  	{{ template "import_fr" . }}
    16  )
    17  
    18  // solver represent the state of the solver during a call to System.Solve(...)
    19  type solver struct {
    20  	*system
    21  
    22  	// values and solved are index by the wire (variable) id
    23  	values				 []fr.Element
    24  	solved               []bool
    25  	nbSolved             uint64
    26  
    27  	// maps hintID to hint function
    28  	mHintsFunctions      map[csolver.HintID]csolver.Hint
    29  
    30  	// used to out api.Println
    31  	logger        zerolog.Logger
    32  	nbTasks       int
    33  
    34  	a,b,c fr.Vector // R1CS solver will compute the a,b,c matrices 
    35  
    36  	q *big.Int 
    37  }
    38  
    39  func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) {
    40      {{ if not .NoGKR -}}
    41      // add GKR options to overwrite the placeholder
    42  	if cs.GkrInfo.Is() {
    43  		var gkrData GkrSolvingData
    44  		opts = append(opts,
    45  			csolver.OverrideHint(cs.GkrInfo.SolveHintID, GkrSolveHint(cs.GkrInfo, &gkrData)),
    46  			csolver.OverrideHint(cs.GkrInfo.ProveHintID, GkrProveHint(cs.GkrInfo.HashName, &gkrData)))
    47  	}
    48  	{{ end -}}
    49  
    50  	// parse options
    51  	opt, err := csolver.NewConfig(opts...)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	// check witness size
    57  	witnessOffset := 0
    58  	if cs.Type == constraint.SystemR1CS {
    59  		witnessOffset++
    60  	}
    61  
    62  	nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables
    63  	expectedWitnessSize := len(cs.Public)-witnessOffset+len(cs.Secret)
    64  
    65  	if len(witness) != expectedWitnessSize {
    66  		return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize)
    67  	}
    68  
    69  	// check all hints are there
    70  	hintFunctions := opt.HintFunctions
    71  
    72  	// hintsDependencies is from compile time; it contains the list of hints the solver **needs**
    73  	var missing []string
    74  	for hintUUID, hintID := range cs.MHintsDependencies {
    75  		if _, ok := hintFunctions[hintUUID]; !ok {
    76  			missing = append(missing, hintID)
    77  		}
    78  	}
    79  
    80  	if len(missing) > 0 {
    81  		return nil, fmt.Errorf("solver missing hint(s): %v", missing)
    82  	}
    83  
    84  	s := solver{
    85  			system: cs,
    86  			values: make([]fr.Element, nbWires),
    87  			solved: make([]bool, nbWires),
    88  			mHintsFunctions: hintFunctions,
    89  			logger: opt.Logger,
    90  			nbTasks: opt.NbTasks,
    91  			q: cs.Field(),
    92  	}
    93  
    94  	// set the witness indexes as solved
    95  	if witnessOffset == 1 {
    96  		s.solved[0] = true // ONE_WIRE
    97  		s.values[0].SetOne()
    98  	}
    99  	copy(s.values[witnessOffset:], witness) 
   100  	for i := range witness {
   101  		s.solved[i+witnessOffset] = true
   102  	}
   103  
   104  	// keep track of the number of wire instantiations we do, for a post solve sanity check
   105  	// to ensure we instantiated all wires
   106  	s.nbSolved += uint64(len(witness) + witnessOffset)
   107  
   108  
   109  
   110  	if s.Type == constraint.SystemR1CS {
   111  		n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints()))
   112  		s.a = make(fr.Vector, cs.GetNbConstraints(), n)
   113  		s.b = make(fr.Vector, cs.GetNbConstraints(), n)
   114  		s.c = make(fr.Vector, cs.GetNbConstraints(), n)
   115  	}
   116  
   117  	return &s, nil
   118  }
   119  
   120  
   121  func (s *solver) set(id int, value fr.Element) {
   122  	if s.solved[id] {
   123  		panic("solving the same wire twice should never happen.")
   124  	}
   125  	s.values[id] = value
   126  	s.solved[id] = true
   127  	atomic.AddUint64(&s.nbSolved, 1)
   128  }
   129  
   130  
   131  // computeTerm computes coeff*variable
   132  func (s *solver) computeTerm(t constraint.Term) fr.Element {
   133  	cID, vID := t.CoeffID(), t.WireID()
   134  
   135  	if t.IsConstant() {
   136  		return s.Coefficients[cID]
   137  	}
   138  
   139  	if cID != 0 && !s.solved[vID] {
   140  		panic("computing a term with an unsolved wire")
   141  	}
   142  
   143  	switch cID {
   144  	case constraint.CoeffIdZero:
   145  		return fr.Element{}
   146  	case constraint.CoeffIdOne:
   147  		return s.values[vID]
   148  	case constraint.CoeffIdTwo:
   149  		var res fr.Element
   150  		res.Double(&s.values[vID])
   151  		return res
   152  	case constraint.CoeffIdMinusOne:
   153  		var res fr.Element
   154  		res.Neg(&s.values[vID])
   155  		return res
   156  	default:
   157  		var res fr.Element
   158  		res.Mul(&s.Coefficients[cID], &s.values[vID])
   159  		return res
   160  	}
   161  }
   162  
   163  // r += (t.coeff*t.value)
   164  // TODO @gbotrel check t.IsConstant on the caller side when necessary
   165  func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) {
   166  	cID := t.CoeffID()
   167  	vID := t.WireID()
   168  
   169  	if t.IsConstant() {
   170  		r.Add(r, &s.Coefficients[cID])
   171  		return
   172  	}
   173  
   174  	switch cID {
   175  	case constraint.CoeffIdZero:
   176  		return 
   177  	case constraint.CoeffIdOne:
   178  		r.Add(r, &s.values[vID])
   179  	case constraint.CoeffIdTwo:
   180  		var res fr.Element
   181  		res.Double(&s.values[vID])
   182  		r.Add(r, &res)
   183  	case constraint.CoeffIdMinusOne:
   184  		r.Sub(r, &s.values[vID])
   185  	default:
   186  		var res fr.Element
   187  		res.Mul(&s.Coefficients[cID], &s.values[vID])
   188  		r.Add(r, &res)
   189  	}
   190  }
   191  
   192  // solveWithHint executes a hint and assign the result to its defined outputs.
   193  func (s *solver) solveWithHint(h *constraint.HintMapping) error {
   194  	// ensure hint function was provided
   195  	f, ok := s.mHintsFunctions[h.HintID]
   196  	if !ok {
   197  		return errors.New("missing hint function")
   198  	}
   199  
   200  	// tmp IO big int memory
   201  	nbInputs := len(h.Inputs)
   202  	nbOutputs := int(h.OutputRange.End - h.OutputRange.Start)
   203  	inputs := make([]*big.Int, nbInputs) 
   204  	outputs :=  make([]*big.Int, nbOutputs)
   205  	for i :=0; i < nbOutputs; i++ {
   206  		outputs[i] = pool.BigInt.Get()
   207  		outputs[i].SetUint64(0)
   208  	}
   209  
   210  	q := pool.BigInt.Get()
   211  	q.Set(s.q)
   212  
   213  	for i := 0; i < nbInputs; i++ {
   214  		var v fr.Element
   215  		for _, term := range h.Inputs[i] {
   216  			if term.IsConstant() {
   217  				v.Add(&v, &s.Coefficients[term.CoeffID()])
   218  				continue
   219  			}
   220  			s.accumulateInto(term, &v)
   221  		}
   222  		inputs[i] = pool.BigInt.Get()
   223  		v.BigInt(inputs[i])
   224  	}
   225  
   226  
   227  	err := f(q, inputs, outputs)
   228  
   229  	var v fr.Element
   230  	for i := range outputs {
   231  		v.SetBigInt(outputs[i])
   232  		s.set(int(h.OutputRange.Start) + i, v)
   233  		pool.BigInt.Put(outputs[i])
   234  	}
   235  
   236  	for i := range inputs {
   237  		pool.BigInt.Put(inputs[i])
   238  	}
   239  
   240  	pool.BigInt.Put(q)
   241  
   242  	return err 
   243  }
   244  
   245  func (s *solver) printLogs(logs []constraint.LogEntry) {
   246  	if s.logger.GetLevel() == zerolog.Disabled {
   247  		return
   248  	}
   249  
   250  	for i := 0; i < len(logs); i++ {
   251  		logLine := s.logValue(logs[i])
   252  		s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine)
   253  	}
   254  }
   255  
   256  const unsolvedVariable = "<unsolved>"
   257  
   258  func (s *solver) logValue(log constraint.LogEntry) string {
   259  	var toResolve []interface{}
   260  	var (
   261  		eval         fr.Element
   262  		missingValue bool
   263  	)
   264  	for j := 0; j < len(log.ToResolve); j++ {
   265  		// before eval le 
   266  
   267  		missingValue = false
   268  		eval.SetZero()
   269  
   270  		for _, t := range log.ToResolve[j] {
   271  			// for each term in the linear expression
   272  
   273  			cID, vID := t.CoeffID(), t.WireID()
   274  			if t.IsConstant() {
   275  				// just add the constant
   276  				eval.Add(&eval, &s.Coefficients[cID])
   277  				continue
   278  			}
   279  
   280  			if !s.solved[vID] {
   281  				missingValue = true
   282  				break // stop the loop we can't evaluate.
   283  			}
   284  
   285  			tv := s.computeTerm(t)
   286  			eval.Add(&eval, &tv)
   287  		}
   288  
   289  
   290  		// after
   291  		if missingValue {
   292  			toResolve = append(toResolve, unsolvedVariable)
   293  		} else {
   294  			// we have to append our accumulator
   295  			toResolve = append(toResolve, eval.String())
   296  		}
   297  
   298  	}
   299  	if len(log.Stack) > 0 {
   300  		var sbb strings.Builder 
   301  		for _, lID :=  range log.Stack {
   302  			location := s.SymbolTable.Locations[lID]
   303  			function := s.SymbolTable.Functions[location.FunctionID]
   304  
   305  			sbb.WriteString(function.Name)
   306  			sbb.WriteByte('\n')
   307  			sbb.WriteByte('\t')
   308  			sbb.WriteString(function.Filename)
   309  			sbb.WriteByte(':')
   310  			sbb.WriteString(strconv.Itoa(int(location.Line)))
   311  			sbb.WriteByte('\n')
   312  		}
   313  		toResolve = append(toResolve, sbb.String())
   314  	}
   315  	return fmt.Sprintf(log.Format, toResolve...)
   316  }
   317  
   318  
   319  // divByCoeff sets res = res / t.Coeff
   320  func (solver *solver) divByCoeff(res *fr.Element, cID uint32) {
   321  	switch cID {
   322  	case constraint.CoeffIdOne:
   323  		return
   324  	case constraint.CoeffIdMinusOne:
   325  		res.Neg(res)
   326  	case constraint.CoeffIdZero:
   327  		panic("division by 0")
   328  	default:
   329  		// this is slow, but shouldn't happen as divByCoeff is called to 
   330  		// remove the coeff of an unsolved wire
   331  		// but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1
   332  		res.Div(res, &solver.Coefficients[cID])
   333  	}
   334  }
   335  
   336  
   337  
   338  
   339  // Implement constraint.Solver
   340  func (s *solver) GetValue(cID, vID uint32) constraint.Element {
   341  	var r constraint.Element
   342  	e := s.computeTerm(constraint.Term{CID:cID,VID: vID})	
   343  	copy(r[:], e[:])
   344  	return r
   345  }
   346  func (s *solver) GetCoeff(cID uint32) constraint.Element {
   347  	var r constraint.Element
   348  	copy(r[:], s.Coefficients[cID][:])
   349  	return r
   350  }
   351  func (s *solver) SetValue(vID uint32, f constraint.Element) {
   352  	s.set(int(vID), *(*fr.Element)(f[:]))
   353  }
   354  
   355  func (s *solver) IsSolved(vID uint32) bool {
   356  	return s.solved[vID]
   357  }
   358  
   359  // Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish),
   360  // evaluates it and return the result and the number of uint32 word read.
   361  func (s *solver) Read(calldata []uint32) (constraint.Element, int) {
   362  	if s.Type == constraint.SystemSparseR1CS {
   363  		if calldata[0] != 1 {
   364  			panic("invalid calldata")
   365  		}
   366  		return s.GetValue(calldata[1], calldata[2]), 3
   367  	}
   368  	var r fr.Element
   369  	n := int(calldata[0])
   370  	j := 1
   371  	for k:= 0; k < n; k++ {
   372  		// we read k Terms
   373  		s.accumulateInto(constraint.Term{CID:calldata[j],VID: calldata[j+1]} , &r)
   374  		j+=2
   375  	}
   376  
   377  	var ret constraint.Element
   378  	copy(ret[:], r[:])
   379  	return ret, j
   380  }
   381  
   382  
   383  
   384  // processInstruction decodes the instruction and execute blueprint-defined logic.
   385  // an instruction can encode a hint, a custom constraint or a generic constraint.
   386  func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error {
   387  	// fetch the blueprint
   388  	blueprint := solver.Blueprints[pi.BlueprintID]
   389  	inst := pi.Unpack(&solver.System)
   390  	cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only
   391  
   392  	if solver.Type == constraint.SystemR1CS {
   393  		if bc, ok := blueprint.(constraint.BlueprintR1C); ok {
   394  			// TODO @gbotrel we use the solveR1C method for now, having user-defined
   395  			// blueprint for R1CS would require constraint.Solver interface to add methods
   396  			// to set a,b,c since it's more efficient to compute these while we solve.
   397  			bc.DecompressR1C(&scratch.tR1C, inst)
   398  			return solver.solveR1C(cID, &scratch.tR1C)
   399  		}
   400  	} 
   401  
   402  	// blueprint declared "I know how to solve this."
   403  	if bc, ok := blueprint.(constraint.BlueprintSolvable); ok {
   404  		if err := bc.Solve(solver, inst); err != nil {
   405  			return solver.wrapErrWithDebugInfo(cID, err)
   406  		}
   407  		return nil
   408  	}
   409  
   410  	// blueprint encodes a hint, we execute.
   411  	// TODO @gbotrel may be worth it to move hint logic in blueprint "solve"
   412  	if bc, ok := blueprint.(constraint.BlueprintHint); ok {
   413  		bc.DecompressHint(&scratch.tHint,inst)
   414  		return solver.solveWithHint(&scratch.tHint)
   415  	}
   416  
   417  
   418  	return nil 
   419  }
   420  
   421  
   422  // run runs the solver. it return an error if a constraint is not satisfied or if not all wires
   423  // were instantiated.
   424  func (solver *solver) run() error {
   425  	// minWorkPerCPU is the minimum target number of constraint a task should hold
   426  	// in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
   427  	// sequentially without sync.  
   428  	const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks.
   429  
   430  	// cs.Levels has a list of levels, where all constraints in a level l(n) are independent
   431  	// and may only have dependencies on previous levels
   432  	// for each constraint
   433  	// we are guaranteed that each R1C contains at most one unsolved wire
   434  	// first we solve the unsolved wire (if any)
   435  	// then we check that the constraint is valid
   436  	// if a[i] * b[i] != c[i]; it means the constraint is not satisfied
   437  	var wg sync.WaitGroup 
   438  	chTasks := make(chan []uint32, solver.nbTasks)
   439  	chError := make(chan error, solver.nbTasks)
   440  
   441  	// start a worker pool
   442  	// each worker wait on chTasks
   443  	// a task is a slice of constraint indexes to be solved
   444  	for i := 0; i < solver.nbTasks; i++ {
   445  		go func() {
   446  			var scratch scratch
   447  			for t := range chTasks {
   448  				for _, i := range t {
   449  					if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil {
   450  						chError <- err 
   451  						wg.Done()
   452  						return 
   453  					}
   454  				}
   455  				wg.Done()
   456  			}
   457  		}()
   458  	}
   459  
   460  	// clean up pool go routines 
   461  	defer func() {
   462  		close(chTasks)
   463  		close(chError)
   464  	}()
   465  
   466  	var scratch scratch
   467  
   468  	// for each level, we push the tasks
   469  	for _, level := range solver.Levels {
   470  
   471  		// max CPU to use 
   472  		maxCPU := float64(len(level)) / minWorkPerCPU
   473  
   474  		if maxCPU <= 1.0 || solver.nbTasks == 1 {
   475  			// we do it sequentially 
   476  			for _, i := range level {
   477  				if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil {
   478  					return err 
   479  				}
   480  			}
   481  			continue 
   482  		}
   483  
   484  		// number of tasks for this level is set to number of CPU
   485  		// but if we don't have enough work for all our CPU, it can be lower. 
   486  		nbTasks := solver.nbTasks
   487  		maxTasks := int(math.Ceil(maxCPU))
   488  		if nbTasks > maxTasks {
   489  			nbTasks = maxTasks
   490  		}
   491  		nbIterationsPerCpus := len(level) / nbTasks
   492  	
   493  		// more CPUs than tasks: a CPU will work on exactly one iteration
   494  		// note: this depends on minWorkPerCPU constant
   495  		if nbIterationsPerCpus < 1 {
   496  			nbIterationsPerCpus = 1
   497  			nbTasks = len(level)
   498  		}
   499  	
   500  	
   501  		extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
   502  		extraTasksOffset := 0
   503  	
   504  		for i := 0; i < nbTasks; i++ {
   505  			wg.Add(1)
   506  			_start := i*nbIterationsPerCpus + extraTasksOffset
   507  			_end := _start + nbIterationsPerCpus
   508  			if extraTasks > 0 {
   509  				_end++
   510  				extraTasks--
   511  				extraTasksOffset++
   512  			}
   513  			// since we're never pushing more than num CPU tasks
   514  			// we will never be blocked here
   515  			chTasks <- level[_start:_end]
   516  		}
   517  	
   518  		// wait for the level to be done 
   519  		wg.Wait()
   520  
   521  		if len(chError) > 0 {
   522  			return <-chError
   523  		}
   524  	}
   525  
   526  	if int(solver.nbSolved) != len(solver.values) {
   527  		return errors.New("solver didn't assign a value to all wires")
   528  	}
   529  
   530  	return nil
   531  }
   532  
   533  
   534  
   535  // solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly
   536  // 
   537  // returns an error if the solver called a hint function that errored
   538  // returns false, nil if there was no wire to solve 
   539  // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that 
   540  // the constraint is satisfied later.
   541  func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error {
   542  	a, b, c := &solver.a[cID],&solver.b[cID], &solver.c[cID]
   543  
   544  	// the index of the non-zero entry shows if L, R or O has an uninstantiated wire
   545  	// the content is the ID of the wire non instantiated
   546  	var loc uint8
   547  
   548  	var termToCompute constraint.Term
   549  
   550  	processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8)  {
   551  		for _, t := range l {
   552  			vID := t.WireID()
   553  
   554  			// wire is already computed, we just accumulate in val
   555  			if solver.solved[vID] {
   556  				solver.accumulateInto(t, val)
   557  				continue
   558  			}
   559  
   560  			if loc != 0 {
   561  				panic("found more than one wire to instantiate")
   562  			}
   563  			termToCompute = t
   564  			loc = locValue
   565  		}
   566  	}
   567  
   568  	processLExp(r.L, a, 1)
   569  	processLExp(r.R, b, 2)
   570  	processLExp(r.O, c, 3)
   571  
   572  
   573  
   574  	if loc == 0 {
   575  		// there is nothing to solve, may happen if we have an assertion
   576  		// (ie a constraints that doesn't yield any output)
   577  		// or if we solved the unsolved wires with hint functions
   578  		var check fr.Element 
   579  		if !check.Mul(a, b).Equal(c) {
   580  			return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()))
   581  		}
   582  		return nil
   583  	}
   584  
   585  	// we compute the wire value and instantiate it
   586  	wID := termToCompute.WireID()
   587  
   588  	// solver result
   589  	var wire fr.Element
   590  	
   591  
   592  	switch loc {
   593  	case 1:
   594  		if !b.IsZero() {
   595  			wire.Div(c, b).
   596  				Sub(&wire, a)
   597  			a.Add(a, &wire)
   598  		} else {
   599  			// we didn't actually ensure that a * b == c
   600  			var check fr.Element
   601  			if !check.Mul(a, b).Equal(c) {
   602  				return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()))
   603  			}
   604  		}
   605  	case 2:
   606  		if !a.IsZero() {
   607  			wire.Div(c, a).
   608  				Sub(&wire, b)
   609  			b.Add(b, &wire)
   610  		} else {
   611  			var check fr.Element
   612  			if !check.Mul(a, b).Equal(c) {
   613  				return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()))
   614  			}
   615  		}
   616  	case 3:
   617  		wire.Mul(a, b).
   618  			Sub(&wire, c)
   619  		
   620  		c.Add(c, &wire)
   621  	}
   622  
   623  	// wire is the term (coeff * value)
   624  	// but in the solver we want to store the value only
   625  	// note that in gnark frontend, coeff here is always 1 or -1
   626  	solver.divByCoeff(&wire, termToCompute.CID)
   627  	solver.set(wID, wire)
   628  
   629  	return nil
   630  }
   631  
   632  // UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint
   633  type UnsatisfiedConstraintError struct {
   634  	Err error
   635  	CID int // constraint ID 
   636  	DebugInfo *string // optional debug info
   637  }
   638  
   639  func (r *UnsatisfiedConstraintError) Error() string {
   640  	if r.DebugInfo != nil {
   641  		return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo)
   642  	}
   643  	return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error())
   644  }
   645  
   646  
   647  func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError {
   648  	var debugInfo *string
   649  	if dID, ok := solver.MDebug[int(cID)]; ok {
   650  		debugInfo = new(string)
   651  		*debugInfo = solver.logValue(solver.DebugInfo[dID])
   652  	}
   653  	return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo}
   654  }
   655  
   656  // temporary variables to avoid memallocs in hotloop
   657  type scratch struct {
   658  	tR1C constraint.R1C
   659  	tHint constraint.HintMapping
   660  }
   661