github.com/weaviate/weaviate@v1.24.6/modules/ref2vec-centroid/vectorizer/vectorizer.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  
    18  	"github.com/go-openapi/strfmt"
    19  	"github.com/weaviate/weaviate/entities/additional"
    20  	"github.com/weaviate/weaviate/entities/models"
    21  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    22  	"github.com/weaviate/weaviate/entities/moduletools"
    23  	"github.com/weaviate/weaviate/entities/schema/crossref"
    24  	"github.com/weaviate/weaviate/entities/search"
    25  	"github.com/weaviate/weaviate/modules/ref2vec-centroid/config"
    26  )
    27  
    28  type calcFn func(vecs ...[]float32) ([]float32, error)
    29  
    30  type Vectorizer struct {
    31  	config       *config.Config
    32  	calcFn       calcFn
    33  	findObjectFn modulecapabilities.FindObjectFn
    34  }
    35  
    36  func New(cfg moduletools.ClassConfig, findFn modulecapabilities.FindObjectFn) *Vectorizer {
    37  	v := &Vectorizer{
    38  		config:       config.New(cfg),
    39  		findObjectFn: findFn,
    40  	}
    41  
    42  	switch v.config.CalculationMethod() {
    43  	case config.MethodMean:
    44  		v.calcFn = calculateMean
    45  	default:
    46  		v.calcFn = calculateMean
    47  	}
    48  
    49  	return v
    50  }
    51  
    52  func (v *Vectorizer) Object(ctx context.Context, obj *models.Object) ([]float32, error) {
    53  	props := v.config.ReferenceProperties()
    54  
    55  	refVecs, err := v.referenceVectorSearch(ctx, obj, props)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	if len(refVecs) == 0 {
    61  		obj.Vector = nil
    62  		return nil, nil
    63  	}
    64  
    65  	vec, err := v.calcFn(refVecs...)
    66  	if err != nil {
    67  		return nil, fmt.Errorf("calculate vector: %w", err)
    68  	}
    69  
    70  	return vec, nil
    71  }
    72  
    73  func (v *Vectorizer) referenceVectorSearch(ctx context.Context,
    74  	obj *models.Object, refProps map[string]struct{},
    75  ) ([][]float32, error) {
    76  	var refVecs [][]float32
    77  	props := obj.Properties.(map[string]interface{})
    78  
    79  	// use the ids from parent's beacons to find the referenced objects
    80  	beacons := beaconsForVectorization(props, refProps)
    81  	for _, beacon := range beacons {
    82  		res, err := v.findReferenceObject(ctx, beacon, obj.Tenant)
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  
    87  		// if the ref'd object has a vector, we grab it.
    88  		// these will be used to compute the parent's
    89  		// vector eventually
    90  		if res.Vector != nil {
    91  			refVecs = append(refVecs, res.Vector)
    92  		}
    93  	}
    94  
    95  	return refVecs, nil
    96  }
    97  
    98  func (v *Vectorizer) findReferenceObject(ctx context.Context, beacon strfmt.URI, tenant string) (res *search.Result, err error) {
    99  	ref, err := crossref.Parse(beacon.String())
   100  	if err != nil {
   101  		return nil, fmt.Errorf("parse beacon %q: %w", beacon, err)
   102  	}
   103  
   104  	res, err = v.findObjectFn(ctx, ref.Class, ref.TargetID,
   105  		search.SelectProperties{}, additional.Properties{}, tenant)
   106  	if err != nil || res == nil {
   107  		if err == nil {
   108  			err = fmt.Errorf("not found")
   109  		}
   110  		err = fmt.Errorf("find object with beacon %q': %w", beacon, err)
   111  	}
   112  	return
   113  }
   114  
   115  func beaconsForVectorization(allProps map[string]interface{},
   116  	targetRefProps map[string]struct{},
   117  ) []strfmt.URI {
   118  	var beacons []strfmt.URI
   119  
   120  	// add any refs that were supplied as a part of the parent
   121  	// object, like when caller is AddObject/UpdateObject
   122  	for prop, val := range allProps {
   123  		if _, ok := targetRefProps[prop]; ok {
   124  			switch refs := val.(type) {
   125  			case []interface{}:
   126  				// due to the fix introduced in https://github.com/weaviate/weaviate/pull/2320,
   127  				// MultipleRef's can appear as empty []interface{} when no actual refs are provided for
   128  				// an object's reference property.
   129  				//
   130  				// if we encounter []interface{}, assume it indicates an empty ref prop, and skip it.
   131  				continue
   132  			case models.MultipleRef:
   133  				for _, ref := range refs {
   134  					beacons = append(beacons, ref.Beacon)
   135  				}
   136  			}
   137  		}
   138  	}
   139  
   140  	return beacons
   141  }