github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/gnum/vecs.go (about) 1 package gnum 2 3 import ( 4 "fmt" 5 "math" 6 7 "golang.org/x/exp/constraints" 8 ) 9 10 // L1 returns the L1 (Manhattan) distance between a and b. 11 // Equivalent to Lp(1) but returns the same type. 12 func L1[S ~[]N, N Number](a, b S) N { 13 assertMatchingLengths(a, b) 14 var sum N 15 for i := range a { 16 sum += Diff(a[i], b[i]) 17 } 18 return sum 19 } 20 21 // L2 returns the L2 (Euclidean) distance between a and b. 22 // Equivalent to Lp(2). 23 func L2[S ~[]N, N Number](a, b S) float64 { 24 assertMatchingLengths(a, b) 25 var sum N 26 for i := range a { 27 d := (a[i] - b[i]) 28 sum += d * d 29 } 30 return math.Sqrt(float64(sum)) 31 } 32 33 // Lp returns an Lp distance function. Lp is calculated as follows: 34 // 35 // Lp(v) = (sum_i(v[i]^p))^(1/p) 36 func Lp[S ~[]N, N Number](p int) func(S, S) float64 { 37 if p < 1 { 38 panic(fmt.Sprintf("invalid p: %d", p)) 39 } 40 41 if p == 1 { 42 return func(a, b S) float64 { 43 return float64(L1(a, b)) 44 } 45 } 46 if p == 2 { 47 return L2[S, N] 48 } 49 50 return func(a, b S) float64 { 51 assertMatchingLengths(a, b) 52 fp := float64(p) 53 var sum float64 54 for i := range a { 55 sum += math.Pow(float64(Diff(a[i], b[i])), fp) 56 } 57 return math.Pow(sum, 1/fp) 58 } 59 } 60 61 // Add adds b to a and returns a. b is unchanged. If a is nil, creates a new 62 // vector. 63 func Add[S ~[]N, N Number](a S, b ...S) S { 64 if a == nil { 65 if len(b) == 0 { 66 return nil 67 } 68 a = make(S, len(b[0])) 69 } 70 for i := range b { 71 assertMatchingLengths(a, b[i]) 72 for j := range a { 73 a[j] += b[i][j] 74 } 75 } 76 return a 77 } 78 79 // Sub subtracts b from a and returns a. b is unchanged. If a is nil, creates a 80 // new vector. 81 func Sub[S ~[]N, N Number](a S, b ...S) S { 82 if a == nil { 83 if len(b) == 0 { 84 return nil 85 } 86 a = make(S, len(b[0])) 87 } 88 for i := range b { 89 assertMatchingLengths(a, b[i]) 90 for j := range a { 91 a[j] -= b[i][j] 92 } 93 } 94 return a 95 } 96 97 // Mul multiplies a by b and returns a. b is unchanged. If a is nil, creates a 98 // new vector. 99 func Mul[S ~[]N, N Number](a S, b ...S) S { 100 if a == nil { 101 if len(b) == 0 { 102 return nil 103 } 104 a = Ones[S](len(b[0])) 105 } 106 for i := range b { 107 assertMatchingLengths(a, b[i]) 108 for j := range a { 109 a[j] -= b[i][j] 110 } 111 } 112 return a 113 } 114 115 // Add1 adds m to a and returns a. 116 func Add1[S ~[]N, N Number](a S, m N) S { 117 for i := range a { 118 a[i] += m 119 } 120 return a 121 } 122 123 // Sub1 subtracts m from a and returns a. 124 func Sub1[S ~[]N, N Number](a S, m N) S { 125 for i := range a { 126 a[i] -= m 127 } 128 return a 129 } 130 131 // Mul1 multiplies the values of a by m and returns a. 132 func Mul1[S ~[]N, N Number](a S, m N) S { 133 for i := range a { 134 a[i] *= m 135 } 136 return a 137 } 138 139 // Dot returns the dot product of the input vectors. 140 func Dot[S ~[]N, N Number](a, b S) N { 141 assertMatchingLengths(a, b) 142 var sum N 143 for i := range a { 144 sum += a[i] * b[i] 145 } 146 return sum 147 } 148 149 // Norm returns the L2 norm of the vector. 150 func Norm[S ~[]N, N constraints.Float](a S) float64 { 151 var norm N 152 for _, v := range a { 153 norm += v * v 154 } 155 return math.Sqrt(float64(norm)) 156 } 157 158 // Ones returns a slice of n ones. Panics if n is negative. 159 func Ones[S ~[]N, N Number](n int) S { 160 if n < 0 { 161 panic(fmt.Sprintf("bad vector length: %d", n)) 162 } 163 a := make(S, n) 164 for i := range a { 165 a[i] = 1 166 } 167 return a 168 } 169 170 // Copy returns a copy of the given slice. 171 func Copy[S ~[]N, N any](a S) S { 172 result := make(S, len(a)) 173 copy(result, a) 174 return result 175 } 176 177 // Cast casts the values of a and places them in a new slice. 178 func Cast[S ~[]N, T ~[]M, N Number, M Number](a S) T { 179 t := make(T, len(a)) 180 for i, s := range a { 181 t[i] = M(s) 182 } 183 return t 184 } 185 186 // Panics if the input vectors are of different lengths. 187 func assertMatchingLengths[S ~[]N, N any](a, b S) { 188 if len(a) != len(b) { 189 panic(fmt.Sprintf("mismatching lengths: %d, %d", len(a), len(b))) 190 } 191 }