github.com/consensys/gnark@v0.11.0/test/solver_test.go (about)

     1  package test
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"math/big"
     7  	"reflect"
     8  	"strconv"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/consensys/gnark/backend/witness"
    13  	"github.com/consensys/gnark/constraint"
    14  	"github.com/consensys/gnark/constraint/solver"
    15  	"github.com/consensys/gnark/debug"
    16  	"github.com/consensys/gnark/frontend"
    17  	"github.com/consensys/gnark/frontend/cs/r1cs"
    18  	"github.com/consensys/gnark/frontend/cs/scs"
    19  	"github.com/consensys/gnark/frontend/schema"
    20  	"github.com/consensys/gnark/internal/backend/circuits"
    21  	"github.com/consensys/gnark/internal/kvstore"
    22  	"github.com/consensys/gnark/internal/tinyfield"
    23  	"github.com/consensys/gnark/internal/utils"
    24  )
    25  
    26  // ignore witness size larger than this bound
    27  const permutterBound = 3
    28  
    29  // r1cs + sparser1cs
    30  const nbSystems = 2
    31  
    32  var builders [2]frontend.NewBuilder
    33  
    34  func TestSolverConsistency(t *testing.T) {
    35  	if testing.Short() {
    36  		t.Skip("skipping R1CS solver test with testing.Short() flag set")
    37  		return
    38  	}
    39  
    40  	// idea is test circuits, we are going to test all possible values of the witness.
    41  	// (hence the choice of a small modulus for the field size)
    42  	//
    43  	// we generate witnesses and compare with the output of big.Int test engine against
    44  	// R1CS and SparseR1CS solvers
    45  
    46  	for name := range circuits.Circuits {
    47  		t.Run(name, func(t *testing.T) {
    48  			tc := circuits.Circuits[name]
    49  			t.Parallel()
    50  			err := consistentSolver(tc.Circuit, tc.HintFunctions)
    51  			if err != nil {
    52  				t.Fatal(err)
    53  			}
    54  		})
    55  	}
    56  }
    57  
    58  // witness used for the permutter. It implements the Witness interface
    59  // using mock methods (only the underlying vector is required).
    60  type permutterWitness struct {
    61  	vector any
    62  }
    63  
    64  func (pw *permutterWitness) WriteTo(w io.Writer) (int64, error) {
    65  	return 0, nil
    66  }
    67  
    68  func (pw *permutterWitness) ReadFrom(r io.Reader) (int64, error) {
    69  	return 0, nil
    70  }
    71  
    72  func (pw *permutterWitness) MarshalBinary() ([]byte, error) {
    73  	return nil, nil
    74  }
    75  
    76  func (pw *permutterWitness) UnmarshalBinary([]byte) error {
    77  	return nil
    78  }
    79  
    80  func (pw *permutterWitness) Public() (witness.Witness, error) {
    81  	return pw, nil
    82  }
    83  
    84  func (pw *permutterWitness) Vector() any {
    85  	return pw.vector
    86  }
    87  
    88  func (pw *permutterWitness) ToJSON(s *schema.Schema) ([]byte, error) {
    89  	return nil, nil
    90  }
    91  
    92  func (pw *permutterWitness) FromJSON(s *schema.Schema, data []byte) error {
    93  	return nil
    94  }
    95  
    96  func (pw *permutterWitness) Fill(nbPublic, nbSecret int, values <-chan any) error {
    97  	return nil
    98  }
    99  
   100  func newPermutterWitness(pv tinyfield.Vector) witness.Witness {
   101  	return &permutterWitness{
   102  		vector: pv,
   103  	}
   104  }
   105  
   106  type permutter struct {
   107  	circuit           frontend.Circuit
   108  	constraintSystems [2]constraint.ConstraintSystem
   109  	witness           []tinyfield.Element
   110  	hints             []solver.Hint
   111  }
   112  
   113  // note that circuit will be mutated and this is not thread safe
   114  func (p *permutter) permuteAndTest(index int) error {
   115  
   116  	for i := 0; i < len(tinyfieldElements); i++ {
   117  		p.witness[index].SetUint64(tinyfieldElements[i])
   118  		if index == len(p.witness)-1 {
   119  
   120  			// we have a unique permutation
   121  			var errorSystems [2]error
   122  			var errorEngines [2]error
   123  
   124  			// 2 constraints systems
   125  			for k := 0; k < nbSystems; k++ {
   126  
   127  				errorSystems[k] = p.solve(k)
   128  
   129  				// solve the cs using test engine
   130  				// first copy the witness in the circuit
   131  				copyWitnessFromVector(p.circuit, p.witness)
   132  				errorEngines[0] = isSolvedEngine(p.circuit, tinyfield.Modulus())
   133  
   134  				copyWitnessFromVector(p.circuit, p.witness)
   135  				errorEngines[1] = isSolvedEngine(p.circuit, tinyfield.Modulus(), SetAllVariablesAsConstants())
   136  
   137  			}
   138  			if (errorSystems[0] == nil) != (errorEngines[0] == nil) ||
   139  				(errorSystems[1] == nil) != (errorEngines[0] == nil) ||
   140  				(errorEngines[0] == nil) != (errorEngines[1] == nil) {
   141  				return fmt.Errorf("errSCS :%s\nerrR1CS :%s\nerrEngine(const=false): %s\nerrEngine(const=true): %s\nwitness: %s",
   142  					formatError(errorSystems[0]),
   143  					formatError(errorSystems[1]),
   144  					formatError(errorEngines[0]),
   145  					formatError(errorEngines[1]),
   146  					formatWitness(p.witness))
   147  			}
   148  
   149  		} else {
   150  			// recurse
   151  			if err := p.permuteAndTest(index + 1); err != nil {
   152  				return err
   153  			}
   154  		}
   155  	}
   156  	return nil
   157  }
   158  
   159  func formatError(err error) string {
   160  	if err == nil {
   161  		return "<nil>"
   162  	}
   163  	return err.Error()
   164  }
   165  
   166  func formatWitness(witness []tinyfield.Element) string {
   167  	var sbb strings.Builder
   168  	sbb.WriteByte('[')
   169  
   170  	for i := 0; i < len(witness); i++ {
   171  		sbb.WriteString(strconv.Itoa(int(witness[i].Uint64())))
   172  		if i != len(witness)-1 {
   173  			sbb.WriteString(", ")
   174  		}
   175  	}
   176  
   177  	sbb.WriteByte(']')
   178  
   179  	return sbb.String()
   180  }
   181  
   182  func (p *permutter) solve(i int) error {
   183  	pw := newPermutterWitness(p.witness)
   184  	_, err := p.constraintSystems[i].Solve(pw, solver.WithHints(p.hints...))
   185  	return err
   186  }
   187  
   188  // isSolvedEngine behaves like test.IsSolved except it doesn't clone the circuit
   189  func isSolvedEngine(c frontend.Circuit, field *big.Int, opts ...TestEngineOption) (err error) {
   190  	e := &engine{
   191  		curveID:   utils.FieldToCurve(field),
   192  		q:         new(big.Int).Set(field),
   193  		constVars: false,
   194  		Store:     kvstore.New(),
   195  	}
   196  	for _, opt := range opts {
   197  		if err := opt(e); err != nil {
   198  			return fmt.Errorf("apply option: %w", err)
   199  		}
   200  	}
   201  
   202  	defer func() {
   203  		if r := recover(); r != nil {
   204  			err = fmt.Errorf("%v\n%s", r, string(debug.Stack()))
   205  		}
   206  	}()
   207  
   208  	if err = c.Define(e); err != nil {
   209  		return fmt.Errorf("define: %w", err)
   210  	}
   211  	if err = callDeferred(e); err != nil {
   212  		return fmt.Errorf("")
   213  	}
   214  
   215  	return
   216  }
   217  
   218  // fill the "to" frontend.Circuit with values from the provided vector
   219  // values are assumed to be ordered [public | secret]
   220  func copyWitnessFromVector(to frontend.Circuit, from []tinyfield.Element) {
   221  	i := 0
   222  	schema.Walk(to, tVariable, func(f schema.LeafInfo, tInput reflect.Value) error {
   223  		if f.Visibility == schema.Public {
   224  			tInput.Set(reflect.ValueOf(from[i]))
   225  			i++
   226  		}
   227  		return nil
   228  	})
   229  
   230  	schema.Walk(to, tVariable, func(f schema.LeafInfo, tInput reflect.Value) error {
   231  		if f.Visibility == schema.Secret {
   232  			tInput.Set(reflect.ValueOf(from[i]))
   233  			i++
   234  		}
   235  		return nil
   236  	})
   237  }
   238  
   239  // ConsistentSolver solves given circuit with all possible witness combinations using internal/tinyfield
   240  //
   241  // Since the goal of this method is to flag potential solver issues, it is not exposed as an API for now
   242  func consistentSolver(circuit frontend.Circuit, hintFunctions []solver.Hint) error {
   243  
   244  	p := permutter{
   245  		circuit: circuit,
   246  		hints:   hintFunctions,
   247  	}
   248  
   249  	// compile the systems
   250  	for i := 0; i < nbSystems; i++ {
   251  
   252  		ccs, err := frontend.Compile(tinyfield.Modulus(), builders[i], circuit)
   253  		if err != nil {
   254  			return err
   255  		}
   256  		p.constraintSystems[i] = ccs
   257  
   258  		if i == 0 { // the -1 is only for r1cs...
   259  			n := ccs.GetNbPublicVariables() - 1 + ccs.GetNbSecretVariables()
   260  			if n > permutterBound {
   261  				return nil
   262  			}
   263  			p.witness = make([]tinyfield.Element, n)
   264  		}
   265  
   266  	}
   267  
   268  	return p.permuteAndTest(0)
   269  }
   270  
   271  // [0, 1, ..., q - 1], with q == tinyfield.Modulus()
   272  var tinyfieldElements []uint64
   273  
   274  func init() {
   275  	n := tinyfield.Modulus().Uint64()
   276  	tinyfieldElements = make([]uint64, n)
   277  	for i := uint64(0); i < n; i++ {
   278  		tinyfieldElements[i] = i
   279  	}
   280  
   281  	builders[0] = r1cs.NewBuilder
   282  	builders[1] = scs.NewBuilder
   283  }