github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/additional/sempath/builder.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package sempath
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"math"
    18  	"sort"
    19  	"time"
    20  
    21  	"github.com/danaugrs/go-tsne/tsne"
    22  	"github.com/pkg/errors"
    23  	"github.com/tailor-inc/graphql/language/ast"
    24  	"github.com/weaviate/weaviate/entities/models"
    25  	"github.com/weaviate/weaviate/entities/moduletools"
    26  	"github.com/weaviate/weaviate/entities/search"
    27  	txt2vecmodels "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/models"
    28  	"gonum.org/v1/gonum/mat"
    29  )
    30  
    31  func New(c11y Remote) *PathBuilder {
    32  	return &PathBuilder{
    33  		fixedSeed: time.Now().UnixNano(),
    34  		c11y:      c11y,
    35  	}
    36  }
    37  
    38  type PathBuilder struct {
    39  	fixedSeed int64
    40  	c11y      Remote
    41  }
    42  
    43  type Remote interface {
    44  	MultiNearestWordsByVector(ctx context.Context, vectors [][]float32, k, n int) ([]*txt2vecmodels.NearestNeighbors, error)
    45  }
    46  
    47  func (pb *PathBuilder) AdditionalPropertyDefaultValue() interface{} {
    48  	return &Params{}
    49  }
    50  
    51  func (pb *PathBuilder) AdditionalPropertyFn(ctx context.Context,
    52  	in []search.Result, params interface{}, limit *int,
    53  	argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
    54  ) ([]search.Result, error) {
    55  	if parameters, ok := params.(*Params); ok {
    56  		return pb.CalculatePath(in, parameters)
    57  	}
    58  	return nil, errors.New("unknown params")
    59  }
    60  
    61  func (pb *PathBuilder) ExtractAdditionalFn(param []*ast.Argument) interface{} {
    62  	return &Params{}
    63  }
    64  
    65  func (pb *PathBuilder) CalculatePath(in []search.Result, params *Params) ([]search.Result, error) {
    66  	if len(in) == 0 {
    67  		return nil, nil
    68  	}
    69  
    70  	if params == nil {
    71  		return nil, fmt.Errorf("no params provided")
    72  	}
    73  
    74  	dims := len(in[0].Vector)
    75  	if err := params.SetDefaultsAndValidate(len(in), dims); err != nil {
    76  		return nil, errors.Wrap(err, "invalid params")
    77  	}
    78  
    79  	searchNeighbors, err := pb.addSearchNeighbors(params)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	for i, obj := range in {
    85  		path, err := pb.calculatePathPerObject(obj, in, params, searchNeighbors)
    86  		if err != nil {
    87  			return nil, fmt.Errorf("object %d: %v", i, err)
    88  		}
    89  
    90  		if in[i].AdditionalProperties == nil {
    91  			in[i].AdditionalProperties = models.AdditionalProperties{}
    92  		}
    93  
    94  		in[i].AdditionalProperties["semanticPath"] = path
    95  	}
    96  
    97  	return in, nil
    98  }
    99  
   100  func (pb *PathBuilder) calculatePathPerObject(obj search.Result, allObjects []search.Result, params *Params,
   101  	searchNeighbors []*txt2vecmodels.NearestNeighbor,
   102  ) (*txt2vecmodels.SemanticPath, error) {
   103  	dims := len(obj.Vector)
   104  	matrix, neighbors, err := pb.vectorsToMatrix(obj, allObjects, dims, params, searchNeighbors)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	inputRows := matrix.RawMatrix().Rows
   110  	t := tsne.NewTSNE(2, float64(inputRows/2), 100, 100, false)
   111  	res := t.EmbedData(matrix, nil)
   112  	rows, cols := res.Dims()
   113  	if rows != inputRows {
   114  		return nil, fmt.Errorf("have different output results than input %d != %d", inputRows, rows)
   115  	}
   116  
   117  	// create an explicit copy of the neighbors, so we don't mutate them.
   118  	// Otherwise the 2nd round will have been influenced by the first
   119  	projectedNeighbors := copyNeighbors(neighbors)
   120  	var projectedSearchVector []float32
   121  	var projectedTargetVector []float32
   122  	for i := 0; i < rows; i++ {
   123  		vector := make([]float32, cols)
   124  		for j := range vector {
   125  			vector[j] = float32(res.At(i, j))
   126  		}
   127  		if i == 0 { // the input object
   128  			projectedTargetVector = vector
   129  		} else if i < 1+len(neighbors) {
   130  			// these must be neighbor props
   131  			projectedNeighbors[i-1].Vector = vector
   132  		} else {
   133  			// is now the very last element which is the search vector
   134  			projectedSearchVector = vector
   135  		}
   136  	}
   137  
   138  	path := pb.buildPath(projectedNeighbors, projectedSearchVector, projectedTargetVector)
   139  	return pb.addDistancesToPath(path, neighbors, params.SearchVector, obj.Vector)
   140  }
   141  
   142  func (pb *PathBuilder) addSearchNeighbors(params *Params) ([]*txt2vecmodels.NearestNeighbor, error) {
   143  	nn, err := pb.c11y.MultiNearestWordsByVector(context.TODO(), [][]float32{params.SearchVector}, 36, 50)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	return nn[0].Neighbors, nil
   149  }
   150  
   151  // TODO: document behavior if it actually stays like this
   152  func (pb *PathBuilder) vectorsToMatrix(obj search.Result, allObjects []search.Result, dims int,
   153  	params *Params, searchNeighbors []*txt2vecmodels.NearestNeighbor,
   154  ) (*mat.Dense, []*txt2vecmodels.NearestNeighbor, error) {
   155  	items := 1 // the initial object
   156  	var neighbors []*txt2vecmodels.NearestNeighbor
   157  	neighbors = pb.extractNeighbors(allObjects)
   158  	neighbors = append(neighbors, searchNeighbors...)
   159  	neighbors = pb.removeDuplicateNeighborsAndDollarNeighbors(neighbors)
   160  	items += len(neighbors) + 1 // The +1 is for the search vector which we append last
   161  
   162  	// concat all vectors to build gonum dense matrix
   163  	mergedVectors := make([]float64, items*dims)
   164  	if l := len(obj.Vector); l != dims {
   165  		return nil, nil, fmt.Errorf("object: inconsistent vector lengths found: dimensions=%d and object=%d", dims, l)
   166  	}
   167  
   168  	for j, dim := range obj.Vector {
   169  		mergedVectors[j] = float64(dim)
   170  	}
   171  
   172  	withoutNeighbors := 1 * dims
   173  	for i, neighbor := range neighbors {
   174  		neighborVector := neighbor.Vector
   175  
   176  		if l := len(neighborVector); l != dims {
   177  			return nil, nil, fmt.Errorf("neighbor: inconsistent vector lengths found: dimensions=%d and object=%d", dims, l)
   178  		}
   179  
   180  		for j, dim := range neighborVector {
   181  			mergedVectors[withoutNeighbors+i*dims+j] = float64(dim)
   182  		}
   183  	}
   184  
   185  	for i, dim := range params.SearchVector {
   186  		mergedVectors[len(mergedVectors)-dims+i] = float64(dim)
   187  	}
   188  
   189  	return mat.NewDense(items, dims, mergedVectors), neighbors, nil
   190  }
   191  
   192  func (pb *PathBuilder) extractNeighbors(in []search.Result) []*txt2vecmodels.NearestNeighbor {
   193  	var out []*txt2vecmodels.NearestNeighbor
   194  
   195  	for _, obj := range in {
   196  		if obj.AdditionalProperties == nil || obj.AdditionalProperties["nearestNeighbors"] == nil {
   197  			continue
   198  		}
   199  
   200  		if neighbors, ok := obj.AdditionalProperties["nearestNeighbors"]; ok {
   201  			if nearestNeighbors, ok := neighbors.(*txt2vecmodels.NearestNeighbors); ok {
   202  				out = append(out, nearestNeighbors.Neighbors...)
   203  			}
   204  		}
   205  	}
   206  
   207  	return out
   208  }
   209  
   210  func (pb *PathBuilder) removeDuplicateNeighborsAndDollarNeighbors(in []*txt2vecmodels.NearestNeighbor) []*txt2vecmodels.NearestNeighbor {
   211  	seen := map[string]struct{}{}
   212  	out := make([]*txt2vecmodels.NearestNeighbor, len(in))
   213  
   214  	i := 0
   215  	for _, candidate := range in {
   216  		if _, ok := seen[candidate.Concept]; ok {
   217  			continue
   218  		}
   219  
   220  		if candidate.Concept[0] == '$' {
   221  			continue
   222  		}
   223  
   224  		out[i] = candidate
   225  		i++
   226  		seen[candidate.Concept] = struct{}{}
   227  	}
   228  
   229  	return out[:i]
   230  }
   231  
   232  func (pb *PathBuilder) buildPath(neighbors []*txt2vecmodels.NearestNeighbor, searchVector []float32,
   233  	target []float32,
   234  ) *txt2vecmodels.SemanticPath {
   235  	var path []*txt2vecmodels.SemanticPathElement
   236  
   237  	minDist := float32(math.MaxFloat32)
   238  
   239  	current := searchVector // initial search point
   240  
   241  	for {
   242  		nn := pb.nearestNeighbors(current, neighbors, 10)
   243  		nn = pb.discardFurtherThan(nn, minDist, target)
   244  		if len(nn) == 0 {
   245  			break
   246  		}
   247  		nn = pb.nearestNeighbors(current, nn, 1)
   248  		current = nn[0].Vector
   249  		minDist = pb.distance(current, target)
   250  
   251  		path = append(path, &txt2vecmodels.SemanticPathElement{
   252  			Concept: nn[0].Concept,
   253  		})
   254  	}
   255  
   256  	return &txt2vecmodels.SemanticPath{
   257  		Path: path,
   258  	}
   259  }
   260  
   261  func (pb *PathBuilder) nearestNeighbors(search []float32, candidates []*txt2vecmodels.NearestNeighbor, length int) []*txt2vecmodels.NearestNeighbor {
   262  	sort.Slice(candidates, func(a, b int) bool {
   263  		return pb.distance(candidates[a].Vector, search) < pb.distance(candidates[b].Vector, search)
   264  	})
   265  	return candidates[:length]
   266  }
   267  
   268  func (pb *PathBuilder) distance(a, b []float32) float32 {
   269  	var sums float32
   270  	for i := range a {
   271  		sums += (a[i] - b[i]) * (a[i] - b[i])
   272  	}
   273  
   274  	return float32(math.Sqrt(float64(sums)))
   275  }
   276  
   277  func (pb *PathBuilder) discardFurtherThan(candidates []*txt2vecmodels.NearestNeighbor, threshold float32, target []float32) []*txt2vecmodels.NearestNeighbor {
   278  	out := make([]*txt2vecmodels.NearestNeighbor, len(candidates))
   279  	i := 0
   280  	for _, c := range candidates {
   281  		if pb.distance(c.Vector, target) >= threshold {
   282  			continue
   283  		}
   284  
   285  		out[i] = c
   286  		i++
   287  	}
   288  
   289  	return out[:i]
   290  }
   291  
   292  // create an explicit deep copy that does not keep any references
   293  func copyNeighbors(in []*txt2vecmodels.NearestNeighbor) []*txt2vecmodels.NearestNeighbor {
   294  	out := make([]*txt2vecmodels.NearestNeighbor, len(in))
   295  	for i, n := range in {
   296  		out[i] = &txt2vecmodels.NearestNeighbor{
   297  			Concept:  n.Concept,
   298  			Distance: n.Distance,
   299  			Vector:   n.Vector,
   300  		}
   301  	}
   302  
   303  	return out
   304  }
   305  
   306  func (pb *PathBuilder) addDistancesToPath(path *txt2vecmodels.SemanticPath, neighbors []*txt2vecmodels.NearestNeighbor,
   307  	searchVector, targetVector []float32,
   308  ) (*txt2vecmodels.SemanticPath, error) {
   309  	for i, elem := range path.Path {
   310  		vec, ok := neighborVecByConcept(neighbors, elem.Concept)
   311  		if !ok {
   312  			return nil, fmt.Errorf("no vector present for concept: %s", elem.Concept)
   313  		}
   314  
   315  		if i != 0 {
   316  			// include previous
   317  			previousVec, ok := neighborVecByConcept(neighbors, path.Path[i-1].Concept)
   318  			if !ok {
   319  				return nil, fmt.Errorf("no vector present for previous concept: %s", path.Path[i-1].Concept)
   320  			}
   321  
   322  			d, err := cosineDist(vec, previousVec)
   323  			if err != nil {
   324  				return nil, errors.Wrap(err, "calculate distance between current path and previous element")
   325  			}
   326  
   327  			path.Path[i].DistanceToPrevious = &d
   328  		}
   329  
   330  		// target
   331  		d, err := cosineDist(vec, targetVector)
   332  		if err != nil {
   333  			return nil, errors.Wrap(err, "calculate distance between current path and result element")
   334  		}
   335  		path.Path[i].DistanceToResult = d
   336  
   337  		// query
   338  		d, err = cosineDist(vec, searchVector)
   339  		if err != nil {
   340  			return nil, errors.Wrap(err, "calculate distance between current path and query element")
   341  		}
   342  		path.Path[i].DistanceToQuery = d
   343  
   344  		if i != len(path.Path)-1 {
   345  			// include next
   346  			nextVec, ok := neighborVecByConcept(neighbors, path.Path[i+1].Concept)
   347  			if !ok {
   348  				return nil, fmt.Errorf("no vector present for next concept: %s", path.Path[i+1].Concept)
   349  			}
   350  
   351  			d, err := cosineDist(vec, nextVec)
   352  			if err != nil {
   353  				return nil, errors.Wrap(err, "calculate distance between current path and next element")
   354  			}
   355  
   356  			path.Path[i].DistanceToNext = &d
   357  		}
   358  	}
   359  
   360  	return path, nil
   361  }
   362  
   363  func neighborVecByConcept(neighbors []*txt2vecmodels.NearestNeighbor, concept string) ([]float32, bool) {
   364  	for _, n := range neighbors {
   365  		if n.Concept == concept {
   366  			return n.Vector, true
   367  		}
   368  	}
   369  
   370  	return nil, false
   371  }
   372  
   373  func cosineSim(a, b []float32) (float32, error) {
   374  	if len(a) != len(b) {
   375  		return 0, fmt.Errorf("vectors have different dimensions")
   376  	}
   377  
   378  	var (
   379  		sumProduct float64
   380  		sumASquare float64
   381  		sumBSquare float64
   382  	)
   383  
   384  	for i := range a {
   385  		sumProduct += float64(a[i] * b[i])
   386  		sumASquare += float64(a[i] * a[i])
   387  		sumBSquare += float64(b[i] * b[i])
   388  	}
   389  
   390  	return float32(sumProduct / (math.Sqrt(sumASquare) * math.Sqrt(sumBSquare))), nil
   391  }
   392  
   393  func cosineDist(a, b []float32) (float32, error) {
   394  	sim, err := cosineSim(a, b)
   395  	if err != nil {
   396  		return 0, err
   397  	}
   398  
   399  	return 1 - sim, nil
   400  }