github.com/consensys/gnark-crypto@v0.14.0/internal/generator/polynomial/template/multilin.go.tmpl (about)

     1  import (
     2      "{{.FieldPackagePath}}"
     3  	"math/bits"
     4  	"github.com/consensys/gnark-crypto/utils"
     5  )
     6  
     7  // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial
     8  // The variables are X₁ through Xₙ where n = log(len(.))
     9  // .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ)
    10  // It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial
    11  type MultiLin []{{.ElementType}}
    12  
    13  // Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r
    14  func (m *MultiLin) Fold(r {{.ElementType}}) {
    15  	mid := len(*m) / 2
    16  
    17  	bottom, top := (*m)[:mid], (*m)[mid:]
    18  
    19  	var t {{.ElementType}} // no need to update the top part
    20  
    21  	// updating bookkeeping table
    22  	// knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0))
    23  	// the following loop computes the evaluations of f(r) accordingly:
    24  	//		f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ))
    25  	for i := 0; i < mid; i++ {
    26  		// table[i] ← table[i] + r (table[i + mid] - table[i])
    27  		t.Sub(&top[i], &bottom[i])
    28  		t.Mul(&t, &r)
    29  		bottom[i].Add(&bottom[i], &t)
    30  	}
    31  
    32  	*m = (*m)[:mid]
    33  }
    34  
    35  func (m *MultiLin) FoldParallel(r {{.ElementType}}) utils.Task {
    36  	mid := len(*m) / 2
    37  	bottom, top := (*m)[:mid], (*m)[mid:]
    38  
    39  	*m = bottom
    40  
    41  	return func(start, end int) {
    42  		var t {{.ElementType}} // no need to update the top part
    43  		for i := start; i < end; i++ {
    44  			// table[i] ← table[i]  + r (table[i + mid] - table[i])
    45  			t.Sub(&top[i], &bottom[i])
    46  			t.Mul(&t, &r)
    47  			bottom[i].Add(&bottom[i], &t)
    48  		}
    49  	}
    50  }
    51  
    52  func (m MultiLin) Sum() {{.ElementType}} {
    53  	s := m[0]
    54  	for i := 1; i < len(m); i++ {
    55  		s.Add(&s, &m[i])
    56  	}
    57  	return s
    58  }
    59  
    60  func _clone(m MultiLin, p *Pool) MultiLin {
    61  	if p == nil {
    62  		return m.Clone()
    63  	} else {
    64  		return p.Clone(m)
    65  	}
    66  }
    67  
    68  func _dump(m MultiLin, p *Pool) {
    69  	if p != nil {
    70  		p.Dump(m)
    71  	}
    72  }
    73  
    74  // Evaluate extrapolate the value of the multilinear polynomial corresponding to m
    75  // on the given coordinates
    76  func (m MultiLin) Evaluate(coordinates []{{.ElementType}}, p *Pool) {{.ElementType}} {
    77  	// Folding is a mutating operation
    78  	bkCopy := _clone(m, p)
    79  
    80  	// Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable)
    81  	for _, r := range coordinates {
    82  		bkCopy.Fold(r)
    83  	}
    84  
    85  	result := bkCopy[0]
    86  
    87  	_dump(bkCopy, p)
    88  	return result
    89  }
    90  
    91  // Clone creates a deep copy of a bookkeeping table.
    92  // Both multilinear interpolation and sumcheck require folding an underlying
    93  // array, but folding changes the array. To do both one requires a deep copy
    94  // of the bookkeeping table.
    95  func (m MultiLin) Clone() MultiLin {
    96  	res := make(MultiLin, len(m))
    97  	copy(res, m)
    98  	return res
    99  }
   100  
   101  // Add two bookKeepingTables
   102  func (m *MultiLin) Add(left, right MultiLin) {
   103  	size := len(left)
   104  	// Check that left and right have the same size
   105  	if len(right) != size || len(*m) != size{
   106  		panic("left, right and destination must have the right size")
   107  	}
   108  
   109  	// Add elementwise
   110  	for i := 0; i < size; i++ {
   111  		(*m)[i].Add(&left[i], &right[i])
   112  	}
   113  }
   114  
   115  
   116  // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ)
   117  // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates
   118  //      _________________
   119  //      |       |       |
   120  //      |   0   |   1   |
   121  //      |_______|_______|
   122  //  y   |       |       |
   123  //      |   1   |   0   |
   124  //      |_______|_______|
   125  //
   126  //              x
   127  // In other words the polynomial evaluated here is the multilinear extrapolation of
   128  // one that evaluates to q' == h' for vectors q', h' of binary values
   129  func EvalEq(q, h []{{.ElementType}}) {{.ElementType}} {
   130  	var res, nxt, one, sum {{.ElementType}}
   131  	one.SetOne()
   132  	for i := 0; i < len(q); i++ {
   133  		nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ
   134  		nxt.Double(&nxt)      // nxt <- 2 * qᵢ * hᵢ
   135  		nxt.Add(&nxt, &one)   // nxt <- 1 + 2 * qᵢ * hᵢ
   136  		sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ	TODO: Why not subtract one by one from nxt? More parallel?
   137  
   138  		if i == 0 {
   139  			res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ
   140  		} else {
   141  			nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ
   142  			res.Mul(&res, &nxt) // res <- res * nxt
   143  		}
   144  	}
   145  	return res
   146  }
   147  
   148  // Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0]
   149  func (m *MultiLin) Eq(q []{{.ElementType}}) {
   150  	n := len(q)
   151  
   152  	if len(*m) != 1 << n {
   153  		panic("destination must have size 2 raised to the size of source")
   154  	}
   155  
   156  	//At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁)
   157  	for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁
   158  		// go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ
   159  		for j := 0; j < (1 << i); j++ {
   160  			j0 := j << (n - i)                 // bᵢ₊₁ = 0
   161  			j1 := j0 + 1<<(n-1-i)              // bᵢ₊₁ = 1
   162  			(*m)[j1].Mul(&q[i], &(*m)[j0])       // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
   163  			(*m)[j0].Sub(&(*m)[j0], &(*m)[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
   164  		}
   165  	}
   166  }
   167  
   168  func (m MultiLin) NumVars() int {
   169  	return bits.TrailingZeros(uint(len(m)))
   170  }
   171  
   172  func init() {
   173  	//TODO: Check for whether already computed in the Getter or this?
   174  	lagrangeBasis = make([][]Polynomial, maxLagrangeDomainSize+1)
   175  
   176  	//size = 0: Cannot extrapolate with no data points
   177  
   178  	//size = 1: Constant polynomial
   179  	lagrangeBasis[1] = []Polynomial{make(Polynomial, 1)}
   180  	lagrangeBasis[1][0][0].SetOne()
   181  
   182  	//for size ≥ 2, the function works
   183  	for size := uint8(2); size <= maxLagrangeDomainSize; size++ {
   184  		lagrangeBasis[size] = computeLagrangeBasis(size)
   185  	}
   186  }
   187  
   188  func getLagrangeBasis(domainSize int) []Polynomial {
   189  	//TODO: Precompute everything at init or this?
   190  	/*if lagrangeBasis[domainSize] == nil {
   191  		lagrangeBasis[domainSize] = computeLagrangeBasis(domainSize)
   192  	}*/
   193  	return lagrangeBasis[domainSize]
   194  }
   195  
   196  const maxLagrangeDomainSize uint8 = 12
   197  
   198  var lagrangeBasis [][]Polynomial
   199  
   200  // computeLagrangeBasis precomputes in explicit coefficient form for each 0 ≤ l < domainSize the polynomial
   201  // pₗ := X (X-1) ... (X-l-1) (X-l+1) ... (X - domainSize + 1) / ( l (l-1) ... 2 (-1) ... (l - domainSize +1) )
   202  // Note that pₗ(l) = 1 and pₗ(n) = 0 if 0 ≤ l < domainSize, n ≠ l
   203  func computeLagrangeBasis(domainSize uint8) []Polynomial {
   204  
   205  	constTerms := make([]{{.ElementType}}, domainSize)
   206  	for i := uint8(0); i < domainSize; i++ {
   207  		constTerms[i].SetInt64(-int64(i))
   208  	}
   209  
   210  	res := make([]Polynomial, domainSize)
   211  	multScratch := make(Polynomial, domainSize-1)
   212  
   213  	// compute pₗ
   214  	for l := uint8(0); l < domainSize; l++ {
   215  
   216  		// TODO: Optimize this with some trees? O(log(domainSize)) polynomial mults instead of O(domainSize)? Then again it would be fewer big poly mults vs many small poly mults
   217  		d := uint8(0) //d is the current degree of res
   218  		for i := uint8(0); i < domainSize; i++ {
   219  			if i == l {
   220  				continue
   221  			}
   222  			if d == 0 {
   223  				res[l] = make(Polynomial, domainSize)
   224  				res[l][domainSize-2] = constTerms[i]
   225  				res[l][domainSize-1].SetOne()
   226  			} else {
   227  				current := res[l][domainSize-d-2:]
   228  				timesConst := multScratch[domainSize-d-2:]
   229  
   230  				timesConst.Scale(&constTerms[i], current[1:]) //TODO: Directly double and add since constTerms are tiny? (even less than 4 bits)
   231  				nonLeading := current[0 : d+1]
   232  
   233  				nonLeading.Add(nonLeading, timesConst)
   234  
   235  			}
   236  			d++
   237  		}
   238  
   239  	}
   240  
   241  	// We have pₗ(i≠l)=0. Now scale so that pₗ(l)=1
   242  	// Replace the constTerms with norms
   243  	for l := uint8(0); l < domainSize; l++ {
   244  		constTerms[l].Neg(&constTerms[l])
   245  		constTerms[l] = res[l].Eval(&constTerms[l])
   246  	}
   247  	constTerms = {{.FieldPackageName}}.BatchInvert(constTerms)
   248  	for l := uint8(0); l < domainSize; l++ {
   249  		res[l].ScaleInPlace(&constTerms[l])
   250  	}
   251  
   252  	return res
   253  }
   254  
   255  // InterpolateOnRange performs the interpolation of the given list of elements
   256  // On the range [0, 1,..., len(values) - 1]
   257  func InterpolateOnRange(values []{{.ElementType}}) Polynomial {
   258  	nEvals := len(values)
   259  	lagrange := getLagrangeBasis(nEvals)
   260  
   261  	var res Polynomial
   262  	res.Scale(&values[0], lagrange[0])
   263  
   264  	temp := make(Polynomial, nEvals)
   265  
   266  	for i := 1; i < nEvals; i++ {
   267  		temp.Scale(&values[i], lagrange[i])
   268  		res.Add(res, temp)
   269  	}
   270  
   271  	return res
   272  }