github.com/consensys/gnark-crypto@v0.14.0/ecc/bw6-756/fr/polynomial/multilin.go (about)

     1  // Copyright 2020 Consensys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Code generated by consensys/gnark-crypto DO NOT EDIT
    16  
    17  package polynomial
    18  
    19  import (
    20  	"github.com/consensys/gnark-crypto/ecc/bw6-756/fr"
    21  	"github.com/consensys/gnark-crypto/utils"
    22  	"math/bits"
    23  )
    24  
    25  // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial
    26  // The variables are X₁ through Xₙ where n = log(len(.))
    27  // .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ)
    28  // It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial
    29  type MultiLin []fr.Element
    30  
    31  // Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r
    32  func (m *MultiLin) Fold(r fr.Element) {
    33  	mid := len(*m) / 2
    34  
    35  	bottom, top := (*m)[:mid], (*m)[mid:]
    36  
    37  	var t fr.Element // no need to update the top part
    38  
    39  	// updating bookkeeping table
    40  	// knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0))
    41  	// the following loop computes the evaluations of f(r) accordingly:
    42  	//		f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ))
    43  	for i := 0; i < mid; i++ {
    44  		// table[i] ← table[i] + r (table[i + mid] - table[i])
    45  		t.Sub(&top[i], &bottom[i])
    46  		t.Mul(&t, &r)
    47  		bottom[i].Add(&bottom[i], &t)
    48  	}
    49  
    50  	*m = (*m)[:mid]
    51  }
    52  
    53  func (m *MultiLin) FoldParallel(r fr.Element) utils.Task {
    54  	mid := len(*m) / 2
    55  	bottom, top := (*m)[:mid], (*m)[mid:]
    56  
    57  	*m = bottom
    58  
    59  	return func(start, end int) {
    60  		var t fr.Element // no need to update the top part
    61  		for i := start; i < end; i++ {
    62  			// table[i] ← table[i]  + r (table[i + mid] - table[i])
    63  			t.Sub(&top[i], &bottom[i])
    64  			t.Mul(&t, &r)
    65  			bottom[i].Add(&bottom[i], &t)
    66  		}
    67  	}
    68  }
    69  
    70  func (m MultiLin) Sum() fr.Element {
    71  	s := m[0]
    72  	for i := 1; i < len(m); i++ {
    73  		s.Add(&s, &m[i])
    74  	}
    75  	return s
    76  }
    77  
    78  func _clone(m MultiLin, p *Pool) MultiLin {
    79  	if p == nil {
    80  		return m.Clone()
    81  	} else {
    82  		return p.Clone(m)
    83  	}
    84  }
    85  
    86  func _dump(m MultiLin, p *Pool) {
    87  	if p != nil {
    88  		p.Dump(m)
    89  	}
    90  }
    91  
    92  // Evaluate extrapolate the value of the multilinear polynomial corresponding to m
    93  // on the given coordinates
    94  func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element {
    95  	// Folding is a mutating operation
    96  	bkCopy := _clone(m, p)
    97  
    98  	// Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable)
    99  	for _, r := range coordinates {
   100  		bkCopy.Fold(r)
   101  	}
   102  
   103  	result := bkCopy[0]
   104  
   105  	_dump(bkCopy, p)
   106  	return result
   107  }
   108  
   109  // Clone creates a deep copy of a bookkeeping table.
   110  // Both multilinear interpolation and sumcheck require folding an underlying
   111  // array, but folding changes the array. To do both one requires a deep copy
   112  // of the bookkeeping table.
   113  func (m MultiLin) Clone() MultiLin {
   114  	res := make(MultiLin, len(m))
   115  	copy(res, m)
   116  	return res
   117  }
   118  
   119  // Add two bookKeepingTables
   120  func (m *MultiLin) Add(left, right MultiLin) {
   121  	size := len(left)
   122  	// Check that left and right have the same size
   123  	if len(right) != size || len(*m) != size {
   124  		panic("left, right and destination must have the right size")
   125  	}
   126  
   127  	// Add elementwise
   128  	for i := 0; i < size; i++ {
   129  		(*m)[i].Add(&left[i], &right[i])
   130  	}
   131  }
   132  
   133  // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ)
   134  // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates
   135  //
   136  //	    _________________
   137  //	    |       |       |
   138  //	    |   0   |   1   |
   139  //	    |_______|_______|
   140  //	y   |       |       |
   141  //	    |   1   |   0   |
   142  //	    |_______|_______|
   143  //
   144  //	            x
   145  //
   146  // In other words the polynomial evaluated here is the multilinear extrapolation of
   147  // one that evaluates to q' == h' for vectors q', h' of binary values
   148  func EvalEq(q, h []fr.Element) fr.Element {
   149  	var res, nxt, one, sum fr.Element
   150  	one.SetOne()
   151  	for i := 0; i < len(q); i++ {
   152  		nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ
   153  		nxt.Double(&nxt)      // nxt <- 2 * qᵢ * hᵢ
   154  		nxt.Add(&nxt, &one)   // nxt <- 1 + 2 * qᵢ * hᵢ
   155  		sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ	TODO: Why not subtract one by one from nxt? More parallel?
   156  
   157  		if i == 0 {
   158  			res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ
   159  		} else {
   160  			nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ
   161  			res.Mul(&res, &nxt) // res <- res * nxt
   162  		}
   163  	}
   164  	return res
   165  }
   166  
   167  // Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0]
   168  func (m *MultiLin) Eq(q []fr.Element) {
   169  	n := len(q)
   170  
   171  	if len(*m) != 1<<n {
   172  		panic("destination must have size 2 raised to the size of source")
   173  	}
   174  
   175  	//At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁)
   176  	for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁
   177  		// go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ
   178  		for j := 0; j < (1 << i); j++ {
   179  			j0 := j << (n - i)                 // bᵢ₊₁ = 0
   180  			j1 := j0 + 1<<(n-1-i)              // bᵢ₊₁ = 1
   181  			(*m)[j1].Mul(&q[i], &(*m)[j0])     // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
   182  			(*m)[j0].Sub(&(*m)[j0], &(*m)[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
   183  		}
   184  	}
   185  }
   186  
   187  func (m MultiLin) NumVars() int {
   188  	return bits.TrailingZeros(uint(len(m)))
   189  }
   190  
   191  func init() {
   192  	//TODO: Check for whether already computed in the Getter or this?
   193  	lagrangeBasis = make([][]Polynomial, maxLagrangeDomainSize+1)
   194  
   195  	//size = 0: Cannot extrapolate with no data points
   196  
   197  	//size = 1: Constant polynomial
   198  	lagrangeBasis[1] = []Polynomial{make(Polynomial, 1)}
   199  	lagrangeBasis[1][0][0].SetOne()
   200  
   201  	//for size ≥ 2, the function works
   202  	for size := uint8(2); size <= maxLagrangeDomainSize; size++ {
   203  		lagrangeBasis[size] = computeLagrangeBasis(size)
   204  	}
   205  }
   206  
   207  func getLagrangeBasis(domainSize int) []Polynomial {
   208  	//TODO: Precompute everything at init or this?
   209  	/*if lagrangeBasis[domainSize] == nil {
   210  		lagrangeBasis[domainSize] = computeLagrangeBasis(domainSize)
   211  	}*/
   212  	return lagrangeBasis[domainSize]
   213  }
   214  
   215  const maxLagrangeDomainSize uint8 = 12
   216  
   217  var lagrangeBasis [][]Polynomial
   218  
   219  // computeLagrangeBasis precomputes in explicit coefficient form for each 0 ≤ l < domainSize the polynomial
   220  // pₗ := X (X-1) ... (X-l-1) (X-l+1) ... (X - domainSize + 1) / ( l (l-1) ... 2 (-1) ... (l - domainSize +1) )
   221  // Note that pₗ(l) = 1 and pₗ(n) = 0 if 0 ≤ l < domainSize, n ≠ l
   222  func computeLagrangeBasis(domainSize uint8) []Polynomial {
   223  
   224  	constTerms := make([]fr.Element, domainSize)
   225  	for i := uint8(0); i < domainSize; i++ {
   226  		constTerms[i].SetInt64(-int64(i))
   227  	}
   228  
   229  	res := make([]Polynomial, domainSize)
   230  	multScratch := make(Polynomial, domainSize-1)
   231  
   232  	// compute pₗ
   233  	for l := uint8(0); l < domainSize; l++ {
   234  
   235  		// TODO: Optimize this with some trees? O(log(domainSize)) polynomial mults instead of O(domainSize)? Then again it would be fewer big poly mults vs many small poly mults
   236  		d := uint8(0) //d is the current degree of res
   237  		for i := uint8(0); i < domainSize; i++ {
   238  			if i == l {
   239  				continue
   240  			}
   241  			if d == 0 {
   242  				res[l] = make(Polynomial, domainSize)
   243  				res[l][domainSize-2] = constTerms[i]
   244  				res[l][domainSize-1].SetOne()
   245  			} else {
   246  				current := res[l][domainSize-d-2:]
   247  				timesConst := multScratch[domainSize-d-2:]
   248  
   249  				timesConst.Scale(&constTerms[i], current[1:]) //TODO: Directly double and add since constTerms are tiny? (even less than 4 bits)
   250  				nonLeading := current[0 : d+1]
   251  
   252  				nonLeading.Add(nonLeading, timesConst)
   253  
   254  			}
   255  			d++
   256  		}
   257  
   258  	}
   259  
   260  	// We have pₗ(i≠l)=0. Now scale so that pₗ(l)=1
   261  	// Replace the constTerms with norms
   262  	for l := uint8(0); l < domainSize; l++ {
   263  		constTerms[l].Neg(&constTerms[l])
   264  		constTerms[l] = res[l].Eval(&constTerms[l])
   265  	}
   266  	constTerms = fr.BatchInvert(constTerms)
   267  	for l := uint8(0); l < domainSize; l++ {
   268  		res[l].ScaleInPlace(&constTerms[l])
   269  	}
   270  
   271  	return res
   272  }
   273  
   274  // InterpolateOnRange performs the interpolation of the given list of elements
   275  // On the range [0, 1,..., len(values) - 1]
   276  func InterpolateOnRange(values []fr.Element) Polynomial {
   277  	nEvals := len(values)
   278  	lagrange := getLagrangeBasis(nEvals)
   279  
   280  	var res Polynomial
   281  	res.Scale(&values[0], lagrange[0])
   282  
   283  	temp := make(Polynomial, nEvals)
   284  
   285  	for i := 1; i < nEvals; i++ {
   286  		temp.Scale(&values[i], lagrange[i])
   287  		res.Add(res, temp)
   288  	}
   289  
   290  	return res
   291  }