github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/compressionhelpers/kmeans.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package compressionhelpers
    13  
    14  import (
    15  	"encoding/binary"
    16  	"errors"
    17  	"fmt"
    18  	"math"
    19  	"math/rand"
    20  
    21  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    22  )
    23  
    24  type FilterFunc func([]float32) []float32
    25  
    26  type KMeans struct {
    27  	K                  int     // How many centroids
    28  	DeltaThreshold     float32 // Used to stop fitting if there are not too much changes in the centroids anymore
    29  	IterationThreshold int     // Used to stop fitting after a certain amount of iterations
    30  	Distance           distancer.Provider
    31  	centers            [][]float32 // The current centroids
    32  	dimensions         int         // Dimensions of the data
    33  	segment            int         // Segment where it operates
    34  
    35  	data KMeansPartitionData // Non-persistent data used only during the fitting process
    36  }
    37  
    38  // String prints some minimal information about the encoder. This can be
    39  // used for viability checks to see if the encoder was initialized
    40  // correctly – for example after a restart.
    41  func (k *KMeans) String() string {
    42  	maxElem := 5
    43  	var firstCenters []float32
    44  	i := 0
    45  	for _, center := range k.centers {
    46  		for _, centerVal := range center {
    47  			if i == maxElem {
    48  				break
    49  			}
    50  
    51  			firstCenters = append(firstCenters, centerVal)
    52  			i++
    53  		}
    54  		if i == maxElem {
    55  			break
    56  		}
    57  	}
    58  	return fmt.Sprintf("KMeans Encoder: K=%d, dim=%d, segment=%d first_center_truncated=%v", k.K, k.dimensions, k.segment, firstCenters)
    59  }
    60  
    61  type KMeansPartitionData struct {
    62  	changes int        // How many vectors has jumped to a new cluster
    63  	points  []uint64   // Cluster assigned to each point
    64  	cc      [][]uint64 // Partition of the data into the clusters
    65  }
    66  
    67  func NewKMeans(k int, dimensions int, segment int) *KMeans {
    68  	kMeans := &KMeans{
    69  		K:                  k,
    70  		DeltaThreshold:     0.01,
    71  		IterationThreshold: 10,
    72  		Distance:           distancer.NewL2SquaredProvider(),
    73  		dimensions:         dimensions,
    74  		segment:            segment,
    75  	}
    76  	return kMeans
    77  }
    78  
    79  func NewKMeansWithCenters(k int, dimensions int, segment int, centers [][]float32) *KMeans {
    80  	kmeans := NewKMeans(k, dimensions, segment)
    81  	kmeans.centers = centers
    82  	return kmeans
    83  }
    84  
    85  func (m *KMeans) ExposeDataForRestore() []byte {
    86  	ds := len(m.centers[0])
    87  	len := 4 * m.K * ds
    88  	buffer := make([]byte, len)
    89  	for i := 0; i < len/4; i++ {
    90  		binary.LittleEndian.PutUint32(buffer[i*4:(i+1)*4], math.Float32bits(m.centers[i/ds][i%ds]))
    91  	}
    92  	return buffer
    93  }
    94  
    95  func (m *KMeans) Add(x []float32) {
    96  	// nothing to do here
    97  }
    98  
    99  func (m *KMeans) Centers() [][]float32 {
   100  	return m.centers
   101  }
   102  
   103  func (m *KMeans) Encode(point []float32) byte {
   104  	return byte(m.Nearest(point))
   105  }
   106  
   107  func (m *KMeans) Nearest(point []float32) uint64 {
   108  	return m.NNearest(point, 1)[0]
   109  }
   110  
   111  func (m *KMeans) nNearest(point []float32, n int) ([]uint64, []float32) {
   112  	mins := make([]uint64, n)
   113  	minD := make([]float32, n)
   114  	for i := range mins {
   115  		mins[i] = 0
   116  		minD[i] = math.MaxFloat32
   117  	}
   118  	filteredPoint := point[m.segment*m.dimensions : (m.segment+1)*m.dimensions]
   119  	for i, c := range m.centers {
   120  		distance, _, _ := m.Distance.SingleDist(filteredPoint, c)
   121  		j := 0
   122  		for (j < n) && minD[j] < distance {
   123  			j++
   124  		}
   125  		if j < n {
   126  			for l := n - 1; l >= j+1; l-- {
   127  				mins[l] = mins[l-1]
   128  				minD[l] = minD[l-1]
   129  			}
   130  			minD[j] = distance
   131  			mins[j] = uint64(i)
   132  		}
   133  	}
   134  	return mins, minD
   135  }
   136  
   137  func (m *KMeans) NNearest(point []float32, n int) []uint64 {
   138  	nearest, _ := m.nNearest(point, n)
   139  	return nearest
   140  }
   141  
   142  func (m *KMeans) initCenters(data [][]float32) {
   143  	if len(m.centers) == m.K {
   144  		return
   145  	}
   146  	m.centers = make([][]float32, 0, m.K)
   147  	for i := 0; i < m.K; i++ {
   148  		var vec []float32
   149  		for vec == nil {
   150  			vec = data[rand.Intn(len(data))]
   151  		}
   152  		vecCopy := make([]float32, m.dimensions)
   153  		copy(vecCopy, vec[m.segment*m.dimensions:(m.segment+1)*m.dimensions])
   154  		m.centers = append(m.centers, vecCopy)
   155  	}
   156  }
   157  
   158  func (m *KMeans) recluster(data [][]float32) {
   159  	for p := 0; p < len(data); p++ {
   160  		point := data[p]
   161  		if point == nil {
   162  			continue
   163  		}
   164  		cis, _ := m.nNearest(point, 1)
   165  		ci := cis[0]
   166  		m.data.cc[ci] = append(m.data.cc[ci], uint64(p))
   167  		if m.data.points[p] != ci {
   168  			m.data.points[p] = ci
   169  			m.data.changes++
   170  		}
   171  	}
   172  }
   173  
   174  func (m *KMeans) resortOnEmptySets(data [][]float32) {
   175  	k64 := uint64(m.K)
   176  	dataSize := len(data)
   177  	for ci := uint64(0); ci < k64; ci++ {
   178  		if len(m.data.cc[ci]) == 0 {
   179  			var ri int
   180  			for {
   181  				ri = rand.Intn(dataSize)
   182  				if data[ri] == nil {
   183  					continue
   184  				}
   185  				if len(m.data.cc[m.data.points[ri]]) > 1 {
   186  					break
   187  				}
   188  			}
   189  			m.data.cc[ci] = append(m.data.cc[ci], uint64(ri))
   190  			m.data.points[ri] = ci
   191  			m.data.changes = dataSize
   192  		}
   193  	}
   194  }
   195  
   196  func (m *KMeans) recalcCenters(data [][]float32) {
   197  	for index := 0; index < m.K; index++ {
   198  		for j := range m.centers[index] {
   199  			m.centers[index][j] = 0
   200  		}
   201  		size := len(m.data.cc[index])
   202  		for _, ci := range m.data.cc[index] {
   203  			vec := data[ci]
   204  			v := vec[m.segment*m.dimensions : (m.segment+1)*m.dimensions]
   205  			for j := 0; j < m.dimensions; j++ {
   206  				m.centers[index][j] += v[j]
   207  			}
   208  		}
   209  		for j := 0; j < m.dimensions; j++ {
   210  			m.centers[index][j] /= float32(size)
   211  		}
   212  	}
   213  }
   214  
   215  func (m *KMeans) stopCondition(iterations int, dataSize int) bool {
   216  	return iterations >= m.IterationThreshold ||
   217  		m.data.changes < int(float32(dataSize)*m.DeltaThreshold)
   218  }
   219  
   220  func (m *KMeans) Fit(data [][]float32) error { // init centers using min/max per dimension
   221  	dataSize := len(data)
   222  	if dataSize < m.K {
   223  		return errors.New("not enough data to fit kmeans")
   224  	}
   225  	m.initCenters(data)
   226  	m.data.points = make([]uint64, dataSize)
   227  	m.data.changes = 1
   228  
   229  	for i := 0; m.data.changes > 0; i++ {
   230  		m.data.changes = 0
   231  		m.data.cc = make([][]uint64, m.K)
   232  		for j := range m.data.cc {
   233  			m.data.cc[j] = make([]uint64, 0)
   234  		}
   235  
   236  		m.recluster(data)
   237  		m.resortOnEmptySets(data)
   238  		if m.data.changes > 0 {
   239  			m.recalcCenters(data)
   240  		}
   241  
   242  		if m.stopCondition(i, dataSize) {
   243  			break
   244  		}
   245  
   246  	}
   247  
   248  	m.clearData()
   249  	return nil
   250  }
   251  
   252  func (m *KMeans) clearData() {
   253  	m.data.points = nil
   254  	m.data.cc = nil
   255  }
   256  
   257  func (m *KMeans) Center(point []float32) []float32 {
   258  	return m.centers[m.Nearest(point)]
   259  }
   260  
   261  func (m *KMeans) Centroid(i byte) []float32 {
   262  	return m.centers[i]
   263  }