github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/client/contextionary.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package client
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"strings"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	pb "github.com/weaviate/contextionary/contextionary"
    23  	"github.com/weaviate/weaviate/entities/models"
    24  	txt2vecmodels "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/models"
    25  	"github.com/weaviate/weaviate/modules/text2vec-contextionary/vectorizer"
    26  	"github.com/weaviate/weaviate/usecases/traverser"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/codes"
    29  	"google.golang.org/grpc/credentials/insecure"
    30  	"google.golang.org/grpc/status"
    31  )
    32  
    33  const ModelUncontactable = "module uncontactable"
    34  
    35  // Client establishes a gRPC connection to a remote contextionary service
    36  type Client struct {
    37  	grpcClient pb.ContextionaryClient
    38  	logger     logrus.FieldLogger
    39  }
    40  
    41  // NewClient from gRPC discovery url to connect to a remote contextionary service
    42  func NewClient(uri string, logger logrus.FieldLogger) (*Client, error) {
    43  	conn, err := grpc.Dial(uri,
    44  		grpc.WithTransportCredentials(insecure.NewCredentials()),
    45  		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(1024*1024*48)))
    46  	if err != nil {
    47  		return nil, fmt.Errorf("couldn't connect to remote contextionary gRPC server: %s", err)
    48  	}
    49  
    50  	client := pb.NewContextionaryClient(conn)
    51  	return &Client{
    52  		grpcClient: client,
    53  		logger:     logger,
    54  	}, nil
    55  }
    56  
    57  // IsStopWord returns true if the given word is a stopword, errors on connection errors
    58  func (c *Client) IsStopWord(ctx context.Context, word string) (bool, error) {
    59  	res, err := c.grpcClient.IsWordStopword(ctx, &pb.Word{Word: word})
    60  	if err != nil {
    61  		logConnectionRefused(c.logger, err)
    62  		return false, err
    63  	}
    64  
    65  	return res.Stopword, nil
    66  }
    67  
    68  // IsWordPresent returns true if the given word is a stopword, errors on connection errors
    69  func (c *Client) IsWordPresent(ctx context.Context, word string) (bool, error) {
    70  	res, err := c.grpcClient.IsWordPresent(ctx, &pb.Word{Word: word})
    71  	if err != nil {
    72  		logConnectionRefused(c.logger, err)
    73  		return false, err
    74  	}
    75  
    76  	return res.Present, nil
    77  }
    78  
    79  // SafeGetSimilarWordsWithCertainty will always return a list words - unless there is a network error
    80  func (c *Client) SafeGetSimilarWordsWithCertainty(ctx context.Context, word string, certainty float32) ([]string, error) {
    81  	res, err := c.grpcClient.SafeGetSimilarWordsWithCertainty(ctx, &pb.SimilarWordsParams{Word: word, Certainty: certainty})
    82  	if err != nil {
    83  		logConnectionRefused(c.logger, err)
    84  		return nil, err
    85  	}
    86  
    87  	output := make([]string, len(res.Words))
    88  	for i, word := range res.Words {
    89  		output[i] = word.Word
    90  	}
    91  
    92  	return output, nil
    93  }
    94  
    95  // SchemaSearch for related classes and properties
    96  // TODO: is this still used?
    97  func (c *Client) SchemaSearch(ctx context.Context, params traverser.SearchParams) (traverser.SearchResults, error) {
    98  	pbParams := &pb.SchemaSearchParams{
    99  		Certainty:  params.Certainty,
   100  		Name:       params.Name,
   101  		SearchType: searchTypeToProto(params.SearchType),
   102  	}
   103  
   104  	res, err := c.grpcClient.SchemaSearch(ctx, pbParams)
   105  	if err != nil {
   106  		logConnectionRefused(c.logger, err)
   107  		return traverser.SearchResults{}, err
   108  	}
   109  
   110  	return schemaSearchResultsFromProto(res), nil
   111  }
   112  
   113  func searchTypeToProto(input traverser.SearchType) pb.SearchType {
   114  	switch input {
   115  	case traverser.SearchTypeClass:
   116  		return pb.SearchType_CLASS
   117  	case traverser.SearchTypeProperty:
   118  		return pb.SearchType_PROPERTY
   119  	default:
   120  		panic(fmt.Sprintf("unknown search type %v", input))
   121  	}
   122  }
   123  
   124  func searchTypeFromProto(input pb.SearchType) traverser.SearchType {
   125  	switch input {
   126  	case pb.SearchType_CLASS:
   127  		return traverser.SearchTypeClass
   128  	case pb.SearchType_PROPERTY:
   129  		return traverser.SearchTypeProperty
   130  	default:
   131  		panic(fmt.Sprintf("unknown search type %v", input))
   132  	}
   133  }
   134  
   135  func schemaSearchResultsFromProto(res *pb.SchemaSearchResults) traverser.SearchResults {
   136  	return traverser.SearchResults{
   137  		Type:    searchTypeFromProto(res.Type),
   138  		Results: searchResultsFromProto(res.Results),
   139  	}
   140  }
   141  
   142  func searchResultsFromProto(input []*pb.SchemaSearchResult) []traverser.SearchResult {
   143  	output := make([]traverser.SearchResult, len(input))
   144  	for i, res := range input {
   145  		output[i] = traverser.SearchResult{
   146  			Certainty: res.Certainty,
   147  			Name:      res.Name,
   148  		}
   149  	}
   150  
   151  	return output
   152  }
   153  
   154  func (c *Client) VectorForWord(ctx context.Context, word string) ([]float32, error) {
   155  	res, err := c.grpcClient.VectorForWord(ctx, &pb.Word{Word: word})
   156  	if err != nil {
   157  		logConnectionRefused(c.logger, err)
   158  		return nil, fmt.Errorf("could not get vector from remote: %v", err)
   159  	}
   160  	v, _, _ := vectorFromProto(res)
   161  	return v, nil
   162  }
   163  
   164  func logConnectionRefused(logger logrus.FieldLogger, err error) {
   165  	if strings.Contains(fmt.Sprintf("%v", err), "connect: connection refused") {
   166  		logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable)
   167  	} else if strings.Contains(err.Error(), "connectex: No connection could be made because the target machine actively refused it.") {
   168  		logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable)
   169  	}
   170  }
   171  
   172  func (c *Client) MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) {
   173  	out := make([][]float32, len(words))
   174  	wordParams := make([]*pb.Word, len(words))
   175  
   176  	for i, word := range words {
   177  		wordParams[i] = &pb.Word{Word: word}
   178  	}
   179  
   180  	res, err := c.grpcClient.MultiVectorForWord(ctx, &pb.WordList{Words: wordParams})
   181  	if err != nil {
   182  		logConnectionRefused(c.logger, err)
   183  		return nil, err
   184  	}
   185  
   186  	for i, elem := range res.Vectors {
   187  		if len(elem.Entries) == 0 {
   188  			// indicates word not found
   189  			continue
   190  		}
   191  
   192  		out[i], _, _ = vectorFromProto(elem)
   193  	}
   194  
   195  	return out, nil
   196  }
   197  
   198  func (c *Client) MultiNearestWordsByVector(ctx context.Context, vectors [][]float32, k, n int) ([]*txt2vecmodels.NearestNeighbors, error) {
   199  	out := make([]*txt2vecmodels.NearestNeighbors, len(vectors))
   200  	searchParams := make([]*pb.VectorNNParams, len(vectors))
   201  
   202  	for i, vector := range vectors {
   203  		searchParams[i] = &pb.VectorNNParams{
   204  			Vector: vectorToProto(vector),
   205  			K:      int32(k),
   206  			N:      int32(n),
   207  		}
   208  	}
   209  
   210  	res, err := c.grpcClient.MultiNearestWordsByVector(ctx, &pb.VectorNNParamsList{Params: searchParams})
   211  	if err != nil {
   212  		logConnectionRefused(c.logger, err)
   213  		return nil, err
   214  	}
   215  
   216  	for i, elem := range res.Words {
   217  		out[i] = &txt2vecmodels.NearestNeighbors{
   218  			Neighbors: c.extractNeighbors(elem),
   219  		}
   220  	}
   221  
   222  	return out, nil
   223  }
   224  
   225  func (c *Client) extractNeighbors(elem *pb.NearestWords) []*txt2vecmodels.NearestNeighbor {
   226  	out := make([]*txt2vecmodels.NearestNeighbor, len(elem.Words))
   227  
   228  	for i := range out {
   229  		vec, _, _ := vectorFromProto(elem.Vectors.Vectors[i])
   230  		out[i] = &txt2vecmodels.NearestNeighbor{
   231  			Concept:  elem.Words[i],
   232  			Distance: elem.Distances[i],
   233  			Vector:   vec,
   234  		}
   235  	}
   236  	return out
   237  }
   238  
   239  func vectorFromProto(in *pb.Vector) ([]float32, []txt2vecmodels.InterpretationSource, error) {
   240  	output := make([]float32, len(in.Entries))
   241  	for i, entry := range in.Entries {
   242  		output[i] = entry.Entry
   243  	}
   244  
   245  	source := make([]txt2vecmodels.InterpretationSource, len(in.Source))
   246  	for i, s := range in.Source {
   247  		source[i].Concept = s.Concept
   248  		source[i].Weight = float64(s.Weight)
   249  		source[i].Occurrence = s.Occurrence
   250  	}
   251  
   252  	return output, source, nil
   253  }
   254  
   255  func (c *Client) VectorForCorpi(ctx context.Context, corpi []string, overridesMap map[string]string) ([]float32, []txt2vecmodels.InterpretationSource, error) {
   256  	overrides := overridesFromMap(overridesMap)
   257  	res, err := c.grpcClient.VectorForCorpi(ctx, &pb.Corpi{Corpi: corpi, Overrides: overrides})
   258  	if err != nil {
   259  		if strings.Contains(err.Error(), "connect: connection refused") {
   260  			c.logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable)
   261  		} else if strings.Contains(err.Error(), "connectex: No connection could be made because the target machine actively refused it.") {
   262  			c.logger.WithError(err).WithField("module", "contextionary").Warnf(ModelUncontactable)
   263  		}
   264  		st, ok := status.FromError(err)
   265  		if !ok || st.Code() != codes.InvalidArgument {
   266  			return nil, nil, fmt.Errorf("could not get vector from remote: %v", err)
   267  		}
   268  
   269  		return nil, nil, vectorizer.NewErrNoUsableWordsf(st.Message())
   270  	}
   271  
   272  	return vectorFromProto(res)
   273  }
   274  
   275  func (c *Client) VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error) {
   276  	vec, _, err := c.VectorForCorpi(ctx, corpi, overrides)
   277  	return vec, err
   278  }
   279  
   280  func (c *Client) NearestWordsByVector(ctx context.Context, vector []float32, n int, k int) ([]string, []float32, error) {
   281  	res, err := c.grpcClient.NearestWordsByVector(ctx, &pb.VectorNNParams{
   282  		K:      int32(k),
   283  		N:      int32(n),
   284  		Vector: vectorToProto(vector),
   285  	})
   286  	if err != nil {
   287  		logConnectionRefused(c.logger, err)
   288  		return nil, nil, fmt.Errorf("could not get nearest words by vector: %v", err)
   289  	}
   290  
   291  	return res.Words, res.Distances, nil
   292  }
   293  
   294  func (c *Client) AddExtension(ctx context.Context, extension *models.C11yExtension) error {
   295  	_, err := c.grpcClient.AddExtension(ctx, &pb.ExtensionInput{
   296  		Concept:    extension.Concept,
   297  		Definition: strings.ToLower(extension.Definition),
   298  		Weight:     extension.Weight,
   299  	})
   300  
   301  	return err
   302  }
   303  
   304  func vectorToProto(in []float32) *pb.Vector {
   305  	output := make([]*pb.VectorEntry, len(in))
   306  	for i, entry := range in {
   307  		output[i] = &pb.VectorEntry{
   308  			Entry: entry,
   309  		}
   310  	}
   311  
   312  	return &pb.Vector{Entries: output}
   313  }
   314  
   315  func (c *Client) WaitForStartupAndValidateVersion(startupCtx context.Context,
   316  	requiredMinimumVersion string, interval time.Duration,
   317  ) error {
   318  	for {
   319  		if err := startupCtx.Err(); err != nil {
   320  			return errors.Wrap(err, "wait for contextionary remote inference service")
   321  		}
   322  
   323  		time.Sleep(interval)
   324  
   325  		ctx, cancel := context.WithTimeout(startupCtx, 2*time.Second)
   326  		defer cancel()
   327  		v, err := c.version(ctx)
   328  		if err != nil {
   329  			c.logger.WithField("action", "startup_check_contextionary").WithError(err).
   330  				Warnf("could not connect to contextionary at startup, trying again in 1 sec")
   331  			continue
   332  		}
   333  
   334  		ok, err := extractVersionAndCompare(v, requiredMinimumVersion)
   335  		if err != nil {
   336  			c.logger.WithField("action", "startup_check_contextionary").
   337  				WithField("requiredMinimumContextionaryVersion", requiredMinimumVersion).
   338  				WithField("contextionaryVersion", v).
   339  				WithError(err).
   340  				Warnf("cannot determine if contextionary version is compatible. " +
   341  					"This is fine in development, but probelematic if you see this production")
   342  			return nil
   343  		}
   344  
   345  		if ok {
   346  			c.logger.WithField("action", "startup_check_contextionary").
   347  				WithField("requiredMinimumContextionaryVersion", requiredMinimumVersion).
   348  				WithField("contextionaryVersion", v).
   349  				Infof("found a valid contextionary version")
   350  			return nil
   351  		} else {
   352  			return errors.Errorf("insuffcient contextionary version: need at least %s, got %s",
   353  				requiredMinimumVersion, v)
   354  		}
   355  	}
   356  }
   357  
   358  func overridesFromMap(in map[string]string) []*pb.Override {
   359  	if in == nil {
   360  		return nil
   361  	}
   362  
   363  	out := make([]*pb.Override, len(in))
   364  	i := 0
   365  	for key, value := range in {
   366  		out[i] = &pb.Override{
   367  			Word:       key,
   368  			Expression: value,
   369  		}
   370  		i++
   371  	}
   372  
   373  	return out
   374  }