github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-bind/vectorizer/vectorizer.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"context"
    16  
    17  	"github.com/pkg/errors"
    18  
    19  	"github.com/go-openapi/strfmt"
    20  	"github.com/weaviate/weaviate/entities/models"
    21  	"github.com/weaviate/weaviate/entities/moduletools"
    22  	"github.com/weaviate/weaviate/modules/multi2vec-bind/ent"
    23  	libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer"
    24  )
    25  
    26  type Vectorizer struct {
    27  	client Client
    28  }
    29  
    30  func New(client Client) *Vectorizer {
    31  	return &Vectorizer{
    32  		client: client,
    33  	}
    34  }
    35  
    36  type Client interface {
    37  	Vectorize(ctx context.Context,
    38  		texts, images, audio, video, imu, thermal, depth []string,
    39  	) (*ent.VectorizationResult, error)
    40  }
    41  
    42  type ClassSettings interface {
    43  	ImageField(property string) bool
    44  	ImageFieldsWeights() ([]float32, error)
    45  	TextField(property string) bool
    46  	TextFieldsWeights() ([]float32, error)
    47  	AudioField(property string) bool
    48  	AudioFieldsWeights() ([]float32, error)
    49  	VideoField(property string) bool
    50  	VideoFieldsWeights() ([]float32, error)
    51  	IMUField(property string) bool
    52  	IMUFieldsWeights() ([]float32, error)
    53  	ThermalField(property string) bool
    54  	ThermalFieldsWeights() ([]float32, error)
    55  	DepthField(property string) bool
    56  	DepthFieldsWeights() ([]float32, error)
    57  }
    58  
    59  func (v *Vectorizer) Object(ctx context.Context, object *models.Object,
    60  	comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig,
    61  ) ([]float32, models.AdditionalProperties, error) {
    62  	vec, err := v.object(ctx, object.ID, comp, cfg)
    63  	return vec, nil, err
    64  }
    65  
    66  func (v *Vectorizer) VectorizeImage(ctx context.Context, id, image string, cfg moduletools.ClassConfig) ([]float32, error) {
    67  	res, err := v.client.Vectorize(ctx, nil, []string{image}, nil, nil, nil, nil, nil)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  	return v.getVector(res.ImageVectors)
    72  }
    73  
    74  func (v *Vectorizer) VectorizeAudio(ctx context.Context, audio string, cfg moduletools.ClassConfig) ([]float32, error) {
    75  	res, err := v.client.Vectorize(ctx, nil, nil, []string{audio}, nil, nil, nil, nil)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return v.getVector(res.AudioVectors)
    80  }
    81  
    82  func (v *Vectorizer) VectorizeVideo(ctx context.Context, video string, cfg moduletools.ClassConfig) ([]float32, error) {
    83  	res, err := v.client.Vectorize(ctx, nil, nil, nil, []string{video}, nil, nil, nil)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	return v.getVector(res.VideoVectors)
    88  }
    89  
    90  func (v *Vectorizer) VectorizeIMU(ctx context.Context, imu string, cfg moduletools.ClassConfig) ([]float32, error) {
    91  	res, err := v.client.Vectorize(ctx, nil, nil, nil, nil, []string{imu}, nil, nil)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	return v.getVector(res.IMUVectors)
    96  }
    97  
    98  func (v *Vectorizer) VectorizeThermal(ctx context.Context, thermal string, cfg moduletools.ClassConfig) ([]float32, error) {
    99  	res, err := v.client.Vectorize(ctx, nil, nil, nil, nil, nil, []string{thermal}, nil)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	return v.getVector(res.ThermalVectors)
   104  }
   105  
   106  func (v *Vectorizer) VectorizeDepth(ctx context.Context, depth string, cfg moduletools.ClassConfig) ([]float32, error) {
   107  	res, err := v.client.Vectorize(ctx, nil, nil, nil, nil, nil, nil, []string{depth})
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	return v.getVector(res.DepthVectors)
   112  }
   113  
   114  func (v *Vectorizer) getVector(vectors [][]float32) ([]float32, error) {
   115  	if len(vectors) != 1 {
   116  		return nil, errors.New("empty vector")
   117  	}
   118  	return vectors[0], nil
   119  }
   120  
   121  func (v *Vectorizer) object(ctx context.Context, id strfmt.UUID,
   122  	comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig,
   123  ) ([]float32, error) {
   124  	icheck := NewClassSettings(cfg)
   125  	prevVector := comp.PrevVector()
   126  	if cfg.TargetVector() != "" {
   127  		prevVector = comp.PrevVectorForName(cfg.TargetVector())
   128  	}
   129  
   130  	vectorize := prevVector == nil
   131  
   132  	// vectorize image and text
   133  	var texts, images, audio, video, imu, thermal, depth []string
   134  
   135  	it := comp.PropsIterator()
   136  	for propName, propValue, ok := it.Next(); ok; propName, propValue, ok = it.Next() {
   137  		switch typed := propValue.(type) {
   138  		case string:
   139  			if icheck.ImageField(propName) {
   140  				vectorize = vectorize || comp.IsChanged(propName)
   141  				images = append(images, typed)
   142  			}
   143  			if icheck.TextField(propName) {
   144  				vectorize = vectorize || comp.IsChanged(propName)
   145  				texts = append(texts, typed)
   146  			}
   147  			if icheck.AudioField(propName) {
   148  				vectorize = vectorize || comp.IsChanged(propName)
   149  				audio = append(audio, typed)
   150  			}
   151  			if icheck.VideoField(propName) {
   152  				vectorize = vectorize || comp.IsChanged(propName)
   153  				video = append(video, typed)
   154  			}
   155  			if icheck.IMUField(propName) {
   156  				vectorize = vectorize || comp.IsChanged(propName)
   157  				imu = append(imu, typed)
   158  			}
   159  			if icheck.ThermalField(propName) {
   160  				vectorize = vectorize || comp.IsChanged(propName)
   161  				thermal = append(thermal, typed)
   162  			}
   163  			if icheck.DepthField(propName) {
   164  				vectorize = vectorize || comp.IsChanged(propName)
   165  				depth = append(depth, typed)
   166  			}
   167  
   168  		case []string:
   169  			if icheck.TextField(propName) {
   170  				vectorize = vectorize || comp.IsChanged(propName)
   171  				texts = append(texts, typed...)
   172  			}
   173  
   174  		case nil:
   175  			if icheck.ImageField(propName) || icheck.TextField(propName) ||
   176  				icheck.AudioField(propName) || icheck.VideoField(propName) ||
   177  				icheck.IMUField(propName) || icheck.ThermalField(propName) ||
   178  				icheck.DepthField(propName) {
   179  				vectorize = vectorize || comp.IsChanged(propName)
   180  			}
   181  		}
   182  	}
   183  
   184  	// no property was changed, old vector can be used
   185  	if !vectorize {
   186  		return prevVector, nil
   187  	}
   188  
   189  	vectors := [][]float32{}
   190  	if len(texts) > 0 || len(images) > 0 || len(audio) > 0 || len(video) > 0 ||
   191  		len(imu) > 0 || len(thermal) > 0 || len(depth) > 0 {
   192  		res, err := v.client.Vectorize(ctx, texts, images, audio, video, imu, thermal, depth)
   193  		if err != nil {
   194  			return nil, err
   195  		}
   196  		vectors = append(vectors, res.TextVectors...)
   197  		vectors = append(vectors, res.ImageVectors...)
   198  		vectors = append(vectors, res.AudioVectors...)
   199  		vectors = append(vectors, res.VideoVectors...)
   200  		vectors = append(vectors, res.IMUVectors...)
   201  		vectors = append(vectors, res.ThermalVectors...)
   202  		vectors = append(vectors, res.DepthVectors...)
   203  	}
   204  	weights, err := v.getWeights(icheck)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	return libvectorizer.CombineVectorsWithWeights(vectors, weights), nil
   210  }
   211  
   212  func (v *Vectorizer) getWeights(ichek ClassSettings) ([]float32, error) {
   213  	weights := []float32{}
   214  	textFieldsWeights, err := ichek.TextFieldsWeights()
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  	imageFieldsWeights, err := ichek.ImageFieldsWeights()
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	audioFieldsWeights, err := ichek.AudioFieldsWeights()
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	videoFieldsWeights, err := ichek.VideoFieldsWeights()
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  	imuFieldsWeights, err := ichek.IMUFieldsWeights()
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  	thermalFieldsWeights, err := ichek.ThermalFieldsWeights()
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	depthFieldsWeights, err := ichek.DepthFieldsWeights()
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	weights = append(weights, textFieldsWeights...)
   244  	weights = append(weights, imageFieldsWeights...)
   245  	weights = append(weights, audioFieldsWeights...)
   246  	weights = append(weights, videoFieldsWeights...)
   247  	weights = append(weights, imuFieldsWeights...)
   248  	weights = append(weights, thermalFieldsWeights...)
   249  	weights = append(weights, depthFieldsWeights...)
   250  
   251  	normalizedWeights := v.normalizeWeights(weights)
   252  
   253  	return normalizedWeights, nil
   254  }
   255  
   256  func (v *Vectorizer) normalizeWeights(weights []float32) []float32 {
   257  	if len(weights) > 0 {
   258  		var denominator float32
   259  		for i := range weights {
   260  			denominator += weights[i]
   261  		}
   262  		normalizer := 1 / denominator
   263  		normalized := make([]float32, len(weights))
   264  		for i := range weights {
   265  			normalized[i] = weights[i] * normalizer
   266  		}
   267  		return normalized
   268  	}
   269  	return nil
   270  }