github.com/consensys/gnark-crypto@v0.14.0/field/generator/asm/amd64/element_mul.go (about)

     1  // Copyright 2020 ConsenSys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package amd64
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/consensys/bavard/amd64"
    21  )
    22  
    23  // MulADX uses AX, DX and BP
    24  // sets x * y into t, without modular reduction
    25  // x() will have more accesses than y()
    26  // (caller should store x in registers, if possible)
    27  // if no (tmp) register is available, this uses one PUSH/POP on the stack in the hot loop.
    28  func (f *FFAmd64) MulADX(registers *amd64.Registers, x, y func(int) string, t []amd64.Register) []amd64.Register {
    29  	// registers
    30  	var tr amd64.Register // temporary register
    31  	A := amd64.BP
    32  
    33  	hasFreeRegister := registers.Available() > 0
    34  	if hasFreeRegister {
    35  		tr = registers.Pop()
    36  	} else {
    37  		tr = A
    38  	}
    39  
    40  	f.LabelRegisters("A", A)
    41  	f.LabelRegisters("t", t...)
    42  
    43  	for i := 0; i < f.NbWords; i++ {
    44  		f.Comment("clear the flags")
    45  		f.XORQ(amd64.AX, amd64.AX)
    46  
    47  		f.MOVQ(y(i), amd64.DX)
    48  
    49  		// for j=0 to N-1
    50  		//    (A,t[j])  := t[j] + x[j]*y[i] + A
    51  		if i == 0 {
    52  			for j := 0; j < f.NbWords; j++ {
    53  				f.Comment(fmt.Sprintf("(A,t[%[1]d])  := x[%[1]d]*y[%[2]d] + A", j, i))
    54  
    55  				if j == 0 && f.NbWords == 1 {
    56  					f.MULXQ(x(j), t[j], A)
    57  				} else if j == 0 {
    58  					f.MULXQ(x(j), t[j], t[j+1])
    59  				} else {
    60  					highBits := A
    61  					if j != f.NbWordsLastIndex {
    62  						highBits = t[j+1]
    63  					}
    64  					f.MULXQ(x(j), amd64.AX, highBits)
    65  					f.ADOXQ(amd64.AX, t[j])
    66  				}
    67  			}
    68  		} else {
    69  			for j := 0; j < f.NbWords; j++ {
    70  				f.Comment(fmt.Sprintf("(A,t[%[1]d])  := t[%[1]d] + x[%[1]d]*y[%[2]d] + A", j, i))
    71  
    72  				if j != 0 {
    73  					f.ADCXQ(A, t[j])
    74  				}
    75  				f.MULXQ(x(j), amd64.AX, A)
    76  				f.ADOXQ(amd64.AX, t[j])
    77  			}
    78  		}
    79  
    80  		f.Comment("A += carries from ADCXQ and ADOXQ")
    81  		f.MOVQ(0, amd64.AX)
    82  		if i != 0 {
    83  			f.ADCXQ(amd64.AX, A)
    84  		}
    85  		f.ADOXQ(amd64.AX, A)
    86  
    87  		if !hasFreeRegister {
    88  			f.PUSHQ(A)
    89  		}
    90  
    91  		// m := t[0]*q'[0] mod W
    92  		f.Comment("m := t[0]*q'[0] mod W")
    93  		m := amd64.DX
    94  		// f.MOVQ(t[0], m)
    95  		// f.MULXQ(f.qInv0(), m, amd64.AX)
    96  		f.MOVQ(f.qInv0(), m)
    97  		f.IMULQ(t[0], m)
    98  
    99  		// clear the carry flags
   100  		f.Comment("clear the flags")
   101  		f.XORQ(amd64.AX, amd64.AX)
   102  
   103  		// C,_ := t[0] + m*q[0]
   104  		f.Comment("C,_ := t[0] + m*q[0]")
   105  
   106  		f.MULXQ(f.qAt(0), amd64.AX, tr)
   107  		f.ADCXQ(t[0], amd64.AX)
   108  		f.MOVQ(tr, t[0])
   109  
   110  		if !hasFreeRegister {
   111  			f.POPQ(A)
   112  		}
   113  		// for j=1 to N-1
   114  		//    (C,t[j-1]) := t[j] + m*q[j] + C
   115  		for j := 1; j < f.NbWords; j++ {
   116  			f.Comment(fmt.Sprintf("(C,t[%[1]d]) := t[%[2]d] + m*q[%[2]d] + C", j-1, j))
   117  			f.ADCXQ(t[j], t[j-1])
   118  			f.MULXQ(f.qAt(j), amd64.AX, t[j])
   119  			f.ADOXQ(amd64.AX, t[j-1])
   120  		}
   121  
   122  		f.Comment(fmt.Sprintf("t[%d] = C + A", f.NbWordsLastIndex))
   123  		f.MOVQ(0, amd64.AX)
   124  		f.ADCXQ(amd64.AX, t[f.NbWordsLastIndex])
   125  		f.ADOXQ(A, t[f.NbWordsLastIndex])
   126  
   127  	}
   128  
   129  	if hasFreeRegister {
   130  		registers.Push(tr)
   131  	}
   132  
   133  	return t
   134  }
   135  
   136  func (f *FFAmd64) generateMul(forceADX bool) {
   137  	f.Comment("mul(res, x, y *Element)")
   138  
   139  	const argSize = 3 * 8
   140  	minStackSize := argSize
   141  	if forceADX {
   142  		minStackSize = 0
   143  	}
   144  	stackSize := f.StackSize(f.NbWords*2, 2, minStackSize)
   145  	reserved := []amd64.Register{amd64.DX, amd64.AX}
   146  	if f.NbWords <= 5 {
   147  		// when dynamic linking, R15 is clobbered by a global variable access
   148  		// this is a temporary workaround --> don't use R15 when we can avoid it.
   149  		// see https://github.com/ConsenSys/gnark-crypto/issues/113
   150  		reserved = append(reserved, amd64.R15)
   151  	}
   152  	registers := f.FnHeader("mul", stackSize, argSize, reserved...)
   153  	defer f.AssertCleanStack(stackSize, minStackSize)
   154  
   155  	f.WriteLn(fmt.Sprintf(`
   156  	// the algorithm is described in the %s.Mul declaration (.go)
   157  	// however, to benefit from the ADCX and ADOX carry chains
   158  	// we split the inner loops in 2:
   159  	// for i=0 to N-1
   160  	// 		for j=0 to N-1
   161  	// 		    (A,t[j])  := t[j] + x[j]*y[i] + A
   162  	// 		m := t[0]*q'[0] mod W
   163  	// 		C,_ := t[0] + m*q[0]
   164  	// 		for j=1 to N-1
   165  	// 		    (C,t[j-1]) := t[j] + m*q[j] + C
   166  	// 		t[N-1] = C + A
   167  	`, f.ElementName))
   168  	if stackSize > 0 {
   169  		f.WriteLn("NO_LOCAL_POINTERS")
   170  	}
   171  
   172  	noAdx := f.NewLabel()
   173  
   174  	if !forceADX {
   175  		// check ADX instruction support
   176  		f.CMPB("·supportAdx(SB)", 1)
   177  		f.JNE(noAdx)
   178  	}
   179  
   180  	{
   181  		// we need to access x and y words, per index
   182  		var xat, yat func(int) string
   183  		var gc func()
   184  
   185  		// we need NbWords registers for t, plus optionally one for tmp register in mulADX if we want to avoid PUSH/POP
   186  		nbRegisters := registers.Available()
   187  		if nbRegisters < f.NbWords {
   188  			panic("not enough registers, not supported.")
   189  		}
   190  
   191  		t := registers.PopN(f.NbWords)
   192  		nbRegisters = registers.Available()
   193  		switch nbRegisters {
   194  		case 0:
   195  			// y is access through use of AX/DX
   196  			yat = func(i int) string {
   197  				y := amd64.AX
   198  				f.MOVQ("y+16(FP)", y)
   199  				return y.At(i)
   200  			}
   201  
   202  			// we move x on the stack.
   203  			f.MOVQ("x+8(FP)", amd64.AX)
   204  			_x := f.PopN(&registers, true)
   205  			f.LabelRegisters("x", _x...)
   206  			f.Mov(amd64.AX, t)
   207  			f.Mov(t, _x)
   208  			xat = func(i int) string {
   209  				return string(_x[i])
   210  			}
   211  			gc = func() {
   212  				f.Push(&registers, _x...)
   213  			}
   214  		case 1:
   215  			// y is access through use of AX/DX
   216  			yat = func(i int) string {
   217  				y := amd64.AX
   218  				f.MOVQ("y+16(FP)", y)
   219  				return y.At(i)
   220  			}
   221  			// x uses the register
   222  			x := registers.Pop()
   223  			f.MOVQ("x+8(FP)", x)
   224  			xat = func(i int) string {
   225  				return x.At(i)
   226  			}
   227  			gc = func() {
   228  				registers.Push(x)
   229  			}
   230  		case 2, 3:
   231  			// x, y uses registers
   232  			x := registers.Pop()
   233  			y := registers.Pop()
   234  
   235  			f.MOVQ("x+8(FP)", x)
   236  			f.MOVQ("y+16(FP)", y)
   237  
   238  			xat = func(i int) string {
   239  				return x.At(i)
   240  			}
   241  
   242  			yat = func(i int) string {
   243  				return y.At(i)
   244  			}
   245  			gc = func() {
   246  				registers.Push(x, y)
   247  			}
   248  		default:
   249  			// we have a least 4 registers.
   250  			// 1 for tmp.
   251  			nbRegisters--
   252  			// 1 for y
   253  			nbRegisters--
   254  			var y amd64.Register
   255  
   256  			if nbRegisters >= f.NbWords {
   257  				// we store x fully in registers
   258  				x := registers.Pop()
   259  				f.MOVQ("x+8(FP)", x)
   260  				_x := registers.PopN(f.NbWords)
   261  				f.LabelRegisters("x", _x...)
   262  				f.Mov(x, _x)
   263  
   264  				xat = func(i int) string {
   265  					return string(_x[i])
   266  				}
   267  				registers.Push(x)
   268  				gc = func() {
   269  					registers.Push(y)
   270  					registers.Push(_x...)
   271  				}
   272  			} else {
   273  				// we take at least 1 register for x addr
   274  				nbRegisters--
   275  				x := registers.Pop()
   276  				y = registers.Pop() // temporary lock 1 for y
   277  				f.MOVQ("x+8(FP)", x)
   278  
   279  				// and use the rest for x0...xn
   280  				_x := registers.PopN(nbRegisters)
   281  				f.LabelRegisters("x", _x...)
   282  				for i := 0; i < len(_x); i++ {
   283  					f.MOVQ(x.At(i), _x[i])
   284  				}
   285  				xat = func(i int) string {
   286  					if i < len(_x) {
   287  						return string(_x[i])
   288  					}
   289  					return x.At(i)
   290  				}
   291  				registers.Push(y)
   292  
   293  				gc = func() {
   294  					registers.Push(x, y)
   295  					registers.Push(_x...)
   296  				}
   297  
   298  			}
   299  			y = registers.Pop()
   300  
   301  			f.MOVQ("y+16(FP)", y)
   302  			yat = func(i int) string {
   303  				return y.At(i)
   304  			}
   305  
   306  		}
   307  
   308  		f.MulADX(&registers, xat, yat, t)
   309  		gc()
   310  
   311  		// ---------------------------------------------------------------------------------------------
   312  		// reduce
   313  		f.Reduce(&registers, t)
   314  
   315  		f.MOVQ("res+0(FP)", amd64.AX)
   316  		f.Mov(t, amd64.AX)
   317  		f.RET()
   318  	}
   319  
   320  	// ---------------------------------------------------------------------------------------------
   321  	// no MULX, ADX instructions
   322  	if !forceADX {
   323  		f.LABEL(noAdx)
   324  
   325  		f.MOVQ("res+0(FP)", amd64.AX)
   326  		f.MOVQ(amd64.AX, "(SP)")
   327  		f.MOVQ("x+8(FP)", amd64.AX)
   328  		f.MOVQ(amd64.AX, "8(SP)")
   329  		f.MOVQ("y+16(FP)", amd64.AX)
   330  		f.MOVQ(amd64.AX, "16(SP)")
   331  		f.WriteLn("CALL ·_mulGeneric(SB)")
   332  		f.RET()
   333  
   334  	}
   335  }