github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/gnum/vecs.go (about)

     1  package gnum
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  
     7  	"golang.org/x/exp/constraints"
     8  )
     9  
    10  // L1 returns the L1 (Manhattan) distance between a and b.
    11  // Equivalent to Lp(1) but returns the same type.
    12  func L1[S ~[]N, N Number](a, b S) N {
    13  	assertMatchingLengths(a, b)
    14  	var sum N
    15  	for i := range a {
    16  		sum += Diff(a[i], b[i])
    17  	}
    18  	return sum
    19  }
    20  
    21  // L2 returns the L2 (Euclidean) distance between a and b.
    22  // Equivalent to Lp(2).
    23  func L2[S ~[]N, N Number](a, b S) float64 {
    24  	assertMatchingLengths(a, b)
    25  	var sum N
    26  	for i := range a {
    27  		d := (a[i] - b[i])
    28  		sum += d * d
    29  	}
    30  	return math.Sqrt(float64(sum))
    31  }
    32  
    33  // Lp returns an Lp distance function. Lp is calculated as follows:
    34  //
    35  //	Lp(v) = (sum_i(v[i]^p))^(1/p)
    36  func Lp[S ~[]N, N Number](p int) func(S, S) float64 {
    37  	if p < 1 {
    38  		panic(fmt.Sprintf("invalid p: %d", p))
    39  	}
    40  
    41  	if p == 1 {
    42  		return func(a, b S) float64 {
    43  			return float64(L1(a, b))
    44  		}
    45  	}
    46  	if p == 2 {
    47  		return L2[S, N]
    48  	}
    49  
    50  	return func(a, b S) float64 {
    51  		assertMatchingLengths(a, b)
    52  		fp := float64(p)
    53  		var sum float64
    54  		for i := range a {
    55  			sum += math.Pow(float64(Diff(a[i], b[i])), fp)
    56  		}
    57  		return math.Pow(sum, 1/fp)
    58  	}
    59  }
    60  
    61  // Add adds b to a and returns a. b is unchanged. If a is nil, creates a new
    62  // vector.
    63  func Add[S ~[]N, N Number](a S, b ...S) S {
    64  	if a == nil {
    65  		if len(b) == 0 {
    66  			return nil
    67  		}
    68  		a = make(S, len(b[0]))
    69  	}
    70  	for i := range b {
    71  		assertMatchingLengths(a, b[i])
    72  		for j := range a {
    73  			a[j] += b[i][j]
    74  		}
    75  	}
    76  	return a
    77  }
    78  
    79  // Sub subtracts b from a and returns a. b is unchanged. If a is nil, creates a
    80  // new vector.
    81  func Sub[S ~[]N, N Number](a S, b ...S) S {
    82  	if a == nil {
    83  		if len(b) == 0 {
    84  			return nil
    85  		}
    86  		a = make(S, len(b[0]))
    87  	}
    88  	for i := range b {
    89  		assertMatchingLengths(a, b[i])
    90  		for j := range a {
    91  			a[j] -= b[i][j]
    92  		}
    93  	}
    94  	return a
    95  }
    96  
    97  // Mul multiplies a by b and returns a. b is unchanged. If a is nil, creates a
    98  // new vector.
    99  func Mul[S ~[]N, N Number](a S, b ...S) S {
   100  	if a == nil {
   101  		if len(b) == 0 {
   102  			return nil
   103  		}
   104  		a = Ones[S](len(b[0]))
   105  	}
   106  	for i := range b {
   107  		assertMatchingLengths(a, b[i])
   108  		for j := range a {
   109  			a[j] -= b[i][j]
   110  		}
   111  	}
   112  	return a
   113  }
   114  
   115  // Add1 adds m to a and returns a.
   116  func Add1[S ~[]N, N Number](a S, m N) S {
   117  	for i := range a {
   118  		a[i] += m
   119  	}
   120  	return a
   121  }
   122  
   123  // Sub1 subtracts m from a and returns a.
   124  func Sub1[S ~[]N, N Number](a S, m N) S {
   125  	for i := range a {
   126  		a[i] -= m
   127  	}
   128  	return a
   129  }
   130  
   131  // Mul1 multiplies the values of a by m and returns a.
   132  func Mul1[S ~[]N, N Number](a S, m N) S {
   133  	for i := range a {
   134  		a[i] *= m
   135  	}
   136  	return a
   137  }
   138  
   139  // Dot returns the dot product of the input vectors.
   140  func Dot[S ~[]N, N Number](a, b S) N {
   141  	assertMatchingLengths(a, b)
   142  	var sum N
   143  	for i := range a {
   144  		sum += a[i] * b[i]
   145  	}
   146  	return sum
   147  }
   148  
   149  // Norm returns the L2 norm of the vector.
   150  func Norm[S ~[]N, N constraints.Float](a S) float64 {
   151  	var norm N
   152  	for _, v := range a {
   153  		norm += v * v
   154  	}
   155  	return math.Sqrt(float64(norm))
   156  }
   157  
   158  // Ones returns a slice of n ones. Panics if n is negative.
   159  func Ones[S ~[]N, N Number](n int) S {
   160  	if n < 0 {
   161  		panic(fmt.Sprintf("bad vector length: %d", n))
   162  	}
   163  	a := make(S, n)
   164  	for i := range a {
   165  		a[i] = 1
   166  	}
   167  	return a
   168  }
   169  
   170  // Copy returns a copy of the given slice.
   171  func Copy[S ~[]N, N any](a S) S {
   172  	result := make(S, len(a))
   173  	copy(result, a)
   174  	return result
   175  }
   176  
   177  // Cast casts the values of a and places them in a new slice.
   178  func Cast[S ~[]N, T ~[]M, N Number, M Number](a S) T {
   179  	t := make(T, len(a))
   180  	for i, s := range a {
   181  		t[i] = M(s)
   182  	}
   183  	return t
   184  }
   185  
   186  // Panics if the input vectors are of different lengths.
   187  func assertMatchingLengths[S ~[]N, N any](a, b S) {
   188  	if len(a) != len(b) {
   189  		panic(fmt.Sprintf("mismatching lengths: %d, %d", len(a), len(b)))
   190  	}
   191  }