github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/compressionhelpers/kmeans.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package compressionhelpers 13 14 import ( 15 "encoding/binary" 16 "errors" 17 "fmt" 18 "math" 19 "math/rand" 20 21 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 22 ) 23 24 type FilterFunc func([]float32) []float32 25 26 type KMeans struct { 27 K int // How many centroids 28 DeltaThreshold float32 // Used to stop fitting if there are not too much changes in the centroids anymore 29 IterationThreshold int // Used to stop fitting after a certain amount of iterations 30 Distance distancer.Provider 31 centers [][]float32 // The current centroids 32 dimensions int // Dimensions of the data 33 segment int // Segment where it operates 34 35 data KMeansPartitionData // Non-persistent data used only during the fitting process 36 } 37 38 // String prints some minimal information about the encoder. This can be 39 // used for viability checks to see if the encoder was initialized 40 // correctly – for example after a restart. 41 func (k *KMeans) String() string { 42 maxElem := 5 43 var firstCenters []float32 44 i := 0 45 for _, center := range k.centers { 46 for _, centerVal := range center { 47 if i == maxElem { 48 break 49 } 50 51 firstCenters = append(firstCenters, centerVal) 52 i++ 53 } 54 if i == maxElem { 55 break 56 } 57 } 58 return fmt.Sprintf("KMeans Encoder: K=%d, dim=%d, segment=%d first_center_truncated=%v", k.K, k.dimensions, k.segment, firstCenters) 59 } 60 61 type KMeansPartitionData struct { 62 changes int // How many vectors has jumped to a new cluster 63 points []uint64 // Cluster assigned to each point 64 cc [][]uint64 // Partition of the data into the clusters 65 } 66 67 func NewKMeans(k int, dimensions int, segment int) *KMeans { 68 kMeans := &KMeans{ 69 K: k, 70 DeltaThreshold: 0.01, 71 IterationThreshold: 10, 72 Distance: distancer.NewL2SquaredProvider(), 73 dimensions: dimensions, 74 segment: segment, 75 } 76 return kMeans 77 } 78 79 func NewKMeansWithCenters(k int, dimensions int, segment int, centers [][]float32) *KMeans { 80 kmeans := NewKMeans(k, dimensions, segment) 81 kmeans.centers = centers 82 return kmeans 83 } 84 85 func (m *KMeans) ExposeDataForRestore() []byte { 86 ds := len(m.centers[0]) 87 len := 4 * m.K * ds 88 buffer := make([]byte, len) 89 for i := 0; i < len/4; i++ { 90 binary.LittleEndian.PutUint32(buffer[i*4:(i+1)*4], math.Float32bits(m.centers[i/ds][i%ds])) 91 } 92 return buffer 93 } 94 95 func (m *KMeans) Add(x []float32) { 96 // nothing to do here 97 } 98 99 func (m *KMeans) Centers() [][]float32 { 100 return m.centers 101 } 102 103 func (m *KMeans) Encode(point []float32) byte { 104 return byte(m.Nearest(point)) 105 } 106 107 func (m *KMeans) Nearest(point []float32) uint64 { 108 return m.NNearest(point, 1)[0] 109 } 110 111 func (m *KMeans) nNearest(point []float32, n int) ([]uint64, []float32) { 112 mins := make([]uint64, n) 113 minD := make([]float32, n) 114 for i := range mins { 115 mins[i] = 0 116 minD[i] = math.MaxFloat32 117 } 118 filteredPoint := point[m.segment*m.dimensions : (m.segment+1)*m.dimensions] 119 for i, c := range m.centers { 120 distance, _, _ := m.Distance.SingleDist(filteredPoint, c) 121 j := 0 122 for (j < n) && minD[j] < distance { 123 j++ 124 } 125 if j < n { 126 for l := n - 1; l >= j+1; l-- { 127 mins[l] = mins[l-1] 128 minD[l] = minD[l-1] 129 } 130 minD[j] = distance 131 mins[j] = uint64(i) 132 } 133 } 134 return mins, minD 135 } 136 137 func (m *KMeans) NNearest(point []float32, n int) []uint64 { 138 nearest, _ := m.nNearest(point, n) 139 return nearest 140 } 141 142 func (m *KMeans) initCenters(data [][]float32) { 143 if len(m.centers) == m.K { 144 return 145 } 146 m.centers = make([][]float32, 0, m.K) 147 for i := 0; i < m.K; i++ { 148 var vec []float32 149 for vec == nil { 150 vec = data[rand.Intn(len(data))] 151 } 152 vecCopy := make([]float32, m.dimensions) 153 copy(vecCopy, vec[m.segment*m.dimensions:(m.segment+1)*m.dimensions]) 154 m.centers = append(m.centers, vecCopy) 155 } 156 } 157 158 func (m *KMeans) recluster(data [][]float32) { 159 for p := 0; p < len(data); p++ { 160 point := data[p] 161 if point == nil { 162 continue 163 } 164 cis, _ := m.nNearest(point, 1) 165 ci := cis[0] 166 m.data.cc[ci] = append(m.data.cc[ci], uint64(p)) 167 if m.data.points[p] != ci { 168 m.data.points[p] = ci 169 m.data.changes++ 170 } 171 } 172 } 173 174 func (m *KMeans) resortOnEmptySets(data [][]float32) { 175 k64 := uint64(m.K) 176 dataSize := len(data) 177 for ci := uint64(0); ci < k64; ci++ { 178 if len(m.data.cc[ci]) == 0 { 179 var ri int 180 for { 181 ri = rand.Intn(dataSize) 182 if data[ri] == nil { 183 continue 184 } 185 if len(m.data.cc[m.data.points[ri]]) > 1 { 186 break 187 } 188 } 189 m.data.cc[ci] = append(m.data.cc[ci], uint64(ri)) 190 m.data.points[ri] = ci 191 m.data.changes = dataSize 192 } 193 } 194 } 195 196 func (m *KMeans) recalcCenters(data [][]float32) { 197 for index := 0; index < m.K; index++ { 198 for j := range m.centers[index] { 199 m.centers[index][j] = 0 200 } 201 size := len(m.data.cc[index]) 202 for _, ci := range m.data.cc[index] { 203 vec := data[ci] 204 v := vec[m.segment*m.dimensions : (m.segment+1)*m.dimensions] 205 for j := 0; j < m.dimensions; j++ { 206 m.centers[index][j] += v[j] 207 } 208 } 209 for j := 0; j < m.dimensions; j++ { 210 m.centers[index][j] /= float32(size) 211 } 212 } 213 } 214 215 func (m *KMeans) stopCondition(iterations int, dataSize int) bool { 216 return iterations >= m.IterationThreshold || 217 m.data.changes < int(float32(dataSize)*m.DeltaThreshold) 218 } 219 220 func (m *KMeans) Fit(data [][]float32) error { // init centers using min/max per dimension 221 dataSize := len(data) 222 if dataSize < m.K { 223 return errors.New("not enough data to fit kmeans") 224 } 225 m.initCenters(data) 226 m.data.points = make([]uint64, dataSize) 227 m.data.changes = 1 228 229 for i := 0; m.data.changes > 0; i++ { 230 m.data.changes = 0 231 m.data.cc = make([][]uint64, m.K) 232 for j := range m.data.cc { 233 m.data.cc[j] = make([]uint64, 0) 234 } 235 236 m.recluster(data) 237 m.resortOnEmptySets(data) 238 if m.data.changes > 0 { 239 m.recalcCenters(data) 240 } 241 242 if m.stopCondition(i, dataSize) { 243 break 244 } 245 246 } 247 248 m.clearData() 249 return nil 250 } 251 252 func (m *KMeans) clearData() { 253 m.data.points = nil 254 m.data.cc = nil 255 } 256 257 func (m *KMeans) Center(point []float32) []float32 { 258 return m.centers[m.Nearest(point)] 259 } 260 261 func (m *KMeans) Centroid(i byte) []float32 { 262 return m.centers[i] 263 }