github.com/consensys/gnark@v0.11.0/internal/generator/backend/template/representations/tests/r1cs.go.tmpl (about)

     1  
     2  import (
     3  	"bytes"
     4  	"testing"
     5  	"reflect"
     6  	"github.com/consensys/gnark/frontend"
     7  	"github.com/consensys/gnark/frontend/cs/r1cs"
     8  	"github.com/consensys/gnark/frontend/cs/scs"
     9  	"github.com/consensys/gnark/internal/backend/circuits"
    10  
    11  	"github.com/google/go-cmp/cmp"
    12  	"github.com/google/go-cmp/cmp/cmpopts"
    13  
    14  	{{ template "import_backend_cs" . }}
    15  	{{ template "import_fr" . }}
    16  )
    17  
    18  func TestSerialization(t *testing.T) {
    19  	
    20  	var buffer, buffer2 bytes.Buffer
    21  	
    22  	for name := range circuits.Circuits {
    23  		t.Run(name, func(t *testing.T) {
    24  		tc := circuits.Circuits[name]
    25  		{{- if eq .Curve "BW6-761"}}
    26  			if testing.Short() && name != "reference_small" {
    27  				return
    28  			}
    29  		{{- else if eq .Curve "tinyfield"}}
    30  			if name == "range_constant" {
    31  				return
    32  			}
    33  		{{- end}}
    34  
    35  		r1cs1, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder,	tc.Circuit)
    36  		if err != nil {
    37  			t.Fatal(err)
    38  		}
    39  		if testing.Short() && r1cs1.GetNbConstraints() > 50 {
    40  			return
    41  		}
    42  
    43  		// compile a second time to ensure determinism
    44  		r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit)
    45  		if err != nil {
    46  			t.Fatal(err)
    47  		}
    48  
    49  		{
    50  			buffer.Reset()
    51  			t.Log(name)
    52  			var err error
    53  			var written, read int64
    54  			written, err = r1cs1.WriteTo(&buffer)
    55  			if err != nil {
    56  				t.Fatal(err)
    57  			}
    58  			var reconstructed cs.R1CS
    59  			read, err = reconstructed.ReadFrom(&buffer)
    60  			if err != nil {
    61  				t.Fatal(err)
    62  			}
    63  			if written != read {
    64  				t.Fatal("didn't read same number of bytes we wrote")
    65  			}
    66  
    67  			// compare original and reconstructed
    68  			if diff := cmp.Diff(r1cs1, &reconstructed, 
    69  				cmpopts.IgnoreFields(cs.R1CS{},
    70  					 "System.q",
    71  					 "field",
    72  					 "CoeffTable.mCoeffs",
    73  					 "System.lbWireLevel",
    74  					 "System.genericHint",
    75  					 "System.SymbolTable",
    76  					 "System.bitLen")); diff != "" {
    77  				t.Fatalf("round trip mismatch (-want +got):\n%s", diff)
    78  			}
    79  		}
    80  
    81  		// ensure determinism in compilation / serialization / reconstruction
    82  		{
    83  			buffer.Reset()
    84  			n, err := r1cs1.WriteTo(&buffer)
    85  			if err != nil {
    86  				t.Fatal(err)
    87  			}
    88  			if n == 0 {
    89  				t.Fatal("No bytes are written")
    90  			}
    91  
    92  			buffer2.Reset()
    93  			_, err = r1cs2.WriteTo(&buffer2)
    94  			if err != nil {
    95  				t.Fatal(err)
    96  			}
    97  
    98  			if !bytes.Equal(buffer.Bytes(), buffer2.Bytes()) {
    99  				t.Fatal("compilation of R1CS is not deterministic")
   100  			}
   101  
   102  			var r, r2 cs.R1CS
   103  			n, err = r.ReadFrom(&buffer)
   104  			if err != nil {
   105  				t.Fatal(nil)
   106  			}
   107  			if n == 0 {
   108  				t.Fatal("No bytes are read")
   109  			}
   110  			_, err = r2.ReadFrom(&buffer2)
   111  			if err != nil {
   112  				t.Fatal(nil)
   113  			}
   114  
   115  			if !reflect.DeepEqual(r, r2) {
   116  				t.Fatal("compilation of R1CS is not deterministic (reconstruction)")
   117  			}
   118  		}
   119  		})
   120  
   121  	}
   122  }
   123  
   124  
   125  const n = 10000
   126  
   127  type circuit struct {
   128  	X frontend.Variable
   129  	Y frontend.Variable `gnark:",public"`
   130  }
   131  
   132  func (circuit *circuit) Define(api frontend.API) error {
   133  	for i := 0; i < n; i++ {
   134  		circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42)
   135  	}
   136  	api.AssertIsEqual(circuit.X, circuit.Y)
   137  	return nil
   138  }
   139  
   140  func BenchmarkSolve(b *testing.B) {
   141  
   142  
   143  	var w circuit
   144  	w.X = 1 
   145  	w.Y = 1 
   146  	witness, err := frontend.NewWitness(&w, fr.Modulus())
   147  	if err != nil {
   148  		b.Fatal(err)
   149  	}
   150  
   151  	b.Run("scs", func(b *testing.B) {
   152  		var c circuit 
   153  		ccs, err := frontend.Compile(fr.Modulus(),scs.NewBuilder, &c)
   154  		if err != nil {
   155  			b.Fatal(err)
   156  		}
   157  		b.Log("scs nbConstraints", ccs.GetNbConstraints())
   158  	
   159  		b.ResetTimer()
   160  		for i := 0; i < b.N; i++ {
   161  			_ =  ccs.IsSolved(witness)
   162  		}
   163  	})
   164  
   165  	b.Run("r1cs", func(b *testing.B) {
   166  		var c circuit 
   167  		ccs, err := frontend.Compile(fr.Modulus(),r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10))
   168  		if err != nil {
   169  			b.Fatal(err)
   170  		}
   171  		b.Log("r1cs nbConstraints", ccs.GetNbConstraints())
   172  	
   173  		b.ResetTimer()
   174  		for i := 0; i < b.N; i++ {
   175  			_ =  ccs.IsSolved(witness)
   176  		}
   177  	})
   178  
   179  }