github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/vector/compare/amd64/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"flag"
     6  	"fmt"
     7  	"go/format"
     8  	"math/bits"
     9  	"os"
    10  	"regexp"
    11  	"strings"
    12  
    13  	. "github.com/mmcloughlin/avo/build"
    14  	"github.com/mmcloughlin/avo/ir"
    15  	. "github.com/mmcloughlin/avo/operand"
    16  	"github.com/mmcloughlin/avo/reg"
    17  )
    18  
    19  var testhelp = flag.String("testhelp", "", "test helpers")
    20  
    21  func main() {
    22  	const variants = 6
    23  	alignments := []int{0, 8, 9, 10, 11, 12, 13, 14, 15, 16}
    24  	// const variants = 1
    25  	// alignments := []int{0}
    26  
    27  	emitAlignments := func(emit func(variant, align int)) {
    28  		for _, align := range alignments {
    29  			for v := 0; v < variants; v++ {
    30  				emit(v, align)
    31  			}
    32  		}
    33  	}
    34  
    35  	emitAlignments(AxpyPointer)
    36  	emitAlignments(AxpyPointerLoop)
    37  	emitAlignments(AxpyPointerLoopX)
    38  	emitAlignments(AxpyUnsafeX)
    39  	emitAlignments(func(variant, align int) { AxpyUnsafeXUnroll(variant, align, 4) })
    40  	emitAlignments(func(variant, align int) { AxpyUnsafeXUnroll(variant, align, 8) })
    41  	emitAlignments(func(variant, align int) { AxpyUnsafeXInterleaveUnroll(variant, align, 4) })
    42  	emitAlignments(func(variant, align int) { AxpyUnsafeXInterleaveUnroll(variant, align, 8) })
    43  	emitAlignments(func(variant, align int) { AxpyPointerLoopXUnroll(variant, align, 4) })
    44  	emitAlignments(func(variant, align int) { AxpyPointerLoopXUnroll(variant, align, 8) })
    45  	emitAlignments(func(variant, align int) { AxpyPointerLoopXInterleaveUnroll(variant, align, 4) })
    46  	emitAlignments(func(variant, align int) { AxpyPointerLoopXInterleaveUnroll(variant, align, 8) })
    47  
    48  	Generate()
    49  
    50  	if *testhelp != "" {
    51  		generateTestHelp("axpy_amd64.go", *testhelp)
    52  	}
    53  }
    54  
    55  func generateTestHelp(stubs, out string) {
    56  	data, err := os.ReadFile(stubs)
    57  	if err != nil {
    58  		fmt.Fprintln(os.Stderr, err)
    59  	}
    60  
    61  	fns := []string{}
    62  
    63  	rx := regexp.MustCompile("func ([a-zA-Z0-9_]+)\\(")
    64  	for _, match := range rx.FindAllStringSubmatch(string(data), -1) {
    65  		fns = append(fns, match[1])
    66  	}
    67  
    68  	var b bytes.Buffer
    69  	pf := func(format string, args ...interface{}) {
    70  		fmt.Fprintf(&b, format, args...)
    71  	}
    72  
    73  	pf("// Code generated by command. DO NOT EDIT.\n\n")
    74  	pf("package compare\n\n")
    75  	pf("type amdAxpyDecl struct {\n")
    76  	pf("	name string\n")
    77  	pf("	fn   func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)\n")
    78  	pf("}\n\n")
    79  
    80  	pf("var amdAxpyDecls = []amdAxpyDecl{\n")
    81  	for _, fn := range fns {
    82  		pf("	{name: %q, fn: %v},\n", strings.TrimPrefix(fn, "Amd"), fn)
    83  	}
    84  	pf("}\n")
    85  
    86  	formatted, err := format.Source(b.Bytes())
    87  	if err != nil {
    88  		fmt.Fprintln(os.Stderr, b.Bytes())
    89  		fmt.Fprintln(os.Stderr, err)
    90  		os.Exit(1)
    91  	}
    92  
    93  	os.WriteFile(out, formatted, 0755)
    94  }
    95  
    96  func AxpyPointer(variant, align int) {
    97  	TEXT(fmt.Sprintf("AmdAxpyPointer_V%vA%v", variant, align), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)")
    98  
    99  	alpha := Load(Param("alpha"), XMM())
   100  
   101  	xs := Mem{Base: Load(Param("xs"), GP64())}
   102  	incx := Load(Param("incx"), GP64())
   103  
   104  	ys := Mem{Base: Load(Param("ys"), GP64())}
   105  	incy := Load(Param("incy"), GP64())
   106  
   107  	n := Load(Param("n"), GP64())
   108  
   109  	end := n
   110  	SHLQ(U8(0x2), end)
   111  	IMULQ(incx, end)
   112  	ADDQ(xs.Base, end)
   113  	JMP(LabelRef("check_limit"))
   114  
   115  	MISALIGN(align)
   116  	Label("loop")
   117  	{
   118  		tmp := XMM()
   119  		MOVSS(xs, tmp)
   120  		MULSS(alpha, tmp)
   121  		ADDSS(ys, tmp)
   122  		MOVSS(tmp, ys)
   123  
   124  		LEAQ(xs.Idx(incx, 4), xs.Base)
   125  		LEAQ(ys.Idx(incy, 4), ys.Base)
   126  
   127  		Label("check_limit")
   128  
   129  		CMPQ(end, xs.Base)
   130  		JHI(LabelRef("loop"))
   131  	}
   132  
   133  	RET()
   134  }
   135  
   136  func AxpyPointerLoop(variant, align int) {
   137  	TEXT(fmt.Sprintf("AmdAxpyPointerLoop_V%vA%v", variant, align), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)")
   138  
   139  	alpha := Load(Param("alpha"), XMM())
   140  
   141  	xs := Mem{Base: Load(Param("xs"), GP64())}
   142  	incx := Load(Param("incx"), GP64())
   143  
   144  	ys := Mem{Base: Load(Param("ys"), GP64())}
   145  	incy := Load(Param("incy"), GP64())
   146  
   147  	n := Load(Param("n"), GP64())
   148  	counter := GP64()
   149  	XORQ(counter, counter)
   150  
   151  	JMP(LabelRef("check_limit"))
   152  
   153  	MISALIGN(align)
   154  	Label("loop")
   155  	{
   156  		tmp := XMM()
   157  		MOVSS(xs, tmp)
   158  		MULSS(alpha, tmp)
   159  		ADDSS(ys, tmp)
   160  		MOVSS(tmp, ys)
   161  
   162  		INCQ(counter)
   163  
   164  		LEAQ(xs.Idx(incx, 4), xs.Base)
   165  		LEAQ(ys.Idx(incy, 4), ys.Base)
   166  
   167  		Label("check_limit")
   168  
   169  		CMPQ(n, counter)
   170  		JHI(LabelRef("loop"))
   171  	}
   172  
   173  	RET()
   174  }
   175  
   176  func AxpyPointerLoopX(variant, align int) {
   177  	TEXT(fmt.Sprintf("AmdAxpyPointerLoopX_V%vA%v", variant, align), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)")
   178  
   179  	alpha := Load(Param("alpha"), XMM())
   180  
   181  	xs := Mem{Base: Load(Param("xs"), GP64())}
   182  	incx := Load(Param("incx"), GP64())
   183  
   184  	ys := Mem{Base: Load(Param("ys"), GP64())}
   185  	incy := Load(Param("incy"), GP64())
   186  
   187  	n := Load(Param("n"), GP64())
   188  
   189  	JMP(LabelRef("check_limit"))
   190  
   191  	MISALIGN(align)
   192  	Label("loop")
   193  	{
   194  		tmp := XMM()
   195  		MOVSS(xs, tmp)
   196  		MULSS(alpha, tmp)
   197  		ADDSS(ys, tmp)
   198  		MOVSS(tmp, ys)
   199  
   200  		DECQ(n)
   201  
   202  		LEAQ(xs.Idx(incx, 4), xs.Base)
   203  		LEAQ(ys.Idx(incy, 4), ys.Base)
   204  
   205  		Label("check_limit")
   206  
   207  		CMPQ(n, U8(0))
   208  		JHI(LabelRef("loop"))
   209  	}
   210  
   211  	RET()
   212  }
   213  
   214  func log2(v int) int {
   215  	if v&(v-1) != 0 {
   216  		panic("not a power of two")
   217  	}
   218  	return bits.TrailingZeros(uint(v))
   219  }
   220  
   221  func AxpyPointerLoopXUnroll(variant, align, unroll int) {
   222  	TEXT(fmt.Sprintf("AmdAxpyPointerLoopX_V%vA%vU%v", variant, align, unroll), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)")
   223  
   224  	alpha := Load(Param("alpha"), XMM())
   225  
   226  	xs := Mem{Base: Load(Param("xs"), GP64())}
   227  	incx := Load(Param("incx"), GP64())
   228  
   229  	ys := Mem{Base: Load(Param("ys"), GP64())}
   230  	incy := Load(Param("incy"), GP64())
   231  
   232  	n := Load(Param("n"), GP64())
   233  
   234  	JMP(LabelRef("check_limit_unroll"))
   235  
   236  	MISALIGN(align)
   237  	Label("loop_unroll")
   238  	{
   239  		for u := 0; u < unroll; u++ {
   240  			tmp := XMM()
   241  
   242  			MOVSS(xs, tmp)
   243  			MULSS(alpha, tmp)
   244  			ADDSS(ys, tmp)
   245  			MOVSS(tmp, ys)
   246  
   247  			LEAQ(xs.Idx(incx, 4), xs.Base)
   248  			LEAQ(ys.Idx(incy, 4), ys.Base)
   249  		}
   250  
   251  		SUBQ(Imm(uint64(unroll)), n)
   252  
   253  		Label("check_limit_unroll")
   254  
   255  		CMPQ(n, U8(unroll))
   256  		JHS(LabelRef("loop_unroll"))
   257  	}
   258  
   259  	JMP(LabelRef("check_limit"))
   260  	Label("loop")
   261  	{
   262  		tmp := XMM()
   263  		MOVSS(xs, tmp)
   264  		MULSS(alpha, tmp)
   265  		ADDSS(ys, tmp)
   266  		MOVSS(tmp, ys)
   267  
   268  		DECQ(n)
   269  
   270  		LEAQ(xs.Idx(incx, 4), xs.Base)
   271  		LEAQ(ys.Idx(incy, 4), ys.Base)
   272  
   273  		Label("check_limit")
   274  
   275  		CMPQ(n, U8(0))
   276  		JHI(LabelRef("loop"))
   277  	}
   278  
   279  	RET()
   280  }
   281  
   282  func AxpyPointerLoopXInterleaveUnroll(variant, align, unroll int) {
   283  	TEXT(fmt.Sprintf("AmdAxpyPointerLoopXInterleave_V%vA%vU%v", variant, align, unroll), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)")
   284  
   285  	alpha := Load(Param("alpha"), XMM())
   286  
   287  	xs := Mem{Base: Load(Param("xs"), GP64())}
   288  	incx := Load(Param("incx"), GP64())
   289  	incxunroll := GP64()
   290  	MOVQ(incx, incxunroll)
   291  	SHLQ(U8(log2(4*unroll)), incxunroll)
   292  
   293  	ys := Mem{Base: Load(Param("ys"), GP64())}
   294  	incy := Load(Param("incy"), GP64())
   295  	incyunroll := GP64()
   296  	MOVQ(incy, incyunroll)
   297  	SHLQ(U8(log2(4*unroll)), incyunroll)
   298  
   299  	n := Load(Param("n"), GP64())
   300  
   301  	JMP(LabelRef("check_limit_unroll"))
   302  
   303  	MISALIGN(align)
   304  	Label("loop_unroll")
   305  	{
   306  		tmp := make([]reg.VecVirtual, unroll)
   307  
   308  		for u := range tmp {
   309  			tmp[u] = XMM()
   310  		}
   311  
   312  		for u := 0; u < unroll; u++ {
   313  			MOVSS(xs, tmp[u])
   314  			LEAQ(xs.Idx(incx, 4), xs.Base)
   315  		}
   316  		for u := 0; u < unroll; u++ {
   317  			MULSS(alpha, tmp[u])
   318  		}
   319  		for u := 0; u < unroll; u++ {
   320  			ADDSS(ys, tmp[u])
   321  			MOVSS(tmp[u], ys)
   322  			LEAQ(ys.Idx(incy, 4), ys.Base)
   323  		}
   324  
   325  		SUBQ(Imm(uint64(unroll)), n)
   326  
   327  		Label("check_limit_unroll")
   328  
   329  		CMPQ(n, U8(unroll))
   330  		JHS(LabelRef("loop_unroll"))
   331  	}
   332  
   333  	JMP(LabelRef("check_limit"))
   334  	Label("loop")
   335  	{
   336  		tmp := XMM()
   337  		MOVSS(xs, tmp)
   338  		MULSS(alpha, tmp)
   339  		ADDSS(ys, tmp)
   340  		MOVSS(tmp, ys)
   341  
   342  		DECQ(n)
   343  
   344  		LEAQ(xs.Idx(incx, 4), xs.Base)
   345  		LEAQ(ys.Idx(incy, 4), ys.Base)
   346  
   347  		Label("check_limit")
   348  
   349  		CMPQ(n, U8(0))
   350  		JHI(LabelRef("loop"))
   351  	}
   352  
   353  	RET()
   354  }
   355  
   356  func AxpyUnsafeX(variant, align int) {
   357  	TEXT(fmt.Sprintf("AmdAxpyUnsafeX_V%vA%v", variant, align), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)")
   358  
   359  	alpha := Load(Param("alpha"), XMM())
   360  
   361  	xs := Mem{Base: Load(Param("xs"), GP64())}
   362  	incx := Load(Param("incx"), GP64())
   363  
   364  	ys := Mem{Base: Load(Param("ys"), GP64())}
   365  	incy := Load(Param("incy"), GP64())
   366  
   367  	n := Load(Param("n"), GP64())
   368  
   369  	xi, yi := GP64(), GP64()
   370  	XORQ(xi, xi)
   371  	XORQ(yi, yi)
   372  
   373  	JMP(LabelRef("check_limit"))
   374  
   375  	MISALIGN(align)
   376  	Label("loop")
   377  	{
   378  		tmp := XMM()
   379  		MOVSS(xs.Idx(xi, 4), tmp)
   380  		MULSS(alpha, tmp)
   381  		ADDSS(ys.Idx(yi, 4), tmp)
   382  		MOVSS(tmp, ys.Idx(yi, 4))
   383  
   384  		DECQ(n)
   385  		ADDQ(incx, xi)
   386  		ADDQ(incy, yi)
   387  
   388  		Label("check_limit")
   389  
   390  		CMPQ(n, U8(0))
   391  		JHI(LabelRef("loop"))
   392  	}
   393  
   394  	RET()
   395  }
   396  
   397  func AxpyUnsafeXUnroll(variant, align, unroll int) {
   398  	TEXT(fmt.Sprintf("AmdAxpyUnsafeX_V%vA%vR%v", variant, align, unroll), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)")
   399  
   400  	alpha := Load(Param("alpha"), XMM())
   401  
   402  	xs := Mem{Base: Load(Param("xs"), GP64())}
   403  	incx := Load(Param("incx"), GP64())
   404  
   405  	ys := Mem{Base: Load(Param("ys"), GP64())}
   406  	incy := Load(Param("incy"), GP64())
   407  
   408  	n := Load(Param("n"), GP64())
   409  
   410  	xi, yi := GP64(), GP64()
   411  	XORQ(xi, xi)
   412  	XORQ(yi, yi)
   413  
   414  	JMP(LabelRef("check_limit_unroll"))
   415  
   416  	MISALIGN(align)
   417  	Label("loop_unroll")
   418  	{
   419  		for u := 0; u < unroll; u++ {
   420  			tmp := XMM()
   421  
   422  			xat := Mem{Base: xs.Base, Index: xi, Scale: 4, Disp: 0}
   423  			yat := Mem{Base: ys.Base, Index: yi, Scale: 4, Disp: 0}
   424  			MOVSS(xat, tmp)
   425  			MULSS(alpha, tmp)
   426  			ADDSS(yat, tmp)
   427  			MOVSS(tmp, yat)
   428  
   429  			ADDQ(incx, xi)
   430  			ADDQ(incy, yi)
   431  		}
   432  
   433  		SUBQ(Imm(uint64(unroll)), n)
   434  
   435  		Label("check_limit_unroll")
   436  
   437  		CMPQ(n, U8(unroll))
   438  		JHI(LabelRef("loop_unroll"))
   439  	}
   440  
   441  	JMP(LabelRef("check_limit"))
   442  	Label("loop")
   443  	{
   444  		tmp := XMM()
   445  		MOVSS(xs.Idx(xi, 4), tmp)
   446  		MULSS(alpha, tmp)
   447  		ADDSS(ys.Idx(yi, 4), tmp)
   448  		MOVSS(tmp, ys.Idx(yi, 4))
   449  
   450  		DECQ(n)
   451  		ADDQ(incx, xi)
   452  		ADDQ(incy, yi)
   453  
   454  		Label("check_limit")
   455  
   456  		CMPQ(n, U8(0))
   457  		JHI(LabelRef("loop"))
   458  	}
   459  
   460  	RET()
   461  }
   462  
   463  func AxpyUnsafeXInterleaveUnroll(variant, align, unroll int) {
   464  	TEXT(fmt.Sprintf("AmdAxpyUnsafeXInterleave_V%vA%vR%v", variant, align, unroll), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)")
   465  
   466  	alpha := Load(Param("alpha"), XMM())
   467  
   468  	xs := Mem{Base: Load(Param("xs"), GP64())}
   469  	incx := Load(Param("incx"), GP64())
   470  
   471  	ys := Mem{Base: Load(Param("ys"), GP64())}
   472  	incy := Load(Param("incy"), GP64())
   473  
   474  	n := Load(Param("n"), GP64())
   475  
   476  	xi, yi := GP64(), GP64()
   477  	XORQ(xi, xi)
   478  	XORQ(yi, yi)
   479  
   480  	JMP(LabelRef("check_limit_unroll"))
   481  
   482  	MISALIGN(align)
   483  	Label("loop_unroll")
   484  	{
   485  		tmp := make([]reg.VecVirtual, unroll)
   486  		for u := range tmp {
   487  			tmp[u] = XMM()
   488  		}
   489  
   490  		for u := 0; u < unroll; u++ {
   491  			MOVSS(xs.Idx(xi, 4), tmp[u])
   492  			ADDQ(incx, xi)
   493  		}
   494  		for u := 0; u < unroll; u++ {
   495  			MULSS(alpha, tmp[u])
   496  		}
   497  		for u := 0; u < unroll; u++ {
   498  			ADDSS(ys.Idx(yi, 4), tmp[u])
   499  			MOVSS(tmp[u], ys.Idx(yi, 4))
   500  			ADDQ(incy, yi)
   501  		}
   502  
   503  		SUBQ(Imm(uint64(unroll)), n)
   504  
   505  		Label("check_limit_unroll")
   506  
   507  		CMPQ(n, U8(unroll))
   508  		JHS(LabelRef("loop_unroll"))
   509  	}
   510  
   511  	JMP(LabelRef("check_limit"))
   512  	Label("loop")
   513  	{
   514  		tmp := XMM()
   515  		MOVSS(xs.Idx(xi, 4), tmp)
   516  		MULSS(alpha, tmp)
   517  		ADDSS(ys.Idx(yi, 4), tmp)
   518  		MOVSS(tmp, ys.Idx(yi, 4))
   519  
   520  		DECQ(n)
   521  		ADDQ(incx, xi)
   522  		ADDQ(incy, yi)
   523  
   524  		Label("check_limit")
   525  
   526  		CMPQ(n, U8(0))
   527  		JHI(LabelRef("loop"))
   528  	}
   529  
   530  	RET()
   531  }
   532  
   533  func MISALIGN(n int) {
   534  	if n == 0 {
   535  		return
   536  	}
   537  
   538  	nearestPowerOf2 := 8
   539  	for n >= nearestPowerOf2*2 {
   540  		nearestPowerOf2 *= 2
   541  	}
   542  	if nearestPowerOf2 >= 8 {
   543  		Instruction(&ir.Instruction{
   544  			Opcode:   "PCALIGN",
   545  			Operands: []Op{Imm(uint64(nearestPowerOf2))},
   546  		})
   547  		n -= nearestPowerOf2
   548  	}
   549  
   550  	for i := 0; i < n; i++ {
   551  		NOP()
   552  	}
   553  }