gitee.com/quant1x/num@v0.3.2/internal/functions/distance_test.go (about)

     1  package functions
     2  
     3  import (
     4  	"fmt"
     5  	"gitee.com/quant1x/num/internal/rand"
     6  	"gitee.com/quant1x/pkg/testify/require"
     7  	"testing"
     8  )
     9  
    10  func TestCosineSimilarity(t *testing.T) {
    11  	rand.Seed(2)
    12  	for i := 0; i < 1000; i++ {
    13  		size := 1 + (i / 5)
    14  		{
    15  			x := Random[float64](size)
    16  			y := Random[float64](size)
    17  			r1 := Dot_Go(x, y)
    18  			r2 := Dot_AVX2_F64(x, y)
    19  			require.InDelta(t, r1, r2, 0.001)
    20  		}
    21  		{
    22  			x := Random[float64](size)
    23  			y := Random[float64](size)
    24  			r1 := CosineSimilarity_AVX2_F64(x, y)
    25  			r2 := CosineSimilarity_Go_F64(x, y)
    26  			require.InDelta(t, r1, r2, 0.001)
    27  		}
    28  		{
    29  			x := Random[float32](size)
    30  			y := Random[float32](size)
    31  			r1 := CosineSimilarity_Go_F32(x, y)
    32  			r2 := CosineSimilarity_AVX2_F32(x, y)
    33  			require.InDelta(t, r1, r2, 0.001)
    34  		}
    35  	}
    36  }
    37  
    38  func BenchmarkCosineSimilarity(b *testing.B) {
    39  	for _, size := range sizes {
    40  		x := Random[float64](size)
    41  		y := Random[float64](size)
    42  		x32 := Random[float32](size)
    43  		y32 := Random[float32](size)
    44  
    45  		b.Run(fmt.Sprintf("dot_go_f64_%d", size), func(b *testing.B) {
    46  			for i := 0; i < b.N; i++ {
    47  				Dot_Go(x, y)
    48  			}
    49  		})
    50  		b.Run(fmt.Sprintf("dot_go_f32_%d", size), func(b *testing.B) {
    51  			for i := 0; i < b.N; i++ {
    52  				Dot_Go(x32, y32)
    53  			}
    54  		})
    55  		b.Run(fmt.Sprintf("dot_avx2_f64_%d", size), func(b *testing.B) {
    56  			for i := 0; i < b.N; i++ {
    57  				Dot_AVX2_F64(x, y)
    58  			}
    59  		})
    60  		b.Run(fmt.Sprintf("dot_avx2_f32_%d", size), func(b *testing.B) {
    61  			for i := 0; i < b.N; i++ {
    62  				Dot_AVX2_F32(x32, y32)
    63  			}
    64  		})
    65  		b.Run(fmt.Sprintf("cosim_go_f64_%d", size), func(b *testing.B) {
    66  			for i := 0; i < b.N; i++ {
    67  				CosineSimilarity_Go_F64(x, y)
    68  			}
    69  		})
    70  		b.Run(fmt.Sprintf("cosim_go_f32_%d", size), func(b *testing.B) {
    71  			for i := 0; i < b.N; i++ {
    72  				CosineSimilarity_Go_F32(x32, y32)
    73  			}
    74  		})
    75  		b.Run(fmt.Sprintf("cosim_avx2_f64_%d", size), func(b *testing.B) {
    76  			for i := 0; i < b.N; i++ {
    77  				CosineSimilarity_AVX2_F64(x, y)
    78  			}
    79  		})
    80  		b.Run(fmt.Sprintf("cosim_avx2_f32_%d", size), func(b *testing.B) {
    81  			for i := 0; i < b.N; i++ {
    82  				CosineSimilarity_AVX2_F32(x32, y32)
    83  			}
    84  		})
    85  	}
    86  }