github.com/consensys/gnark@v0.11.0/constraint/blueprint_scs.go (about)

     1  package constraint
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  )
     7  
     8  var (
     9  	errDivideByZero  = errors.New("division by 0")
    10  	errBoolConstrain = errors.New("boolean constraint doesn't hold")
    11  )
    12  
    13  // BlueprintGenericSparseR1C implements Blueprint and BlueprintSparseR1C.
    14  // Encodes
    15  //
    16  //	qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC == 0
    17  type BlueprintGenericSparseR1C struct {
    18  }
    19  
    20  func (b *BlueprintGenericSparseR1C) CalldataSize() int {
    21  	return 9 // number of fields in SparseR1C
    22  }
    23  func (b *BlueprintGenericSparseR1C) NbConstraints() int {
    24  	return 1
    25  }
    26  
    27  func (b *BlueprintGenericSparseR1C) NbOutputs(inst Instruction) int {
    28  	return 0
    29  }
    30  
    31  func (b *BlueprintGenericSparseR1C) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
    32  	return updateInstructionTree(inst.Calldata[0:3], tree)
    33  }
    34  
    35  func (b *BlueprintGenericSparseR1C) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
    36  	*to = append(*to, c.XA, c.XB, c.XC, c.QL, c.QR, c.QO, c.QM, c.QC, uint32(c.Commitment))
    37  }
    38  
    39  func (b *BlueprintGenericSparseR1C) DecompressSparseR1C(c *SparseR1C, inst Instruction) {
    40  	c.Clear()
    41  
    42  	c.XA = inst.Calldata[0]
    43  	c.XB = inst.Calldata[1]
    44  	c.XC = inst.Calldata[2]
    45  	c.QL = inst.Calldata[3]
    46  	c.QR = inst.Calldata[4]
    47  	c.QO = inst.Calldata[5]
    48  	c.QM = inst.Calldata[6]
    49  	c.QC = inst.Calldata[7]
    50  	c.Commitment = CommitmentConstraint(inst.Calldata[8])
    51  }
    52  
    53  func (b *BlueprintGenericSparseR1C) Solve(s Solver, inst Instruction) error {
    54  	var c SparseR1C
    55  	b.DecompressSparseR1C(&c, inst)
    56  	if c.Commitment != NOT {
    57  		// a constraint of the form f_L - PI_2 = 0 or f_L = Comm.
    58  		// these are there for enforcing the correctness of the commitment and can be skipped in solving time
    59  		return nil
    60  	}
    61  
    62  	var ok bool
    63  
    64  	// constraint has at most one unsolved wire.
    65  	if !s.IsSolved(c.XA) {
    66  		// we solve for L: u1L+u2R+u3LR+u4O+k=0 => L(u1+u3R)+u2R+u4O+k = 0
    67  		u1 := s.GetCoeff(c.QL)
    68  		den := s.GetValue(c.QM, c.XB)
    69  		den = s.Add(den, u1)
    70  		den, ok = s.Inverse(den)
    71  		if !ok {
    72  			return errDivideByZero
    73  		}
    74  		v1 := s.GetValue(c.QR, c.XB)
    75  		v2 := s.GetValue(c.QO, c.XC)
    76  		num := s.Add(v1, v2)
    77  		num = s.Add(num, s.GetCoeff(c.QC))
    78  		num = s.Mul(num, den)
    79  		num = s.Neg(num)
    80  		s.SetValue(c.XA, num)
    81  	} else if !s.IsSolved(c.XB) {
    82  		u2 := s.GetCoeff(c.QR)
    83  		den := s.GetValue(c.QM, c.XA)
    84  		den = s.Add(den, u2)
    85  		den, ok = s.Inverse(den)
    86  		if !ok {
    87  			return errDivideByZero
    88  		}
    89  
    90  		v1 := s.GetValue(c.QL, c.XA)
    91  		v2 := s.GetValue(c.QO, c.XC)
    92  
    93  		num := s.Add(v1, v2)
    94  		num = s.Add(num, s.GetCoeff(c.QC))
    95  		num = s.Mul(num, den)
    96  		num = s.Neg(num)
    97  		s.SetValue(c.XB, num)
    98  
    99  	} else if !s.IsSolved(c.XC) {
   100  		// O we solve for O
   101  		l := s.GetValue(c.QL, c.XA)
   102  		r := s.GetValue(c.QR, c.XB)
   103  		m0 := s.GetValue(c.QM, c.XA)
   104  		m1 := s.GetValue(CoeffIdOne, c.XB)
   105  
   106  		// o = - ((m0 * m1) + l + r + c.QC) / c.O
   107  		o := s.Mul(m0, m1)
   108  		o = s.Add(o, l)
   109  		o = s.Add(o, r)
   110  		o = s.Add(o, s.GetCoeff(c.QC))
   111  
   112  		den := s.GetCoeff(c.QO)
   113  		den, ok = s.Inverse(den)
   114  		if !ok {
   115  			return errDivideByZero
   116  		}
   117  		o = s.Mul(o, den)
   118  		o = s.Neg(o)
   119  
   120  		s.SetValue(c.XC, o)
   121  	} else {
   122  		// all wires are solved, we verify that the constraint hold.
   123  		// this can happen when all wires are from hints or if the constraint is an assertion.
   124  		return b.checkConstraint(&c, s)
   125  	}
   126  	return nil
   127  }
   128  
   129  func (b *BlueprintGenericSparseR1C) checkConstraint(c *SparseR1C, s Solver) error {
   130  	l := s.GetValue(c.QL, c.XA)
   131  	r := s.GetValue(c.QR, c.XB)
   132  	m0 := s.GetValue(c.QM, c.XA)
   133  	m1 := s.GetValue(CoeffIdOne, c.XB)
   134  	m0 = s.Mul(m0, m1)
   135  	o := s.GetValue(c.QO, c.XC)
   136  	qC := s.GetCoeff(c.QC)
   137  
   138  	t := s.Add(m0, l)
   139  	t = s.Add(t, r)
   140  	t = s.Add(t, o)
   141  	t = s.Add(t, qC)
   142  
   143  	if !t.IsZero() {
   144  		return fmt.Errorf("qL⋅xa + qR⋅xb + qO⋅xc + qM⋅(xaxb) + qC != 0 → %s + %s + %s + %s + %s != 0",
   145  			s.String(l),
   146  			s.String(r),
   147  			s.String(o),
   148  			s.String(m0),
   149  			s.String(qC),
   150  		)
   151  	}
   152  	return nil
   153  }
   154  
   155  // BlueprintSparseR1CMul implements Blueprint, BlueprintSolvable and BlueprintSparseR1C.
   156  // Encodes
   157  //
   158  //	qM⋅(xaxb)  == xc
   159  type BlueprintSparseR1CMul struct{}
   160  
   161  func (b *BlueprintSparseR1CMul) CalldataSize() int {
   162  	return 4
   163  }
   164  func (b *BlueprintSparseR1CMul) NbConstraints() int {
   165  	return 1
   166  }
   167  func (b *BlueprintSparseR1CMul) NbOutputs(inst Instruction) int {
   168  	return 0
   169  }
   170  
   171  func (b *BlueprintSparseR1CMul) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
   172  	return updateInstructionTree(inst.Calldata[0:3], tree)
   173  }
   174  
   175  func (b *BlueprintSparseR1CMul) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
   176  	*to = append(*to, c.XA, c.XB, c.XC, c.QM)
   177  }
   178  
   179  func (b *BlueprintSparseR1CMul) Solve(s Solver, inst Instruction) error {
   180  	// qM⋅(xaxb)  == xc
   181  	m0 := s.GetValue(inst.Calldata[3], inst.Calldata[0])
   182  	m1 := s.GetValue(CoeffIdOne, inst.Calldata[1])
   183  
   184  	m0 = s.Mul(m0, m1)
   185  
   186  	s.SetValue(inst.Calldata[2], m0)
   187  	return nil
   188  }
   189  
   190  func (b *BlueprintSparseR1CMul) DecompressSparseR1C(c *SparseR1C, inst Instruction) {
   191  	c.Clear()
   192  	c.XA = inst.Calldata[0]
   193  	c.XB = inst.Calldata[1]
   194  	c.XC = inst.Calldata[2]
   195  	c.QO = CoeffIdMinusOne
   196  	c.QM = inst.Calldata[3]
   197  }
   198  
   199  // BlueprintSparseR1CAdd implements Blueprint, BlueprintSolvable and BlueprintSparseR1C.
   200  // Encodes
   201  //
   202  //	qL⋅xa + qR⋅xb + qC == xc
   203  type BlueprintSparseR1CAdd struct{}
   204  
   205  func (b *BlueprintSparseR1CAdd) CalldataSize() int {
   206  	return 6
   207  }
   208  func (b *BlueprintSparseR1CAdd) NbConstraints() int {
   209  	return 1
   210  }
   211  func (b *BlueprintSparseR1CAdd) NbOutputs(inst Instruction) int {
   212  	return 0
   213  }
   214  
   215  func (b *BlueprintSparseR1CAdd) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
   216  	return updateInstructionTree(inst.Calldata[0:3], tree)
   217  }
   218  
   219  func (b *BlueprintSparseR1CAdd) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
   220  	*to = append(*to, c.XA, c.XB, c.XC, c.QL, c.QR, c.QC)
   221  }
   222  
   223  func (blueprint *BlueprintSparseR1CAdd) Solve(s Solver, inst Instruction) error {
   224  	// a + b + k == c
   225  	a := s.GetValue(inst.Calldata[3], inst.Calldata[0])
   226  	b := s.GetValue(inst.Calldata[4], inst.Calldata[1])
   227  	k := s.GetCoeff(inst.Calldata[5])
   228  
   229  	a = s.Add(a, b)
   230  	a = s.Add(a, k)
   231  
   232  	s.SetValue(inst.Calldata[2], a)
   233  	return nil
   234  }
   235  
   236  func (b *BlueprintSparseR1CAdd) DecompressSparseR1C(c *SparseR1C, inst Instruction) {
   237  	c.Clear()
   238  	c.XA = inst.Calldata[0]
   239  	c.XB = inst.Calldata[1]
   240  	c.XC = inst.Calldata[2]
   241  	c.QL = inst.Calldata[3]
   242  	c.QR = inst.Calldata[4]
   243  	c.QO = CoeffIdMinusOne
   244  	c.QC = inst.Calldata[5]
   245  }
   246  
   247  // BlueprintSparseR1CBool implements Blueprint, BlueprintSolvable and BlueprintSparseR1C.
   248  // Encodes
   249  //
   250  //	qL⋅xa + qM⋅(xa*xa)  == 0
   251  //	that is v + -v*v == 0
   252  type BlueprintSparseR1CBool struct{}
   253  
   254  func (b *BlueprintSparseR1CBool) CalldataSize() int {
   255  	return 3
   256  }
   257  func (b *BlueprintSparseR1CBool) NbConstraints() int {
   258  	return 1
   259  }
   260  func (b *BlueprintSparseR1CBool) NbOutputs(inst Instruction) int {
   261  	return 0
   262  }
   263  
   264  func (b *BlueprintSparseR1CBool) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
   265  	return updateInstructionTree(inst.Calldata[0:1], tree)
   266  }
   267  
   268  func (b *BlueprintSparseR1CBool) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
   269  	*to = append(*to, c.XA, c.QL, c.QM)
   270  }
   271  
   272  func (blueprint *BlueprintSparseR1CBool) Solve(s Solver, inst Instruction) error {
   273  	// all wires are already solved, we just check the constraint.
   274  	v1 := s.GetValue(inst.Calldata[1], inst.Calldata[0])
   275  	v2 := s.GetValue(inst.Calldata[2], inst.Calldata[0])
   276  	v := s.GetValue(CoeffIdOne, inst.Calldata[0])
   277  	v = s.Mul(v, v2)
   278  	v = s.Add(v1, v)
   279  	if !v.IsZero() {
   280  		return errBoolConstrain
   281  	}
   282  	return nil
   283  }
   284  
   285  func (b *BlueprintSparseR1CBool) DecompressSparseR1C(c *SparseR1C, inst Instruction) {
   286  	c.Clear()
   287  	c.XA = inst.Calldata[0]
   288  	c.XB = c.XA
   289  	c.QL = inst.Calldata[1]
   290  	c.QM = inst.Calldata[2]
   291  }
   292  
   293  func updateInstructionTree(wires []uint32, tree InstructionTree) Level {
   294  	// constraint has at most one unsolved wire.
   295  	var outputWire uint32
   296  	found := false
   297  	maxLevel := LevelUnset
   298  	for _, wireID := range wires {
   299  		if !tree.HasWire(wireID) {
   300  			continue
   301  		}
   302  		if level := tree.GetWireLevel(wireID); level == LevelUnset {
   303  			outputWire = wireID
   304  			found = true
   305  		} else if level > maxLevel {
   306  			maxLevel = level
   307  		}
   308  	}
   309  
   310  	maxLevel++
   311  	if found {
   312  		tree.InsertWire(outputWire, maxLevel)
   313  	}
   314  
   315  	return maxLevel
   316  }