github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/vector/compare/axpy_arm64_test.go (about)

     1  package compare
     2  
     3  import (
     4  	"math/rand"
     5  	"slices"
     6  	"strconv"
     7  	"testing"
     8  	"unsafe"
     9  )
    10  
    11  func BenchmarkArm64(b *testing.B) {
    12  	for _, decl := range armAxpyDecls {
    13  		b.Run(decl.name, func(b *testing.B) {
    14  			fn := decl.fn
    15  			for i := 0; i < b.N; i++ {
    16  				fn(alpha, &xs[i&3], 2, &ys[0], 4, K)
    17  			}
    18  		})
    19  	}
    20  }
    21  
    22  func FuzzArm64(f *testing.F) {
    23  	f.Add(int64(0), uint8(0))
    24  	f.Fuzz(func(t *testing.T, seed1 int64, bn byte) {
    25  		n := int(bn) % 8
    26  		xs, ys := make([]float32, n), make([]float32, n)
    27  
    28  		const Scale = 10000
    29  		rng := rand.New(rand.NewSource(seed1))
    30  		for i := range xs {
    31  			xs[i] = rng.Float32()*Scale - Scale/2
    32  			ys[i] = rng.Float32()*Scale - Scale/2
    33  		}
    34  
    35  		expectys := slices.Clone(ys)
    36  		alpha := float32(2.3)
    37  		AxpyBasic(alpha, xs, 1, expectys, 1, uintptr(len(xs)))
    38  
    39  		for _, axpy := range armAxpyDecls {
    40  			lxs := slices.Clone(xs)
    41  			lys := slices.Clone(ys)
    42  			axpy.fn(alpha, unsafe.SliceData(lxs), 1, unsafe.SliceData(lys), 1, uintptr(len(xs)))
    43  			if !equalFloats(lys, expectys) {
    44  				t.Errorf("%q wrong result\n\tgot=%v\n\texp=%v\n\tal=%v\n\txs=%v\n\tys=%v", axpy.name, lys, expectys, alpha, xs, ys)
    45  			}
    46  			if !equalFloats(lxs, xs) {
    47  				t.Errorf("%q xs modified\n\tgot=%v\n\texp=%v\n\tal=%v\n\txs=%v\n\tys=%v", axpy.name, lxs, xs, alpha, xs, ys)
    48  			}
    49  		}
    50  	})
    51  }
    52  
    53  func TestArm64(t *testing.T) {
    54  	for _, axpy := range armAxpyDecls {
    55  		t.Run(axpy.name, func(t *testing.T) {
    56  			for i, test := range axpyTestCases {
    57  				t.Run(strconv.Itoa(i), func(t *testing.T) {
    58  					lxs := slices.Clone(test.xs)
    59  					lys := slices.Clone(test.ys)
    60  
    61  					axpy.fn(test.alpha, unsafe.SliceData(lxs), test.incx, unsafe.SliceData(lys), test.incy, test.N())
    62  
    63  					if !equalFloats(lys, test.expect) {
    64  						t.Errorf("wrong result\n\tgot=%v\n\texp=%v\n\tal=%v\n\txs=%v\n\tys=%v", lys, test.expect, test.alpha, test.xs, test.ys)
    65  					}
    66  					if !equalFloats(lxs, test.xs) {
    67  						t.Errorf("xs modified")
    68  					}
    69  				})
    70  			}
    71  		})
    72  	}
    73  }