github.com/consensys/gnark@v0.11.0/debug_test.go (about)

     1  package gnark_test
     2  
     3  import (
     4  	"bytes"
     5  	"math/big"
     6  	"testing"
     7  
     8  	"github.com/consensys/gnark-crypto/ecc"
     9  	"github.com/consensys/gnark/backend"
    10  	"github.com/consensys/gnark/backend/groth16"
    11  	"github.com/consensys/gnark/backend/plonk"
    12  	"github.com/consensys/gnark/constraint/solver"
    13  	"github.com/consensys/gnark/debug"
    14  	"github.com/consensys/gnark/frontend"
    15  	"github.com/consensys/gnark/frontend/cs/r1cs"
    16  	"github.com/consensys/gnark/frontend/cs/scs"
    17  	"github.com/consensys/gnark/test/unsafekzg"
    18  	"github.com/rs/zerolog"
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  // -------------------------------------------------------------------------------------------------
    23  // test println (non regression)
    24  type printlnCircuit struct {
    25  	A, B frontend.Variable
    26  }
    27  
    28  func (circuit *printlnCircuit) Define(api frontend.API) error {
    29  	c := api.Add(circuit.A, circuit.B)
    30  	api.Println(c, "is the addition")
    31  	d := api.Mul(circuit.A, c)
    32  	api.Println(d, new(big.Int).SetInt64(42))
    33  	bs := api.ToBinary(circuit.B, 10)
    34  	api.Println("bits", bs[3])
    35  	api.Println("circuit", circuit)
    36  	nb := api.Mul(bs[1], 2)
    37  	api.AssertIsBoolean(nb) // this will fail
    38  	m := api.Mul(circuit.A, circuit.B)
    39  	api.Println("m", m) // this should not be resolved
    40  	return nil
    41  }
    42  
    43  func TestPrintln(t *testing.T) {
    44  	assert := require.New(t)
    45  
    46  	var circuit, witness printlnCircuit
    47  	witness.A = 2
    48  	witness.B = 11
    49  
    50  	var expected bytes.Buffer
    51  	expected.WriteString("debug_test.go:30 > 13 is the addition\n")
    52  	expected.WriteString("debug_test.go:32 > 26 42\n")
    53  	expected.WriteString("debug_test.go:34 > bits 1\n")
    54  	expected.WriteString("debug_test.go:35 > circuit {A: 2, B: 11}\n")
    55  	expected.WriteString("debug_test.go:39 > m .*\n")
    56  
    57  	{
    58  		trace, _ := getGroth16Trace(&circuit, &witness)
    59  		assert.Regexp(expected.String(), trace)
    60  	}
    61  
    62  	{
    63  		trace, _ := getPlonkTrace(&circuit, &witness)
    64  		assert.Regexp(expected.String(), trace)
    65  	}
    66  }
    67  
    68  // -------------------------------------------------------------------------------------------------
    69  // Div by 0
    70  type divBy0Trace struct {
    71  	A, B, C frontend.Variable
    72  }
    73  
    74  func (circuit *divBy0Trace) Define(api frontend.API) error {
    75  	d := api.Add(circuit.B, circuit.C)
    76  	api.Div(circuit.A, d)
    77  	return nil
    78  }
    79  
    80  func TestTraceDivBy0(t *testing.T) {
    81  	if !debug.Debug {
    82  		t.Skip("skipping test in non debug mode")
    83  	}
    84  	assert := require.New(t)
    85  
    86  	var circuit, witness divBy0Trace
    87  	witness.A = 2
    88  	witness.B = -2
    89  	witness.C = 2
    90  
    91  	{
    92  		_, err := getGroth16Trace(&circuit, &witness)
    93  		assert.Error(err)
    94  		assert.Contains(err.Error(), "constraint #0 is not satisfied: [div] 2/0 == <unsolved>")
    95  		assert.Contains(err.Error(), "(*divBy0Trace).Define")
    96  		assert.Contains(err.Error(), "debug_test.go:")
    97  	}
    98  
    99  	{
   100  		_, err := getPlonkTrace(&circuit, &witness)
   101  		assert.Error(err)
   102  		if debug.Debug {
   103  			assert.Contains(err.Error(), "constraint #1 is not satisfied: [inverse] 1/0 < ∞")
   104  			assert.Contains(err.Error(), "(*divBy0Trace).Define")
   105  			assert.Contains(err.Error(), "debug_test.go:")
   106  		} else {
   107  			assert.Contains(err.Error(), "constraint #1 is not satisfied: division by 0")
   108  		}
   109  
   110  	}
   111  }
   112  
   113  // -------------------------------------------------------------------------------------------------
   114  // Not Equal
   115  type notEqualTrace struct {
   116  	A, B, C frontend.Variable
   117  }
   118  
   119  func (circuit *notEqualTrace) Define(api frontend.API) error {
   120  	d := api.Add(circuit.B, circuit.C)
   121  	api.AssertIsEqual(circuit.A, d)
   122  	return nil
   123  }
   124  
   125  func TestTraceNotEqual(t *testing.T) {
   126  	assert := require.New(t)
   127  
   128  	var circuit, witness notEqualTrace
   129  	witness.A = 1
   130  	witness.B = 24
   131  	witness.C = 42
   132  
   133  	{
   134  		_, err := getGroth16Trace(&circuit, &witness)
   135  		assert.Error(err)
   136  		if debug.Debug {
   137  			assert.Contains(err.Error(), "constraint #0 is not satisfied: [assertIsEqual] 1 == 66")
   138  			assert.Contains(err.Error(), "(*notEqualTrace).Define")
   139  			assert.Contains(err.Error(), "debug_test.go:")
   140  		} else {
   141  			assert.Contains(err.Error(), "constraint #0 is not satisfied: 1 ⋅ 1 != 66")
   142  		}
   143  	}
   144  
   145  	{
   146  		_, err := getPlonkTrace(&circuit, &witness)
   147  		assert.Error(err)
   148  		if debug.Debug {
   149  			assert.Contains(err.Error(), "constraint #1 is not satisfied: [assertIsEqual] 1 == 66")
   150  			assert.Contains(err.Error(), "(*notEqualTrace).Define")
   151  			assert.Contains(err.Error(), "debug_test.go:")
   152  		} else {
   153  			assert.Contains(err.Error(), "constraint #1 is not satisfied: qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → 1 + -66 + 0 + 0 + 0 != 0")
   154  		}
   155  
   156  	}
   157  }
   158  
   159  func getPlonkTrace(circuit, w frontend.Circuit) (string, error) {
   160  	ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit)
   161  	if err != nil {
   162  		return "", err
   163  	}
   164  
   165  	srs, srsLagrange, err := unsafekzg.NewSRS(ccs)
   166  	if err != nil {
   167  		return "", err
   168  	}
   169  
   170  	pk, _, err := plonk.Setup(ccs, srs, srsLagrange)
   171  	if err != nil {
   172  		return "", err
   173  	}
   174  
   175  	var buf bytes.Buffer
   176  	sw, err := frontend.NewWitness(w, ecc.BN254.ScalarField())
   177  	if err != nil {
   178  		return "", err
   179  	}
   180  	log := zerolog.New(&zerolog.ConsoleWriter{Out: &buf, NoColor: true, PartsExclude: []string{zerolog.LevelFieldName, zerolog.TimestampFieldName}})
   181  	_, err = plonk.Prove(ccs, pk, sw, backend.WithSolverOptions(solver.WithLogger(log)))
   182  	return buf.String(), err
   183  }
   184  
   185  func getGroth16Trace(circuit, w frontend.Circuit) (string, error) {
   186  	ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, circuit)
   187  	if err != nil {
   188  		return "", err
   189  	}
   190  
   191  	pk, err := groth16.DummySetup(ccs)
   192  	if err != nil {
   193  		return "", err
   194  	}
   195  
   196  	var buf bytes.Buffer
   197  	sw, err := frontend.NewWitness(w, ecc.BN254.ScalarField())
   198  	if err != nil {
   199  		return "", err
   200  	}
   201  	log := zerolog.New(&zerolog.ConsoleWriter{Out: &buf, NoColor: true, PartsExclude: []string{zerolog.LevelFieldName, zerolog.TimestampFieldName}})
   202  	_, err = groth16.Prove(ccs, pk, sw, backend.WithSolverOptions(solver.WithLogger(log)))
   203  	return buf.String(), err
   204  }