github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/deserializer.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hnsw
    13  
    14  import (
    15  	"bufio"
    16  	"encoding/binary"
    17  	"io"
    18  	"math"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/adapters/repos/db/vector/cache"
    23  	"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
    24  )
    25  
    26  type Deserializer struct {
    27  	logger                   logrus.FieldLogger
    28  	reusableBuffer           []byte
    29  	reusableConnectionsSlice []uint64
    30  }
    31  
    32  type DeserializationResult struct {
    33  	Nodes             []*vertex
    34  	Entrypoint        uint64
    35  	Level             uint16
    36  	Tombstones        map[uint64]struct{}
    37  	EntrypointChanged bool
    38  	PQData            compressionhelpers.PQData
    39  	Compressed        bool
    40  
    41  	// If there is no entry for the links at a level to be replaced, we must
    42  	// assume that all links were appended and prior state must exist
    43  	// Similarly if we run into a "Clear" we need to explicitly set the replace
    44  	// flag, so that future appends aren't always appended and we run into a
    45  	// situation where reading multiple condensed logs in succession leads to too
    46  	// many connections as discovered in
    47  	// https://github.com/weaviate/weaviate/issues/1868
    48  	LinksReplaced map[uint64]map[uint16]struct{}
    49  }
    50  
    51  func (dr DeserializationResult) ReplaceLinks(node uint64, level uint16) bool {
    52  	levels, ok := dr.LinksReplaced[node]
    53  	if !ok {
    54  		return false
    55  	}
    56  
    57  	_, ok = levels[level]
    58  	return ok
    59  }
    60  
    61  func NewDeserializer(logger logrus.FieldLogger) *Deserializer {
    62  	return &Deserializer{logger: logger}
    63  }
    64  
    65  func (d *Deserializer) resetResusableBuffer(size int) {
    66  	if size <= cap(d.reusableBuffer) {
    67  		d.reusableBuffer = d.reusableBuffer[:size]
    68  	} else {
    69  		d.reusableBuffer = make([]byte, size, size*2)
    70  	}
    71  }
    72  
    73  func (d *Deserializer) resetReusableConnectionsSlice(size int) {
    74  	if size <= cap(d.reusableConnectionsSlice) {
    75  		d.reusableConnectionsSlice = d.reusableConnectionsSlice[:size]
    76  	} else {
    77  		d.reusableConnectionsSlice = make([]uint64, size, size*2)
    78  	}
    79  }
    80  
    81  func (d *Deserializer) Do(fd *bufio.Reader,
    82  	initialState *DeserializationResult, keepLinkReplaceInformation bool,
    83  ) (*DeserializationResult, int, error) {
    84  	validLength := 0
    85  	out := initialState
    86  	if out == nil {
    87  		out = &DeserializationResult{
    88  			Nodes:         make([]*vertex, cache.InitialSize),
    89  			Tombstones:    make(map[uint64]struct{}),
    90  			LinksReplaced: make(map[uint64]map[uint16]struct{}),
    91  		}
    92  	}
    93  
    94  	for {
    95  		ct, err := d.ReadCommitType(fd)
    96  		if err != nil {
    97  			if errors.Is(err, io.EOF) {
    98  				break
    99  			}
   100  
   101  			return nil, validLength, err
   102  		}
   103  
   104  		var readThisRound int
   105  
   106  		switch ct {
   107  		case AddNode:
   108  			err = d.ReadNode(fd, out)
   109  			readThisRound = 10
   110  		case SetEntryPointMaxLevel:
   111  			var entrypoint uint64
   112  			var level uint16
   113  			entrypoint, level, err = d.ReadEP(fd)
   114  			out.Entrypoint = entrypoint
   115  			out.Level = level
   116  			out.EntrypointChanged = true
   117  			readThisRound = 10
   118  		case AddLinkAtLevel:
   119  			err = d.ReadLink(fd, out)
   120  			readThisRound = 18
   121  		case AddLinksAtLevel:
   122  			readThisRound, err = d.ReadAddLinks(fd, out)
   123  		case ReplaceLinksAtLevel:
   124  			readThisRound, err = d.ReadLinks(fd, out, keepLinkReplaceInformation)
   125  		case AddTombstone:
   126  			err = d.ReadAddTombstone(fd, out.Tombstones)
   127  			readThisRound = 8
   128  		case RemoveTombstone:
   129  			err = d.ReadRemoveTombstone(fd, out.Tombstones)
   130  			readThisRound = 8
   131  		case ClearLinks:
   132  			err = d.ReadClearLinks(fd, out, keepLinkReplaceInformation)
   133  			readThisRound = 8
   134  		case ClearLinksAtLevel:
   135  			err = d.ReadClearLinksAtLevel(fd, out, keepLinkReplaceInformation)
   136  			readThisRound = 10
   137  		case DeleteNode:
   138  			err = d.ReadDeleteNode(fd, out)
   139  			readThisRound = 8
   140  		case ResetIndex:
   141  			out.Entrypoint = 0
   142  			out.Level = 0
   143  			out.Nodes = make([]*vertex, cache.InitialSize)
   144  		case AddPQ:
   145  			err = d.ReadPQ(fd, out)
   146  			readThisRound = 9
   147  		default:
   148  			err = errors.Errorf("unrecognized commit type %d", ct)
   149  		}
   150  		if err != nil {
   151  			// do not return nil, err, because the err could be a recoverable one
   152  			return out, validLength, err
   153  		} else {
   154  			validLength += 1 + readThisRound // 1 byte for commit type
   155  		}
   156  	}
   157  
   158  	return out, validLength, nil
   159  }
   160  
   161  func (d *Deserializer) ReadNode(r io.Reader, res *DeserializationResult) error {
   162  	id, err := d.readUint64(r)
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	level, err := d.readUint16(r)
   168  	if err != nil {
   169  		return err
   170  	}
   171  
   172  	newNodes, changed, err := growIndexToAccomodateNode(res.Nodes, id, d.logger)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	if changed {
   178  		res.Nodes = newNodes
   179  	}
   180  
   181  	if res.Nodes[id] == nil {
   182  		res.Nodes[id] = &vertex{level: int(level), id: id, connections: make([][]uint64, level+1)}
   183  	} else {
   184  		maybeGrowConnectionsForLevel(&res.Nodes[id].connections, level)
   185  		res.Nodes[id].level = int(level)
   186  	}
   187  	return nil
   188  }
   189  
   190  func (d *Deserializer) ReadEP(r io.Reader) (uint64, uint16, error) {
   191  	id, err := d.readUint64(r)
   192  	if err != nil {
   193  		return 0, 0, err
   194  	}
   195  
   196  	level, err := d.readUint16(r)
   197  	if err != nil {
   198  		return 0, 0, err
   199  	}
   200  
   201  	return id, level, nil
   202  }
   203  
   204  func (d *Deserializer) ReadLink(r io.Reader, res *DeserializationResult) error {
   205  	source, err := d.readUint64(r)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	level, err := d.readUint16(r)
   211  	if err != nil {
   212  		return err
   213  	}
   214  
   215  	target, err := d.readUint64(r)
   216  	if err != nil {
   217  		return err
   218  	}
   219  
   220  	newNodes, changed, err := growIndexToAccomodateNode(res.Nodes, source, d.logger)
   221  	if err != nil {
   222  		return err
   223  	}
   224  
   225  	if changed {
   226  		res.Nodes = newNodes
   227  	}
   228  
   229  	if res.Nodes[int(source)] == nil {
   230  		res.Nodes[int(source)] = &vertex{id: source, connections: make([][]uint64, level+1)}
   231  	}
   232  
   233  	maybeGrowConnectionsForLevel(&res.Nodes[int(source)].connections, level)
   234  
   235  	res.Nodes[int(source)].connections[int(level)] = append(res.Nodes[int(source)].connections[int(level)], target)
   236  	return nil
   237  }
   238  
   239  func (d *Deserializer) ReadLinks(r io.Reader, res *DeserializationResult,
   240  	keepReplaceInfo bool,
   241  ) (int, error) {
   242  	d.resetResusableBuffer(12)
   243  	_, err := io.ReadFull(r, d.reusableBuffer)
   244  	if err != nil {
   245  		return 0, err
   246  	}
   247  
   248  	source := binary.LittleEndian.Uint64(d.reusableBuffer[0:8])
   249  	level := binary.LittleEndian.Uint16(d.reusableBuffer[8:10])
   250  	length := binary.LittleEndian.Uint16(d.reusableBuffer[10:12])
   251  
   252  	targets, err := d.readUint64Slice(r, int(length))
   253  	if err != nil {
   254  		return 0, err
   255  	}
   256  
   257  	newNodes, changed, err := growIndexToAccomodateNode(res.Nodes, source, d.logger)
   258  	if err != nil {
   259  		return 0, err
   260  	}
   261  
   262  	if changed {
   263  		res.Nodes = newNodes
   264  	}
   265  
   266  	if res.Nodes[int(source)] == nil {
   267  		res.Nodes[int(source)] = &vertex{id: source, connections: make([][]uint64, level+1)}
   268  	}
   269  
   270  	maybeGrowConnectionsForLevel(&res.Nodes[int(source)].connections, level)
   271  	res.Nodes[int(source)].connections[int(level)] = make([]uint64, len(targets))
   272  	copy(res.Nodes[int(source)].connections[int(level)], targets)
   273  
   274  	if keepReplaceInfo {
   275  		// mark the replace flag for this node and level, so that new commit logs
   276  		// generated on this result (condensing) do not lose information
   277  
   278  		if _, ok := res.LinksReplaced[source]; !ok {
   279  			res.LinksReplaced[source] = map[uint16]struct{}{}
   280  		}
   281  
   282  		res.LinksReplaced[source][level] = struct{}{}
   283  	}
   284  
   285  	return 12 + int(length)*8, nil
   286  }
   287  
   288  func (d *Deserializer) ReadAddLinks(r io.Reader,
   289  	res *DeserializationResult,
   290  ) (int, error) {
   291  	d.resetResusableBuffer(12)
   292  	_, err := io.ReadFull(r, d.reusableBuffer)
   293  	if err != nil {
   294  		return 0, err
   295  	}
   296  
   297  	source := binary.LittleEndian.Uint64(d.reusableBuffer[0:8])
   298  	level := binary.LittleEndian.Uint16(d.reusableBuffer[8:10])
   299  	length := binary.LittleEndian.Uint16(d.reusableBuffer[10:12])
   300  
   301  	targets, err := d.readUint64Slice(r, int(length))
   302  	if err != nil {
   303  		return 0, err
   304  	}
   305  
   306  	newNodes, changed, err := growIndexToAccomodateNode(res.Nodes, source, d.logger)
   307  	if err != nil {
   308  		return 0, err
   309  	}
   310  
   311  	if changed {
   312  		res.Nodes = newNodes
   313  	}
   314  
   315  	if res.Nodes[int(source)] == nil {
   316  		res.Nodes[int(source)] = &vertex{id: source, connections: make([][]uint64, level+1)}
   317  	}
   318  
   319  	maybeGrowConnectionsForLevel(&res.Nodes[int(source)].connections, level)
   320  
   321  	res.Nodes[int(source)].connections[int(level)] = append(
   322  		res.Nodes[int(source)].connections[int(level)], targets...)
   323  
   324  	return 12 + int(length)*8, nil
   325  }
   326  
   327  func (d *Deserializer) ReadAddTombstone(r io.Reader, tombstones map[uint64]struct{}) error {
   328  	id, err := d.readUint64(r)
   329  	if err != nil {
   330  		return err
   331  	}
   332  
   333  	tombstones[id] = struct{}{}
   334  
   335  	return nil
   336  }
   337  
   338  func (d *Deserializer) ReadRemoveTombstone(r io.Reader, tombstones map[uint64]struct{}) error {
   339  	id, err := d.readUint64(r)
   340  	if err != nil {
   341  		return err
   342  	}
   343  
   344  	delete(tombstones, id)
   345  
   346  	return nil
   347  }
   348  
   349  func (d *Deserializer) ReadClearLinks(r io.Reader, res *DeserializationResult,
   350  	keepReplaceInfo bool,
   351  ) error {
   352  	id, err := d.readUint64(r)
   353  	if err != nil {
   354  		return err
   355  	}
   356  
   357  	newNodes, changed, err := growIndexToAccomodateNode(res.Nodes, id, d.logger)
   358  	if err != nil {
   359  		return err
   360  	}
   361  
   362  	if changed {
   363  		res.Nodes = newNodes
   364  	}
   365  
   366  	if res.Nodes[id] == nil {
   367  		// node has been deleted or never existed, nothing to do
   368  		return nil
   369  	}
   370  
   371  	res.Nodes[id].connections = make([][]uint64, len(res.Nodes[id].connections))
   372  	return nil
   373  }
   374  
   375  func (d *Deserializer) ReadClearLinksAtLevel(r io.Reader, res *DeserializationResult,
   376  	keepReplaceInfo bool,
   377  ) error {
   378  	id, err := d.readUint64(r)
   379  	if err != nil {
   380  		return err
   381  	}
   382  
   383  	level, err := d.readUint16(r)
   384  	if err != nil {
   385  		return err
   386  	}
   387  
   388  	newNodes, changed, err := growIndexToAccomodateNode(res.Nodes, id, d.logger)
   389  	if err != nil {
   390  		return err
   391  	}
   392  
   393  	if changed {
   394  		res.Nodes = newNodes
   395  	}
   396  
   397  	if keepReplaceInfo {
   398  		// mark the replace flag for this node and level, so that new commit logs
   399  		// generated on this result (condensing) do not lose information
   400  
   401  		if _, ok := res.LinksReplaced[id]; !ok {
   402  			res.LinksReplaced[id] = map[uint16]struct{}{}
   403  		}
   404  
   405  		res.LinksReplaced[id][level] = struct{}{}
   406  	}
   407  
   408  	if res.Nodes[id] == nil {
   409  		if !keepReplaceInfo {
   410  			// node has been deleted or never existed and we are not looking at a
   411  			// single log in isolation, nothing to do
   412  			return nil
   413  		}
   414  
   415  		// we need to keep the replace info, meaning we have to explicitly create
   416  		// this node in order to be able to store the "clear links" information for
   417  		// it
   418  		res.Nodes[id] = &vertex{
   419  			id:          id,
   420  			connections: make([][]uint64, level+1),
   421  		}
   422  	}
   423  
   424  	if res.Nodes[id].connections == nil {
   425  		res.Nodes[id].connections = make([][]uint64, level+1)
   426  	} else {
   427  		maybeGrowConnectionsForLevel(&res.Nodes[id].connections, level)
   428  		res.Nodes[id].connections[int(level)] = []uint64{}
   429  	}
   430  
   431  	if keepReplaceInfo {
   432  		// mark the replace flag for this node and level, so that new commit logs
   433  		// generated on this result (condensing) do not lose information
   434  
   435  		if _, ok := res.LinksReplaced[id]; !ok {
   436  			res.LinksReplaced[id] = map[uint16]struct{}{}
   437  		}
   438  
   439  		res.LinksReplaced[id][level] = struct{}{}
   440  	}
   441  
   442  	return nil
   443  }
   444  
   445  func (d *Deserializer) ReadDeleteNode(r io.Reader, res *DeserializationResult) error {
   446  	id, err := d.readUint64(r)
   447  	if err != nil {
   448  		return err
   449  	}
   450  
   451  	newNodes, changed, err := growIndexToAccomodateNode(res.Nodes, id, d.logger)
   452  	if err != nil {
   453  		return err
   454  	}
   455  
   456  	if changed {
   457  		res.Nodes = newNodes
   458  	}
   459  
   460  	res.Nodes[id] = nil
   461  	return nil
   462  }
   463  
   464  func (d *Deserializer) ReadTileEncoder(r io.Reader, res *DeserializationResult, i uint16) (compressionhelpers.PQEncoder, error) {
   465  	bins, err := d.readFloat64(r)
   466  	if err != nil {
   467  		return nil, err
   468  	}
   469  	mean, err := d.readFloat64(r)
   470  	if err != nil {
   471  		return nil, err
   472  	}
   473  	stdDev, err := d.readFloat64(r)
   474  	if err != nil {
   475  		return nil, err
   476  	}
   477  	size, err := d.readFloat64(r)
   478  	if err != nil {
   479  		return nil, err
   480  	}
   481  	s1, err := d.readFloat64(r)
   482  	if err != nil {
   483  		return nil, err
   484  	}
   485  	s2, err := d.readFloat64(r)
   486  	if err != nil {
   487  		return nil, err
   488  	}
   489  	segment, err := d.readUint16(r)
   490  	if err != nil {
   491  		return nil, err
   492  	}
   493  	encDistribution, err := d.readByte(r)
   494  	if err != nil {
   495  		return nil, err
   496  	}
   497  	return compressionhelpers.RestoreTileEncoder(bins, mean, stdDev, size, s1, s2, segment, encDistribution), nil
   498  }
   499  
   500  func (d *Deserializer) ReadKMeansEncoder(r io.Reader, res *DeserializationResult, i uint16) (compressionhelpers.PQEncoder, error) {
   501  	ds := int(res.PQData.Dimensions / res.PQData.M)
   502  	centers := make([][]float32, 0, res.PQData.Ks)
   503  	for k := uint16(0); k < res.PQData.Ks; k++ {
   504  		center := make([]float32, 0, ds)
   505  		for i := 0; i < ds; i++ {
   506  			c, err := d.readFloat32(r)
   507  			if err != nil {
   508  				return nil, err
   509  			}
   510  			center = append(center, c)
   511  		}
   512  		centers = append(centers, center)
   513  	}
   514  	kms := compressionhelpers.NewKMeansWithCenters(
   515  		int(res.PQData.Ks),
   516  		ds,
   517  		int(i),
   518  		centers,
   519  	)
   520  	return kms, nil
   521  }
   522  
   523  func (d *Deserializer) ReadPQ(r io.Reader, res *DeserializationResult) error {
   524  	dims, err := d.readUint16(r)
   525  	if err != nil {
   526  		return err
   527  	}
   528  	enc, err := d.readByte(r)
   529  	if err != nil {
   530  		return err
   531  	}
   532  	ks, err := d.readUint16(r)
   533  	if err != nil {
   534  		return err
   535  	}
   536  	m, err := d.readUint16(r)
   537  	if err != nil {
   538  		return err
   539  	}
   540  	dist, err := d.readByte(r)
   541  	if err != nil {
   542  		return err
   543  	}
   544  	useBitsEncoding, err := d.readByte(r)
   545  	if err != nil {
   546  		return err
   547  	}
   548  	encoder := compressionhelpers.Encoder(enc)
   549  	res.PQData = compressionhelpers.PQData{
   550  		Dimensions:          dims,
   551  		EncoderType:         encoder,
   552  		Ks:                  ks,
   553  		M:                   m,
   554  		EncoderDistribution: byte(dist),
   555  		UseBitsEncoding:     useBitsEncoding != 0,
   556  	}
   557  	var encoderReader func(io.Reader, *DeserializationResult, uint16) (compressionhelpers.PQEncoder, error)
   558  	switch encoder {
   559  	case compressionhelpers.UseTileEncoder:
   560  		encoderReader = d.ReadTileEncoder
   561  	case compressionhelpers.UseKMeansEncoder:
   562  		encoderReader = d.ReadKMeansEncoder
   563  	default:
   564  		return errors.New("Unsuported encoder type")
   565  	}
   566  	for i := uint16(0); i < m; i++ {
   567  		encoder, err := encoderReader(r, res, i)
   568  		if err != nil {
   569  			return err
   570  		}
   571  		res.PQData.Encoders = append(res.PQData.Encoders, encoder)
   572  	}
   573  	res.Compressed = true
   574  
   575  	return nil
   576  }
   577  
   578  func (d *Deserializer) readUint64(r io.Reader) (uint64, error) {
   579  	var value uint64
   580  	d.resetResusableBuffer(8)
   581  	_, err := io.ReadFull(r, d.reusableBuffer)
   582  	if err != nil {
   583  		return 0, errors.Wrap(err, "failed to read uint64")
   584  	}
   585  
   586  	value = binary.LittleEndian.Uint64(d.reusableBuffer)
   587  
   588  	return value, nil
   589  }
   590  
   591  func (d *Deserializer) readFloat64(r io.Reader) (float64, error) {
   592  	var value float64
   593  	d.resetResusableBuffer(8)
   594  	_, err := io.ReadFull(r, d.reusableBuffer)
   595  	if err != nil {
   596  		return 0, errors.Wrap(err, "failed to read float64")
   597  	}
   598  
   599  	bits := binary.LittleEndian.Uint64(d.reusableBuffer)
   600  	value = math.Float64frombits(bits)
   601  
   602  	return value, nil
   603  }
   604  
   605  func (d *Deserializer) readFloat32(r io.Reader) (float32, error) {
   606  	var value float32
   607  	d.resetResusableBuffer(4)
   608  	_, err := io.ReadFull(r, d.reusableBuffer)
   609  	if err != nil {
   610  		return 0, errors.Wrap(err, "failed to read float32")
   611  	}
   612  
   613  	bits := binary.LittleEndian.Uint32(d.reusableBuffer)
   614  	value = math.Float32frombits(bits)
   615  
   616  	return value, nil
   617  }
   618  
   619  func (d *Deserializer) readUint16(r io.Reader) (uint16, error) {
   620  	var value uint16
   621  	d.resetResusableBuffer(2)
   622  	_, err := io.ReadFull(r, d.reusableBuffer)
   623  	if err != nil {
   624  		return 0, errors.Wrap(err, "failed to read uint16")
   625  	}
   626  
   627  	value = binary.LittleEndian.Uint16(d.reusableBuffer)
   628  
   629  	return value, nil
   630  }
   631  
   632  func (d *Deserializer) readByte(r io.Reader) (byte, error) {
   633  	d.resetResusableBuffer(1)
   634  	_, err := io.ReadFull(r, d.reusableBuffer)
   635  	if err != nil {
   636  		return 0, errors.Wrap(err, "failed to read byte")
   637  	}
   638  
   639  	return d.reusableBuffer[0], nil
   640  }
   641  
   642  func (d *Deserializer) ReadCommitType(r io.Reader) (HnswCommitType, error) {
   643  	d.resetResusableBuffer(1)
   644  	if _, err := io.ReadFull(r, d.reusableBuffer); err != nil {
   645  		return 0, errors.Wrap(err, "failed to read commit type")
   646  	}
   647  
   648  	return HnswCommitType(d.reusableBuffer[0]), nil
   649  }
   650  
   651  func (d *Deserializer) readUint64Slice(r io.Reader, length int) ([]uint64, error) {
   652  	d.resetResusableBuffer(length * 8)
   653  	d.resetReusableConnectionsSlice(length)
   654  	_, err := io.ReadFull(r, d.reusableBuffer)
   655  	if err != nil {
   656  		return nil, errors.Wrap(err, "failed to read uint64 slice")
   657  	}
   658  
   659  	for i := range d.reusableConnectionsSlice {
   660  		d.reusableConnectionsSlice[i] = binary.LittleEndian.Uint64(d.reusableBuffer[i*8 : (i+1)*8])
   661  	}
   662  
   663  	return d.reusableConnectionsSlice, nil
   664  }
   665  
   666  // If the connections array is to small to contain the current target-levelit
   667  // will be grown. Otherwise, nothing happens.
   668  func maybeGrowConnectionsForLevel(connsPtr *[][]uint64, level uint16) {
   669  	conns := *connsPtr
   670  	if len(conns) <= int(level) {
   671  		// we need to grow the connections slice
   672  		newConns := make([][]uint64, level+1)
   673  		copy(newConns, conns)
   674  		*connsPtr = newConns
   675  	}
   676  }