github.com/consensys/gnark@v0.11.0/test/assert.go (about)

     1  /*
     2  Copyright © 2021 ConsenSys Software Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package test
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"reflect"
    23  	"strings"
    24  	"testing"
    25  
    26  	"github.com/consensys/gnark-crypto/ecc"
    27  	"github.com/consensys/gnark/backend"
    28  	"github.com/consensys/gnark/backend/witness"
    29  	"github.com/consensys/gnark/constraint"
    30  	"github.com/consensys/gnark/frontend"
    31  	"github.com/consensys/gnark/frontend/cs/r1cs"
    32  	"github.com/consensys/gnark/frontend/cs/scs"
    33  	"github.com/consensys/gnark/frontend/schema"
    34  	gnarkio "github.com/consensys/gnark/io"
    35  	"github.com/stretchr/testify/require"
    36  )
    37  
    38  var (
    39  	ErrCompilationNotDeterministic = errors.New("compilation is not deterministic")
    40  	ErrInvalidWitnessSolvedCS      = errors.New("invalid witness solved the constraint system")
    41  	ErrInvalidWitnessVerified      = errors.New("invalid witness resulted in a valid proof")
    42  )
    43  
    44  // Assert is a helper to test circuits
    45  type Assert struct {
    46  	t *testing.T
    47  	*require.Assertions
    48  }
    49  
    50  // NewAssert returns an Assert helper embedding a testify/require object for convenience
    51  //
    52  // The Assert object caches the compiled circuit:
    53  //
    54  // the first call to assert.ProverSucceeded/Failed will compile the circuit for n curves, m backends
    55  // and subsequent calls will re-use the result of the compilation, if available.
    56  func NewAssert(t *testing.T) *Assert {
    57  	return &Assert{t: t, Assertions: require.New(t)}
    58  }
    59  
    60  // Run runs the test function fn as a subtest. The subtest is parametrized by
    61  // the description strings descs.
    62  func (a *Assert) Run(fn func(assert *Assert), descs ...string) {
    63  	desc := strings.Join(descs, "/")
    64  	a.t.Run(desc, func(t *testing.T) {
    65  		assert := &Assert{t, require.New(t)}
    66  		fn(assert)
    67  	})
    68  }
    69  
    70  // Log logs using the test instance logger.
    71  func (assert *Assert) Log(v ...interface{}) {
    72  	assert.t.Log(v...)
    73  }
    74  
    75  // ProverSucceeded is deprecated: use [Assert.CheckCircuit] instead
    76  func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validAssignment frontend.Circuit, opts ...TestingOption) {
    77  	// copy the options
    78  	newOpts := make([]TestingOption, len(opts), len(opts)+2)
    79  	copy(newOpts, opts)
    80  	newOpts = append(newOpts, WithValidAssignment(validAssignment))
    81  
    82  	assert.CheckCircuit(circuit, newOpts...)
    83  }
    84  
    85  // ProverSucceeded is deprecated use [Assert.CheckCircuit] instead
    86  func (assert *Assert) ProverFailed(circuit frontend.Circuit, invalidAssignment frontend.Circuit, opts ...TestingOption) {
    87  	// copy the options
    88  	newOpts := make([]TestingOption, len(opts), len(opts)+2)
    89  	copy(newOpts, opts)
    90  	newOpts = append(newOpts, WithInvalidAssignment(invalidAssignment))
    91  
    92  	assert.CheckCircuit(circuit, newOpts...)
    93  }
    94  
    95  // SolvingSucceeded is deprecated: use [Assert.CheckCircuit] instead
    96  func (assert *Assert) SolvingSucceeded(circuit frontend.Circuit, validWitness frontend.Circuit, opts ...TestingOption) {
    97  
    98  	// copy the options
    99  	newOpts := make([]TestingOption, len(opts), len(opts)+1)
   100  	copy(newOpts, opts)
   101  	newOpts = append(newOpts, WithValidAssignment(validWitness))
   102  
   103  	assert.CheckCircuit(circuit, newOpts...)
   104  }
   105  
   106  // SolvingFailed is deprecated: use CheckCircuit instead
   107  func (assert *Assert) SolvingFailed(circuit frontend.Circuit, invalidWitness frontend.Circuit, opts ...TestingOption) {
   108  	// copy the options
   109  	newOpts := make([]TestingOption, len(opts), len(opts)+1)
   110  	copy(newOpts, opts)
   111  	newOpts = append(newOpts, WithInvalidAssignment(invalidWitness))
   112  
   113  	assert.CheckCircuit(circuit, newOpts...)
   114  }
   115  
   116  func lazySchema(circuit frontend.Circuit) func() *schema.Schema {
   117  	return func() *schema.Schema {
   118  		// we only parse the schema if we need to display the witness in json.
   119  		s, err := schema.New(circuit, tVariable)
   120  		if err != nil {
   121  			panic("couldn't parse schema from circuit: " + err.Error())
   122  		}
   123  		return s
   124  	}
   125  }
   126  
   127  // compile the given circuit for given curve and backend, if not already present in cache
   128  func (assert *Assert) compile(circuit frontend.Circuit, curveID ecc.ID, backendID backend.ID, compileOpts []frontend.CompileOption) (constraint.ConstraintSystem, error) {
   129  	var newBuilder frontend.NewBuilder
   130  
   131  	switch backendID {
   132  	case backend.GROTH16:
   133  		newBuilder = r1cs.NewBuilder
   134  	case backend.PLONK:
   135  		newBuilder = scs.NewBuilder
   136  	default:
   137  		panic("not implemented")
   138  	}
   139  
   140  	// else compile it and ensure it is deterministic
   141  	ccs, err := frontend.Compile(curveID.ScalarField(), newBuilder, circuit, compileOpts...)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  
   146  	_ccs, err := frontend.Compile(curveID.ScalarField(), newBuilder, circuit, compileOpts...)
   147  	if err != nil {
   148  		return nil, fmt.Errorf("%w: %v", ErrCompilationNotDeterministic, err)
   149  	}
   150  
   151  	if !reflect.DeepEqual(ccs, _ccs) {
   152  		return nil, ErrCompilationNotDeterministic
   153  	}
   154  
   155  	return ccs, nil
   156  }
   157  
   158  // error ensure the error is set, else fails the test
   159  // add a witness to the error message if provided
   160  func (assert *Assert) error(err error, w *_witness) {
   161  	if err != nil {
   162  		return
   163  	}
   164  	json := "<nil>"
   165  	if w != nil {
   166  		bjson, err := w.full.ToJSON(lazySchema(w.assignment)())
   167  		if err != nil {
   168  			json = err.Error()
   169  		} else {
   170  			json = string(bjson)
   171  		}
   172  	}
   173  
   174  	e := fmt.Errorf("did not error (but should have)\nwitness:%s", json)
   175  	assert.FailNow(e.Error())
   176  }
   177  
   178  // ensure the error is nil, else fails the test
   179  // add a witness to the error message if provided
   180  func (assert *Assert) noError(err error, w *_witness) {
   181  	if err == nil {
   182  		return
   183  	}
   184  
   185  	e := err
   186  
   187  	if w != nil {
   188  		var json string
   189  		bjson, err := w.full.ToJSON(lazySchema(w.assignment)())
   190  		if err != nil {
   191  			json = err.Error()
   192  		} else {
   193  			json = string(bjson)
   194  		}
   195  		e = fmt.Errorf("%w\nwitness:%s", e, json)
   196  	}
   197  
   198  	assert.FailNow(e.Error())
   199  }
   200  
   201  func (assert *Assert) marshalWitnessJSON(w witness.Witness, s *schema.Schema, curveID ecc.ID, publicOnly bool) {
   202  	var err error
   203  	if publicOnly {
   204  		w, err = w.Public()
   205  		assert.NoError(err)
   206  	}
   207  
   208  	// serialize the vector to binary
   209  	data, err := w.ToJSON(s)
   210  	assert.NoError(err)
   211  
   212  	// re-read
   213  	witness, err := witness.New(curveID.ScalarField())
   214  	assert.NoError(err)
   215  	err = witness.FromJSON(s, data)
   216  	assert.NoError(err)
   217  
   218  	witnessMatch := reflect.DeepEqual(w, witness)
   219  	assert.True(witnessMatch, "round trip marshaling failed")
   220  }
   221  
   222  func (assert *Assert) roundTripCheck(from any, builder func() any, descs ...string) {
   223  	assert.Run(func(assert *Assert) {
   224  		assert.NoError(gnarkio.RoundTripCheck(from, builder))
   225  	}, descs...)
   226  }