github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/vector/compare/amd64/main.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "flag" 6 "fmt" 7 "go/format" 8 "math/bits" 9 "os" 10 "regexp" 11 "strings" 12 13 . "github.com/mmcloughlin/avo/build" 14 "github.com/mmcloughlin/avo/ir" 15 . "github.com/mmcloughlin/avo/operand" 16 "github.com/mmcloughlin/avo/reg" 17 ) 18 19 var testhelp = flag.String("testhelp", "", "test helpers") 20 21 func main() { 22 const variants = 6 23 alignments := []int{0, 8, 9, 10, 11, 12, 13, 14, 15, 16} 24 // const variants = 1 25 // alignments := []int{0} 26 27 emitAlignments := func(emit func(variant, align int)) { 28 for _, align := range alignments { 29 for v := 0; v < variants; v++ { 30 emit(v, align) 31 } 32 } 33 } 34 35 emitAlignments(AxpyPointer) 36 emitAlignments(AxpyPointerLoop) 37 emitAlignments(AxpyPointerLoopX) 38 emitAlignments(AxpyUnsafeX) 39 emitAlignments(func(variant, align int) { AxpyUnsafeXUnroll(variant, align, 4) }) 40 emitAlignments(func(variant, align int) { AxpyUnsafeXUnroll(variant, align, 8) }) 41 emitAlignments(func(variant, align int) { AxpyUnsafeXInterleaveUnroll(variant, align, 4) }) 42 emitAlignments(func(variant, align int) { AxpyUnsafeXInterleaveUnroll(variant, align, 8) }) 43 emitAlignments(func(variant, align int) { AxpyPointerLoopXUnroll(variant, align, 4) }) 44 emitAlignments(func(variant, align int) { AxpyPointerLoopXUnroll(variant, align, 8) }) 45 emitAlignments(func(variant, align int) { AxpyPointerLoopXInterleaveUnroll(variant, align, 4) }) 46 emitAlignments(func(variant, align int) { AxpyPointerLoopXInterleaveUnroll(variant, align, 8) }) 47 48 Generate() 49 50 if *testhelp != "" { 51 generateTestHelp("axpy_amd64.go", *testhelp) 52 } 53 } 54 55 func generateTestHelp(stubs, out string) { 56 data, err := os.ReadFile(stubs) 57 if err != nil { 58 fmt.Fprintln(os.Stderr, err) 59 } 60 61 fns := []string{} 62 63 rx := regexp.MustCompile("func ([a-zA-Z0-9_]+)\\(") 64 for _, match := range rx.FindAllStringSubmatch(string(data), -1) { 65 fns = append(fns, match[1]) 66 } 67 68 var b bytes.Buffer 69 pf := func(format string, args ...interface{}) { 70 fmt.Fprintf(&b, format, args...) 71 } 72 73 pf("// Code generated by command. DO NOT EDIT.\n\n") 74 pf("package compare\n\n") 75 pf("type amdAxpyDecl struct {\n") 76 pf(" name string\n") 77 pf(" fn func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)\n") 78 pf("}\n\n") 79 80 pf("var amdAxpyDecls = []amdAxpyDecl{\n") 81 for _, fn := range fns { 82 pf(" {name: %q, fn: %v},\n", strings.TrimPrefix(fn, "Amd"), fn) 83 } 84 pf("}\n") 85 86 formatted, err := format.Source(b.Bytes()) 87 if err != nil { 88 fmt.Fprintln(os.Stderr, b.Bytes()) 89 fmt.Fprintln(os.Stderr, err) 90 os.Exit(1) 91 } 92 93 os.WriteFile(out, formatted, 0755) 94 } 95 96 func AxpyPointer(variant, align int) { 97 TEXT(fmt.Sprintf("AmdAxpyPointer_V%vA%v", variant, align), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)") 98 99 alpha := Load(Param("alpha"), XMM()) 100 101 xs := Mem{Base: Load(Param("xs"), GP64())} 102 incx := Load(Param("incx"), GP64()) 103 104 ys := Mem{Base: Load(Param("ys"), GP64())} 105 incy := Load(Param("incy"), GP64()) 106 107 n := Load(Param("n"), GP64()) 108 109 end := n 110 SHLQ(U8(0x2), end) 111 IMULQ(incx, end) 112 ADDQ(xs.Base, end) 113 JMP(LabelRef("check_limit")) 114 115 MISALIGN(align) 116 Label("loop") 117 { 118 tmp := XMM() 119 MOVSS(xs, tmp) 120 MULSS(alpha, tmp) 121 ADDSS(ys, tmp) 122 MOVSS(tmp, ys) 123 124 LEAQ(xs.Idx(incx, 4), xs.Base) 125 LEAQ(ys.Idx(incy, 4), ys.Base) 126 127 Label("check_limit") 128 129 CMPQ(end, xs.Base) 130 JHI(LabelRef("loop")) 131 } 132 133 RET() 134 } 135 136 func AxpyPointerLoop(variant, align int) { 137 TEXT(fmt.Sprintf("AmdAxpyPointerLoop_V%vA%v", variant, align), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)") 138 139 alpha := Load(Param("alpha"), XMM()) 140 141 xs := Mem{Base: Load(Param("xs"), GP64())} 142 incx := Load(Param("incx"), GP64()) 143 144 ys := Mem{Base: Load(Param("ys"), GP64())} 145 incy := Load(Param("incy"), GP64()) 146 147 n := Load(Param("n"), GP64()) 148 counter := GP64() 149 XORQ(counter, counter) 150 151 JMP(LabelRef("check_limit")) 152 153 MISALIGN(align) 154 Label("loop") 155 { 156 tmp := XMM() 157 MOVSS(xs, tmp) 158 MULSS(alpha, tmp) 159 ADDSS(ys, tmp) 160 MOVSS(tmp, ys) 161 162 INCQ(counter) 163 164 LEAQ(xs.Idx(incx, 4), xs.Base) 165 LEAQ(ys.Idx(incy, 4), ys.Base) 166 167 Label("check_limit") 168 169 CMPQ(n, counter) 170 JHI(LabelRef("loop")) 171 } 172 173 RET() 174 } 175 176 func AxpyPointerLoopX(variant, align int) { 177 TEXT(fmt.Sprintf("AmdAxpyPointerLoopX_V%vA%v", variant, align), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)") 178 179 alpha := Load(Param("alpha"), XMM()) 180 181 xs := Mem{Base: Load(Param("xs"), GP64())} 182 incx := Load(Param("incx"), GP64()) 183 184 ys := Mem{Base: Load(Param("ys"), GP64())} 185 incy := Load(Param("incy"), GP64()) 186 187 n := Load(Param("n"), GP64()) 188 189 JMP(LabelRef("check_limit")) 190 191 MISALIGN(align) 192 Label("loop") 193 { 194 tmp := XMM() 195 MOVSS(xs, tmp) 196 MULSS(alpha, tmp) 197 ADDSS(ys, tmp) 198 MOVSS(tmp, ys) 199 200 DECQ(n) 201 202 LEAQ(xs.Idx(incx, 4), xs.Base) 203 LEAQ(ys.Idx(incy, 4), ys.Base) 204 205 Label("check_limit") 206 207 CMPQ(n, U8(0)) 208 JHI(LabelRef("loop")) 209 } 210 211 RET() 212 } 213 214 func log2(v int) int { 215 if v&(v-1) != 0 { 216 panic("not a power of two") 217 } 218 return bits.TrailingZeros(uint(v)) 219 } 220 221 func AxpyPointerLoopXUnroll(variant, align, unroll int) { 222 TEXT(fmt.Sprintf("AmdAxpyPointerLoopX_V%vA%vU%v", variant, align, unroll), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)") 223 224 alpha := Load(Param("alpha"), XMM()) 225 226 xs := Mem{Base: Load(Param("xs"), GP64())} 227 incx := Load(Param("incx"), GP64()) 228 229 ys := Mem{Base: Load(Param("ys"), GP64())} 230 incy := Load(Param("incy"), GP64()) 231 232 n := Load(Param("n"), GP64()) 233 234 JMP(LabelRef("check_limit_unroll")) 235 236 MISALIGN(align) 237 Label("loop_unroll") 238 { 239 for u := 0; u < unroll; u++ { 240 tmp := XMM() 241 242 MOVSS(xs, tmp) 243 MULSS(alpha, tmp) 244 ADDSS(ys, tmp) 245 MOVSS(tmp, ys) 246 247 LEAQ(xs.Idx(incx, 4), xs.Base) 248 LEAQ(ys.Idx(incy, 4), ys.Base) 249 } 250 251 SUBQ(Imm(uint64(unroll)), n) 252 253 Label("check_limit_unroll") 254 255 CMPQ(n, U8(unroll)) 256 JHS(LabelRef("loop_unroll")) 257 } 258 259 JMP(LabelRef("check_limit")) 260 Label("loop") 261 { 262 tmp := XMM() 263 MOVSS(xs, tmp) 264 MULSS(alpha, tmp) 265 ADDSS(ys, tmp) 266 MOVSS(tmp, ys) 267 268 DECQ(n) 269 270 LEAQ(xs.Idx(incx, 4), xs.Base) 271 LEAQ(ys.Idx(incy, 4), ys.Base) 272 273 Label("check_limit") 274 275 CMPQ(n, U8(0)) 276 JHI(LabelRef("loop")) 277 } 278 279 RET() 280 } 281 282 func AxpyPointerLoopXInterleaveUnroll(variant, align, unroll int) { 283 TEXT(fmt.Sprintf("AmdAxpyPointerLoopXInterleave_V%vA%vU%v", variant, align, unroll), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)") 284 285 alpha := Load(Param("alpha"), XMM()) 286 287 xs := Mem{Base: Load(Param("xs"), GP64())} 288 incx := Load(Param("incx"), GP64()) 289 incxunroll := GP64() 290 MOVQ(incx, incxunroll) 291 SHLQ(U8(log2(4*unroll)), incxunroll) 292 293 ys := Mem{Base: Load(Param("ys"), GP64())} 294 incy := Load(Param("incy"), GP64()) 295 incyunroll := GP64() 296 MOVQ(incy, incyunroll) 297 SHLQ(U8(log2(4*unroll)), incyunroll) 298 299 n := Load(Param("n"), GP64()) 300 301 JMP(LabelRef("check_limit_unroll")) 302 303 MISALIGN(align) 304 Label("loop_unroll") 305 { 306 tmp := make([]reg.VecVirtual, unroll) 307 308 for u := range tmp { 309 tmp[u] = XMM() 310 } 311 312 for u := 0; u < unroll; u++ { 313 MOVSS(xs, tmp[u]) 314 LEAQ(xs.Idx(incx, 4), xs.Base) 315 } 316 for u := 0; u < unroll; u++ { 317 MULSS(alpha, tmp[u]) 318 } 319 for u := 0; u < unroll; u++ { 320 ADDSS(ys, tmp[u]) 321 MOVSS(tmp[u], ys) 322 LEAQ(ys.Idx(incy, 4), ys.Base) 323 } 324 325 SUBQ(Imm(uint64(unroll)), n) 326 327 Label("check_limit_unroll") 328 329 CMPQ(n, U8(unroll)) 330 JHS(LabelRef("loop_unroll")) 331 } 332 333 JMP(LabelRef("check_limit")) 334 Label("loop") 335 { 336 tmp := XMM() 337 MOVSS(xs, tmp) 338 MULSS(alpha, tmp) 339 ADDSS(ys, tmp) 340 MOVSS(tmp, ys) 341 342 DECQ(n) 343 344 LEAQ(xs.Idx(incx, 4), xs.Base) 345 LEAQ(ys.Idx(incy, 4), ys.Base) 346 347 Label("check_limit") 348 349 CMPQ(n, U8(0)) 350 JHI(LabelRef("loop")) 351 } 352 353 RET() 354 } 355 356 func AxpyUnsafeX(variant, align int) { 357 TEXT(fmt.Sprintf("AmdAxpyUnsafeX_V%vA%v", variant, align), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)") 358 359 alpha := Load(Param("alpha"), XMM()) 360 361 xs := Mem{Base: Load(Param("xs"), GP64())} 362 incx := Load(Param("incx"), GP64()) 363 364 ys := Mem{Base: Load(Param("ys"), GP64())} 365 incy := Load(Param("incy"), GP64()) 366 367 n := Load(Param("n"), GP64()) 368 369 xi, yi := GP64(), GP64() 370 XORQ(xi, xi) 371 XORQ(yi, yi) 372 373 JMP(LabelRef("check_limit")) 374 375 MISALIGN(align) 376 Label("loop") 377 { 378 tmp := XMM() 379 MOVSS(xs.Idx(xi, 4), tmp) 380 MULSS(alpha, tmp) 381 ADDSS(ys.Idx(yi, 4), tmp) 382 MOVSS(tmp, ys.Idx(yi, 4)) 383 384 DECQ(n) 385 ADDQ(incx, xi) 386 ADDQ(incy, yi) 387 388 Label("check_limit") 389 390 CMPQ(n, U8(0)) 391 JHI(LabelRef("loop")) 392 } 393 394 RET() 395 } 396 397 func AxpyUnsafeXUnroll(variant, align, unroll int) { 398 TEXT(fmt.Sprintf("AmdAxpyUnsafeX_V%vA%vR%v", variant, align, unroll), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)") 399 400 alpha := Load(Param("alpha"), XMM()) 401 402 xs := Mem{Base: Load(Param("xs"), GP64())} 403 incx := Load(Param("incx"), GP64()) 404 405 ys := Mem{Base: Load(Param("ys"), GP64())} 406 incy := Load(Param("incy"), GP64()) 407 408 n := Load(Param("n"), GP64()) 409 410 xi, yi := GP64(), GP64() 411 XORQ(xi, xi) 412 XORQ(yi, yi) 413 414 JMP(LabelRef("check_limit_unroll")) 415 416 MISALIGN(align) 417 Label("loop_unroll") 418 { 419 for u := 0; u < unroll; u++ { 420 tmp := XMM() 421 422 xat := Mem{Base: xs.Base, Index: xi, Scale: 4, Disp: 0} 423 yat := Mem{Base: ys.Base, Index: yi, Scale: 4, Disp: 0} 424 MOVSS(xat, tmp) 425 MULSS(alpha, tmp) 426 ADDSS(yat, tmp) 427 MOVSS(tmp, yat) 428 429 ADDQ(incx, xi) 430 ADDQ(incy, yi) 431 } 432 433 SUBQ(Imm(uint64(unroll)), n) 434 435 Label("check_limit_unroll") 436 437 CMPQ(n, U8(unroll)) 438 JHI(LabelRef("loop_unroll")) 439 } 440 441 JMP(LabelRef("check_limit")) 442 Label("loop") 443 { 444 tmp := XMM() 445 MOVSS(xs.Idx(xi, 4), tmp) 446 MULSS(alpha, tmp) 447 ADDSS(ys.Idx(yi, 4), tmp) 448 MOVSS(tmp, ys.Idx(yi, 4)) 449 450 DECQ(n) 451 ADDQ(incx, xi) 452 ADDQ(incy, yi) 453 454 Label("check_limit") 455 456 CMPQ(n, U8(0)) 457 JHI(LabelRef("loop")) 458 } 459 460 RET() 461 } 462 463 func AxpyUnsafeXInterleaveUnroll(variant, align, unroll int) { 464 TEXT(fmt.Sprintf("AmdAxpyUnsafeXInterleave_V%vA%vR%v", variant, align, unroll), NOSPLIT, "func(alpha float32, xs *float32, incx uintptr, ys *float32, incy uintptr, n uintptr)") 465 466 alpha := Load(Param("alpha"), XMM()) 467 468 xs := Mem{Base: Load(Param("xs"), GP64())} 469 incx := Load(Param("incx"), GP64()) 470 471 ys := Mem{Base: Load(Param("ys"), GP64())} 472 incy := Load(Param("incy"), GP64()) 473 474 n := Load(Param("n"), GP64()) 475 476 xi, yi := GP64(), GP64() 477 XORQ(xi, xi) 478 XORQ(yi, yi) 479 480 JMP(LabelRef("check_limit_unroll")) 481 482 MISALIGN(align) 483 Label("loop_unroll") 484 { 485 tmp := make([]reg.VecVirtual, unroll) 486 for u := range tmp { 487 tmp[u] = XMM() 488 } 489 490 for u := 0; u < unroll; u++ { 491 MOVSS(xs.Idx(xi, 4), tmp[u]) 492 ADDQ(incx, xi) 493 } 494 for u := 0; u < unroll; u++ { 495 MULSS(alpha, tmp[u]) 496 } 497 for u := 0; u < unroll; u++ { 498 ADDSS(ys.Idx(yi, 4), tmp[u]) 499 MOVSS(tmp[u], ys.Idx(yi, 4)) 500 ADDQ(incy, yi) 501 } 502 503 SUBQ(Imm(uint64(unroll)), n) 504 505 Label("check_limit_unroll") 506 507 CMPQ(n, U8(unroll)) 508 JHS(LabelRef("loop_unroll")) 509 } 510 511 JMP(LabelRef("check_limit")) 512 Label("loop") 513 { 514 tmp := XMM() 515 MOVSS(xs.Idx(xi, 4), tmp) 516 MULSS(alpha, tmp) 517 ADDSS(ys.Idx(yi, 4), tmp) 518 MOVSS(tmp, ys.Idx(yi, 4)) 519 520 DECQ(n) 521 ADDQ(incx, xi) 522 ADDQ(incy, yi) 523 524 Label("check_limit") 525 526 CMPQ(n, U8(0)) 527 JHI(LabelRef("loop")) 528 } 529 530 RET() 531 } 532 533 func MISALIGN(n int) { 534 if n == 0 { 535 return 536 } 537 538 nearestPowerOf2 := 8 539 for n >= nearestPowerOf2*2 { 540 nearestPowerOf2 *= 2 541 } 542 if nearestPowerOf2 >= 8 { 543 Instruction(&ir.Instruction{ 544 Opcode: "PCALIGN", 545 Operands: []Op{Imm(uint64(nearestPowerOf2))}, 546 }) 547 n -= nearestPowerOf2 548 } 549 550 for i := 0; i < n; i++ { 551 NOP() 552 } 553 }