github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-palm/vectorizer/vectorizer.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package vectorizer 13 14 import ( 15 "context" 16 17 "github.com/pkg/errors" 18 19 "github.com/go-openapi/strfmt" 20 "github.com/weaviate/weaviate/entities/models" 21 "github.com/weaviate/weaviate/entities/moduletools" 22 "github.com/weaviate/weaviate/modules/multi2vec-palm/ent" 23 objectsvectorizer "github.com/weaviate/weaviate/usecases/modulecomponents/vectorizer" 24 libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer" 25 ) 26 27 type Vectorizer struct { 28 client Client 29 objectVectorizer *objectsvectorizer.ObjectVectorizer 30 } 31 32 func New(client Client) *Vectorizer { 33 return &Vectorizer{ 34 client: client, 35 objectVectorizer: objectsvectorizer.New(), 36 } 37 } 38 39 type Client interface { 40 Vectorize(ctx context.Context, 41 texts, images, videos []string, config ent.VectorizationConfig) (*ent.VectorizationResult, error) 42 } 43 44 type ClassSettings interface { 45 ImageField(property string) bool 46 ImageFieldsWeights() ([]float32, error) 47 TextField(property string) bool 48 TextFieldsWeights() ([]float32, error) 49 VideoField(property string) bool 50 VideoFieldsWeights() ([]float32, error) 51 } 52 53 func (v *Vectorizer) Object(ctx context.Context, object *models.Object, 54 comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig, 55 ) ([]float32, models.AdditionalProperties, error) { 56 vec, err := v.object(ctx, object.ID, comp, cfg) 57 return vec, nil, err 58 } 59 60 func (v *Vectorizer) VectorizeImage(ctx context.Context, id, image string, cfg moduletools.ClassConfig) ([]float32, error) { 61 res, err := v.client.Vectorize(ctx, nil, []string{image}, nil, v.getVectorizationConfig(cfg)) 62 if err != nil { 63 return nil, err 64 } 65 if len(res.ImageVectors) != 1 { 66 return nil, errors.New("empty vector") 67 } 68 69 return res.ImageVectors[0], nil 70 } 71 72 func (v *Vectorizer) VectorizeVideo(ctx context.Context, 73 video string, cfg moduletools.ClassConfig, 74 ) ([]float32, error) { 75 res, err := v.client.Vectorize(ctx, nil, nil, []string{video}, v.getVectorizationConfig(cfg)) 76 if err != nil { 77 return nil, err 78 } 79 if len(res.VideoVectors) != 1 { 80 return nil, errors.New("empty vector") 81 } 82 83 return res.VideoVectors[0], nil 84 } 85 86 func (v *Vectorizer) object(ctx context.Context, id strfmt.UUID, 87 comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig, 88 ) ([]float32, error) { 89 ichek := NewClassSettings(cfg) 90 prevVector := comp.PrevVector() 91 if cfg.TargetVector() != "" { 92 prevVector = comp.PrevVectorForName(cfg.TargetVector()) 93 } 94 95 vectorize := prevVector == nil 96 97 // vectorize image and text 98 texts := []string{} 99 images := []string{} 100 videos := []string{} 101 102 it := comp.PropsIterator() 103 for propName, propValue, ok := it.Next(); ok; propName, propValue, ok = it.Next() { 104 switch typed := propValue.(type) { 105 case string: 106 if ichek.ImageField(propName) { 107 vectorize = vectorize || comp.IsChanged(propName) 108 images = append(images, typed) 109 } 110 if ichek.TextField(propName) { 111 vectorize = vectorize || comp.IsChanged(propName) 112 texts = append(texts, typed) 113 } 114 if ichek.VideoField(propName) { 115 vectorize = vectorize || comp.IsChanged(propName) 116 videos = append(videos, typed) 117 } 118 119 case []string: 120 if ichek.TextField(propName) { 121 vectorize = vectorize || comp.IsChanged(propName) 122 texts = append(texts, typed...) 123 } 124 125 case nil: 126 if ichek.ImageField(propName) || ichek.TextField(propName) || ichek.VideoField(propName) { 127 vectorize = vectorize || comp.IsChanged(propName) 128 } 129 } 130 } 131 132 // no property was changed, old vector can be used 133 if !vectorize { 134 return prevVector, nil 135 } 136 137 vectors := [][]float32{} 138 if len(texts) > 0 || len(images) > 0 || len(videos) > 0 { 139 res, err := v.client.Vectorize(ctx, texts, images, videos, v.getVectorizationConfig(cfg)) 140 if err != nil { 141 return nil, err 142 } 143 vectors = append(vectors, res.TextVectors...) 144 vectors = append(vectors, res.ImageVectors...) 145 vectors = append(vectors, res.VideoVectors...) 146 } 147 weights, err := v.getWeights(ichek) 148 if err != nil { 149 return nil, err 150 } 151 152 return libvectorizer.CombineVectorsWithWeights(vectors, weights), nil 153 } 154 155 func (v *Vectorizer) getWeights(ichek ClassSettings) ([]float32, error) { 156 weights := []float32{} 157 textFieldsWeights, err := ichek.TextFieldsWeights() 158 if err != nil { 159 return nil, err 160 } 161 imageFieldsWeights, err := ichek.ImageFieldsWeights() 162 if err != nil { 163 return nil, err 164 } 165 videoFieldsWeights, err := ichek.VideoFieldsWeights() 166 if err != nil { 167 return nil, err 168 } 169 170 weights = append(weights, textFieldsWeights...) 171 weights = append(weights, imageFieldsWeights...) 172 weights = append(weights, videoFieldsWeights...) 173 174 normalizedWeights := v.normalizeWeights(weights) 175 176 return normalizedWeights, nil 177 } 178 179 func (v *Vectorizer) normalizeWeights(weights []float32) []float32 { 180 if len(weights) > 0 { 181 var denominator float32 182 for i := range weights { 183 denominator += weights[i] 184 } 185 normalizer := 1 / denominator 186 normalized := make([]float32, len(weights)) 187 for i := range weights { 188 normalized[i] = weights[i] * normalizer 189 } 190 return normalized 191 } 192 return nil 193 } 194 195 func (v *Vectorizer) getVectorizationConfig(cfg moduletools.ClassConfig) ent.VectorizationConfig { 196 settings := NewClassSettings(cfg) 197 return ent.VectorizationConfig{ 198 Location: settings.Location(), 199 ProjectID: settings.ProjectID(), 200 Model: settings.ModelID(), 201 Dimensions: settings.Dimensions(), 202 VideoIntervalSeconds: settings.VideoIntervalSeconds(), 203 } 204 }