github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/client/contextionary.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package client 13 14 import ( 15 "context" 16 "fmt" 17 "strings" 18 "time" 19 20 "github.com/pkg/errors" 21 "github.com/sirupsen/logrus" 22 pb "github.com/weaviate/contextionary/contextionary" 23 "github.com/weaviate/weaviate/entities/models" 24 txt2vecmodels "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/models" 25 "github.com/weaviate/weaviate/modules/text2vec-contextionary/vectorizer" 26 "github.com/weaviate/weaviate/usecases/traverser" 27 "google.golang.org/grpc" 28 "google.golang.org/grpc/codes" 29 "google.golang.org/grpc/credentials/insecure" 30 "google.golang.org/grpc/status" 31 ) 32 33 const ModelUncontactable = "module uncontactable" 34 35 // Client establishes a gRPC connection to a remote contextionary service 36 type Client struct { 37 grpcClient pb.ContextionaryClient 38 logger logrus.FieldLogger 39 } 40 41 // NewClient from gRPC discovery url to connect to a remote contextionary service 42 func NewClient(uri string, logger logrus.FieldLogger) (*Client, error) { 43 conn, err := grpc.Dial(uri, 44 grpc.WithTransportCredentials(insecure.NewCredentials()), 45 grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(1024*1024*48))) 46 if err != nil { 47 return nil, fmt.Errorf("couldn't connect to remote contextionary gRPC server: %s", err) 48 } 49 50 client := pb.NewContextionaryClient(conn) 51 return &Client{ 52 grpcClient: client, 53 logger: logger, 54 }, nil 55 } 56 57 // IsStopWord returns true if the given word is a stopword, errors on connection errors 58 func (c *Client) IsStopWord(ctx context.Context, word string) (bool, error) { 59 res, err := c.grpcClient.IsWordStopword(ctx, &pb.Word{Word: word}) 60 if err != nil { 61 logConnectionRefused(c.logger, err) 62 return false, err 63 } 64 65 return res.Stopword, nil 66 } 67 68 // IsWordPresent returns true if the given word is a stopword, errors on connection errors 69 func (c *Client) IsWordPresent(ctx context.Context, word string) (bool, error) { 70 res, err := c.grpcClient.IsWordPresent(ctx, &pb.Word{Word: word}) 71 if err != nil { 72 logConnectionRefused(c.logger, err) 73 return false, err 74 } 75 76 return res.Present, nil 77 } 78 79 // SafeGetSimilarWordsWithCertainty will always return a list words - unless there is a network error 80 func (c *Client) SafeGetSimilarWordsWithCertainty(ctx context.Context, word string, certainty float32) ([]string, error) { 81 res, err := c.grpcClient.SafeGetSimilarWordsWithCertainty(ctx, &pb.SimilarWordsParams{Word: word, Certainty: certainty}) 82 if err != nil { 83 logConnectionRefused(c.logger, err) 84 return nil, err 85 } 86 87 output := make([]string, len(res.Words)) 88 for i, word := range res.Words { 89 output[i] = word.Word 90 } 91 92 return output, nil 93 } 94 95 // SchemaSearch for related classes and properties 96 // TODO: is this still used? 97 func (c *Client) SchemaSearch(ctx context.Context, params traverser.SearchParams) (traverser.SearchResults, error) { 98 pbParams := &pb.SchemaSearchParams{ 99 Certainty: params.Certainty, 100 Name: params.Name, 101 SearchType: searchTypeToProto(params.SearchType), 102 } 103 104 res, err := c.grpcClient.SchemaSearch(ctx, pbParams) 105 if err != nil { 106 logConnectionRefused(c.logger, err) 107 return traverser.SearchResults{}, err 108 } 109 110 return schemaSearchResultsFromProto(res), nil 111 } 112 113 func searchTypeToProto(input traverser.SearchType) pb.SearchType { 114 switch input { 115 case traverser.SearchTypeClass: 116 return pb.SearchType_CLASS 117 case traverser.SearchTypeProperty: 118 return pb.SearchType_PROPERTY 119 default: 120 panic(fmt.Sprintf("unknown search type %v", input)) 121 } 122 } 123 124 func searchTypeFromProto(input pb.SearchType) traverser.SearchType { 125 switch input { 126 case pb.SearchType_CLASS: 127 return traverser.SearchTypeClass 128 case pb.SearchType_PROPERTY: 129 return traverser.SearchTypeProperty 130 default: 131 panic(fmt.Sprintf("unknown search type %v", input)) 132 } 133 } 134 135 func schemaSearchResultsFromProto(res *pb.SchemaSearchResults) traverser.SearchResults { 136 return traverser.SearchResults{ 137 Type: searchTypeFromProto(res.Type), 138 Results: searchResultsFromProto(res.Results), 139 } 140 } 141 142 func searchResultsFromProto(input []*pb.SchemaSearchResult) []traverser.SearchResult { 143 output := make([]traverser.SearchResult, len(input)) 144 for i, res := range input { 145 output[i] = traverser.SearchResult{ 146 Certainty: res.Certainty, 147 Name: res.Name, 148 } 149 } 150 151 return output 152 } 153 154 func (c *Client) VectorForWord(ctx context.Context, word string) ([]float32, error) { 155 res, err := c.grpcClient.VectorForWord(ctx, &pb.Word{Word: word}) 156 if err != nil { 157 logConnectionRefused(c.logger, err) 158 return nil, fmt.Errorf("could not get vector from remote: %v", err) 159 } 160 v, _, _ := vectorFromProto(res) 161 return v, nil 162 } 163 164 func logConnectionRefused(logger logrus.FieldLogger, err error) { 165 if strings.Contains(fmt.Sprintf("%v", err), "connect: connection refused") { 166 logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable) 167 } else if strings.Contains(err.Error(), "connectex: No connection could be made because the target machine actively refused it.") { 168 logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable) 169 } 170 } 171 172 func (c *Client) MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) { 173 out := make([][]float32, len(words)) 174 wordParams := make([]*pb.Word, len(words)) 175 176 for i, word := range words { 177 wordParams[i] = &pb.Word{Word: word} 178 } 179 180 res, err := c.grpcClient.MultiVectorForWord(ctx, &pb.WordList{Words: wordParams}) 181 if err != nil { 182 logConnectionRefused(c.logger, err) 183 return nil, err 184 } 185 186 for i, elem := range res.Vectors { 187 if len(elem.Entries) == 0 { 188 // indicates word not found 189 continue 190 } 191 192 out[i], _, _ = vectorFromProto(elem) 193 } 194 195 return out, nil 196 } 197 198 func (c *Client) MultiNearestWordsByVector(ctx context.Context, vectors [][]float32, k, n int) ([]*txt2vecmodels.NearestNeighbors, error) { 199 out := make([]*txt2vecmodels.NearestNeighbors, len(vectors)) 200 searchParams := make([]*pb.VectorNNParams, len(vectors)) 201 202 for i, vector := range vectors { 203 searchParams[i] = &pb.VectorNNParams{ 204 Vector: vectorToProto(vector), 205 K: int32(k), 206 N: int32(n), 207 } 208 } 209 210 res, err := c.grpcClient.MultiNearestWordsByVector(ctx, &pb.VectorNNParamsList{Params: searchParams}) 211 if err != nil { 212 logConnectionRefused(c.logger, err) 213 return nil, err 214 } 215 216 for i, elem := range res.Words { 217 out[i] = &txt2vecmodels.NearestNeighbors{ 218 Neighbors: c.extractNeighbors(elem), 219 } 220 } 221 222 return out, nil 223 } 224 225 func (c *Client) extractNeighbors(elem *pb.NearestWords) []*txt2vecmodels.NearestNeighbor { 226 out := make([]*txt2vecmodels.NearestNeighbor, len(elem.Words)) 227 228 for i := range out { 229 vec, _, _ := vectorFromProto(elem.Vectors.Vectors[i]) 230 out[i] = &txt2vecmodels.NearestNeighbor{ 231 Concept: elem.Words[i], 232 Distance: elem.Distances[i], 233 Vector: vec, 234 } 235 } 236 return out 237 } 238 239 func vectorFromProto(in *pb.Vector) ([]float32, []txt2vecmodels.InterpretationSource, error) { 240 output := make([]float32, len(in.Entries)) 241 for i, entry := range in.Entries { 242 output[i] = entry.Entry 243 } 244 245 source := make([]txt2vecmodels.InterpretationSource, len(in.Source)) 246 for i, s := range in.Source { 247 source[i].Concept = s.Concept 248 source[i].Weight = float64(s.Weight) 249 source[i].Occurrence = s.Occurrence 250 } 251 252 return output, source, nil 253 } 254 255 func (c *Client) VectorForCorpi(ctx context.Context, corpi []string, overridesMap map[string]string) ([]float32, []txt2vecmodels.InterpretationSource, error) { 256 overrides := overridesFromMap(overridesMap) 257 res, err := c.grpcClient.VectorForCorpi(ctx, &pb.Corpi{Corpi: corpi, Overrides: overrides}) 258 if err != nil { 259 if strings.Contains(err.Error(), "connect: connection refused") { 260 c.logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable) 261 } else if strings.Contains(err.Error(), "connectex: No connection could be made because the target machine actively refused it.") { 262 c.logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable) 263 } 264 st, ok := status.FromError(err) 265 if !ok || st.Code() != codes.InvalidArgument { 266 return nil, nil, fmt.Errorf("could not get vector from remote: %v", err) 267 } 268 269 return nil, nil, vectorizer.NewErrNoUsableWordsf(st.Message()) 270 } 271 272 return vectorFromProto(res) 273 } 274 275 func (c *Client) VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error) { 276 vec, _, err := c.VectorForCorpi(ctx, corpi, overrides) 277 return vec, err 278 } 279 280 func (c *Client) NearestWordsByVector(ctx context.Context, vector []float32, n int, k int) ([]string, []float32, error) { 281 res, err := c.grpcClient.NearestWordsByVector(ctx, &pb.VectorNNParams{ 282 K: int32(k), 283 N: int32(n), 284 Vector: vectorToProto(vector), 285 }) 286 if err != nil { 287 logConnectionRefused(c.logger, err) 288 return nil, nil, fmt.Errorf("could not get nearest words by vector: %v", err) 289 } 290 291 return res.Words, res.Distances, nil 292 } 293 294 func (c *Client) AddExtension(ctx context.Context, extension *models.C11yExtension) error { 295 _, err := c.grpcClient.AddExtension(ctx, &pb.ExtensionInput{ 296 Concept: extension.Concept, 297 Definition: strings.ToLower(extension.Definition), 298 Weight: extension.Weight, 299 }) 300 301 return err 302 } 303 304 func vectorToProto(in []float32) *pb.Vector { 305 output := make([]*pb.VectorEntry, len(in)) 306 for i, entry := range in { 307 output[i] = &pb.VectorEntry{ 308 Entry: entry, 309 } 310 } 311 312 return &pb.Vector{Entries: output} 313 } 314 315 func (c *Client) WaitForStartupAndValidateVersion(startupCtx context.Context, 316 requiredMinimumVersion string, interval time.Duration, 317 ) error { 318 for { 319 if err := startupCtx.Err(); err != nil { 320 return errors.Wrap(err, "wait for contextionary remote inference service") 321 } 322 323 time.Sleep(interval) 324 325 ctx, cancel := context.WithTimeout(startupCtx, 2*time.Second) 326 defer cancel() 327 v, err := c.version(ctx) 328 if err != nil { 329 c.logger.WithField("action", "startup_check_contextionary").WithError(err). 330 Warnf("could not connect to contextionary at startup, trying again in 1 sec") 331 continue 332 } 333 334 ok, err := extractVersionAndCompare(v, requiredMinimumVersion) 335 if err != nil { 336 c.logger.WithField("action", "startup_check_contextionary"). 337 WithField("requiredMinimumContextionaryVersion", requiredMinimumVersion). 338 WithField("contextionaryVersion", v). 339 WithError(err). 340 Warnf("cannot determine if contextionary version is compatible. " + 341 "This is fine in development, but probelematic if you see this production") 342 return nil 343 } 344 345 if ok { 346 c.logger.WithField("action", "startup_check_contextionary"). 347 WithField("requiredMinimumContextionaryVersion", requiredMinimumVersion). 348 WithField("contextionaryVersion", v). 349 Infof("found a valid contextionary version") 350 return nil 351 } else { 352 return errors.Errorf("insuffcient contextionary version: need at least %s, got %s", 353 requiredMinimumVersion, v) 354 } 355 } 356 } 357 358 func overridesFromMap(in map[string]string) []*pb.Override { 359 if in == nil { 360 return nil 361 } 362 363 out := make([]*pb.Override, len(in)) 364 i := 0 365 for key, value := range in { 366 out[i] = &pb.Override{ 367 Word: key, 368 Expression: value, 369 } 370 i++ 371 } 372 373 return out 374 }