github.com/consensys/gnark-crypto@v0.14.0/internal/generator/polynomial/template/multilin.go.tmpl (about) 1 import ( 2 "{{.FieldPackagePath}}" 3 "math/bits" 4 "github.com/consensys/gnark-crypto/utils" 5 ) 6 7 // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial 8 // The variables are X₁ through Xₙ where n = log(len(.)) 9 // .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ) 10 // It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial 11 type MultiLin []{{.ElementType}} 12 13 // Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r 14 func (m *MultiLin) Fold(r {{.ElementType}}) { 15 mid := len(*m) / 2 16 17 bottom, top := (*m)[:mid], (*m)[mid:] 18 19 var t {{.ElementType}} // no need to update the top part 20 21 // updating bookkeeping table 22 // knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0)) 23 // the following loop computes the evaluations of f(r) accordingly: 24 // f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ)) 25 for i := 0; i < mid; i++ { 26 // table[i] ← table[i] + r (table[i + mid] - table[i]) 27 t.Sub(&top[i], &bottom[i]) 28 t.Mul(&t, &r) 29 bottom[i].Add(&bottom[i], &t) 30 } 31 32 *m = (*m)[:mid] 33 } 34 35 func (m *MultiLin) FoldParallel(r {{.ElementType}}) utils.Task { 36 mid := len(*m) / 2 37 bottom, top := (*m)[:mid], (*m)[mid:] 38 39 *m = bottom 40 41 return func(start, end int) { 42 var t {{.ElementType}} // no need to update the top part 43 for i := start; i < end; i++ { 44 // table[i] ← table[i] + r (table[i + mid] - table[i]) 45 t.Sub(&top[i], &bottom[i]) 46 t.Mul(&t, &r) 47 bottom[i].Add(&bottom[i], &t) 48 } 49 } 50 } 51 52 func (m MultiLin) Sum() {{.ElementType}} { 53 s := m[0] 54 for i := 1; i < len(m); i++ { 55 s.Add(&s, &m[i]) 56 } 57 return s 58 } 59 60 func _clone(m MultiLin, p *Pool) MultiLin { 61 if p == nil { 62 return m.Clone() 63 } else { 64 return p.Clone(m) 65 } 66 } 67 68 func _dump(m MultiLin, p *Pool) { 69 if p != nil { 70 p.Dump(m) 71 } 72 } 73 74 // Evaluate extrapolate the value of the multilinear polynomial corresponding to m 75 // on the given coordinates 76 func (m MultiLin) Evaluate(coordinates []{{.ElementType}}, p *Pool) {{.ElementType}} { 77 // Folding is a mutating operation 78 bkCopy := _clone(m, p) 79 80 // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) 81 for _, r := range coordinates { 82 bkCopy.Fold(r) 83 } 84 85 result := bkCopy[0] 86 87 _dump(bkCopy, p) 88 return result 89 } 90 91 // Clone creates a deep copy of a bookkeeping table. 92 // Both multilinear interpolation and sumcheck require folding an underlying 93 // array, but folding changes the array. To do both one requires a deep copy 94 // of the bookkeeping table. 95 func (m MultiLin) Clone() MultiLin { 96 res := make(MultiLin, len(m)) 97 copy(res, m) 98 return res 99 } 100 101 // Add two bookKeepingTables 102 func (m *MultiLin) Add(left, right MultiLin) { 103 size := len(left) 104 // Check that left and right have the same size 105 if len(right) != size || len(*m) != size{ 106 panic("left, right and destination must have the right size") 107 } 108 109 // Add elementwise 110 for i := 0; i < size; i++ { 111 (*m)[i].Add(&left[i], &right[i]) 112 } 113 } 114 115 116 // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) 117 // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates 118 // _________________ 119 // | | | 120 // | 0 | 1 | 121 // |_______|_______| 122 // y | | | 123 // | 1 | 0 | 124 // |_______|_______| 125 // 126 // x 127 // In other words the polynomial evaluated here is the multilinear extrapolation of 128 // one that evaluates to q' == h' for vectors q', h' of binary values 129 func EvalEq(q, h []{{.ElementType}}) {{.ElementType}} { 130 var res, nxt, one, sum {{.ElementType}} 131 one.SetOne() 132 for i := 0; i < len(q); i++ { 133 nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ 134 nxt.Double(&nxt) // nxt <- 2 * qᵢ * hᵢ 135 nxt.Add(&nxt, &one) // nxt <- 1 + 2 * qᵢ * hᵢ 136 sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ TODO: Why not subtract one by one from nxt? More parallel? 137 138 if i == 0 { 139 res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ 140 } else { 141 nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ 142 res.Mul(&res, &nxt) // res <- res * nxt 143 } 144 } 145 return res 146 } 147 148 // Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] 149 func (m *MultiLin) Eq(q []{{.ElementType}}) { 150 n := len(q) 151 152 if len(*m) != 1 << n { 153 panic("destination must have size 2 raised to the size of source") 154 } 155 156 //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) 157 for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ 158 // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ 159 for j := 0; j < (1 << i); j++ { 160 j0 := j << (n - i) // bᵢ₊₁ = 0 161 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 162 (*m)[j1].Mul(&q[i], &(*m)[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ 163 (*m)[j0].Sub(&(*m)[j0], &(*m)[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) 164 } 165 } 166 } 167 168 func (m MultiLin) NumVars() int { 169 return bits.TrailingZeros(uint(len(m))) 170 } 171 172 func init() { 173 //TODO: Check for whether already computed in the Getter or this? 174 lagrangeBasis = make([][]Polynomial, maxLagrangeDomainSize+1) 175 176 //size = 0: Cannot extrapolate with no data points 177 178 //size = 1: Constant polynomial 179 lagrangeBasis[1] = []Polynomial{make(Polynomial, 1)} 180 lagrangeBasis[1][0][0].SetOne() 181 182 //for size ≥ 2, the function works 183 for size := uint8(2); size <= maxLagrangeDomainSize; size++ { 184 lagrangeBasis[size] = computeLagrangeBasis(size) 185 } 186 } 187 188 func getLagrangeBasis(domainSize int) []Polynomial { 189 //TODO: Precompute everything at init or this? 190 /*if lagrangeBasis[domainSize] == nil { 191 lagrangeBasis[domainSize] = computeLagrangeBasis(domainSize) 192 }*/ 193 return lagrangeBasis[domainSize] 194 } 195 196 const maxLagrangeDomainSize uint8 = 12 197 198 var lagrangeBasis [][]Polynomial 199 200 // computeLagrangeBasis precomputes in explicit coefficient form for each 0 ≤ l < domainSize the polynomial 201 // pₗ := X (X-1) ... (X-l-1) (X-l+1) ... (X - domainSize + 1) / ( l (l-1) ... 2 (-1) ... (l - domainSize +1) ) 202 // Note that pₗ(l) = 1 and pₗ(n) = 0 if 0 ≤ l < domainSize, n ≠ l 203 func computeLagrangeBasis(domainSize uint8) []Polynomial { 204 205 constTerms := make([]{{.ElementType}}, domainSize) 206 for i := uint8(0); i < domainSize; i++ { 207 constTerms[i].SetInt64(-int64(i)) 208 } 209 210 res := make([]Polynomial, domainSize) 211 multScratch := make(Polynomial, domainSize-1) 212 213 // compute pₗ 214 for l := uint8(0); l < domainSize; l++ { 215 216 // TODO: Optimize this with some trees? O(log(domainSize)) polynomial mults instead of O(domainSize)? Then again it would be fewer big poly mults vs many small poly mults 217 d := uint8(0) //d is the current degree of res 218 for i := uint8(0); i < domainSize; i++ { 219 if i == l { 220 continue 221 } 222 if d == 0 { 223 res[l] = make(Polynomial, domainSize) 224 res[l][domainSize-2] = constTerms[i] 225 res[l][domainSize-1].SetOne() 226 } else { 227 current := res[l][domainSize-d-2:] 228 timesConst := multScratch[domainSize-d-2:] 229 230 timesConst.Scale(&constTerms[i], current[1:]) //TODO: Directly double and add since constTerms are tiny? (even less than 4 bits) 231 nonLeading := current[0 : d+1] 232 233 nonLeading.Add(nonLeading, timesConst) 234 235 } 236 d++ 237 } 238 239 } 240 241 // We have pₗ(i≠l)=0. Now scale so that pₗ(l)=1 242 // Replace the constTerms with norms 243 for l := uint8(0); l < domainSize; l++ { 244 constTerms[l].Neg(&constTerms[l]) 245 constTerms[l] = res[l].Eval(&constTerms[l]) 246 } 247 constTerms = {{.FieldPackageName}}.BatchInvert(constTerms) 248 for l := uint8(0); l < domainSize; l++ { 249 res[l].ScaleInPlace(&constTerms[l]) 250 } 251 252 return res 253 } 254 255 // InterpolateOnRange performs the interpolation of the given list of elements 256 // On the range [0, 1,..., len(values) - 1] 257 func InterpolateOnRange(values []{{.ElementType}}) Polynomial { 258 nEvals := len(values) 259 lagrange := getLagrangeBasis(nEvals) 260 261 var res Polynomial 262 res.Scale(&values[0], lagrange[0]) 263 264 temp := make(Polynomial, nEvals) 265 266 for i := 1; i < nEvals; i++ { 267 temp.Scale(&values[i], lagrange[i]) 268 res.Add(res, temp) 269 } 270 271 return res 272 }