github.com/weaviate/weaviate@v1.24.6/modules/text2vec-transformers/clients/transformers.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/weaviate/weaviate/modules/text2vec-transformers/ent" 26 ) 27 28 type vectorizer struct { 29 originPassage string 30 originQuery string 31 httpClient *http.Client 32 logger logrus.FieldLogger 33 } 34 35 func New(originPassage, originQuery string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { 36 return &vectorizer{ 37 originPassage: originPassage, 38 originQuery: originQuery, 39 httpClient: &http.Client{ 40 Timeout: timeout, 41 }, 42 logger: logger, 43 } 44 } 45 46 func (v *vectorizer) VectorizeObject(ctx context.Context, input string, 47 config ent.VectorizationConfig, 48 ) (*ent.VectorizationResult, error) { 49 return v.vectorize(ctx, input, config, v.urlPassage) 50 } 51 52 func (v *vectorizer) VectorizeQuery(ctx context.Context, input string, 53 config ent.VectorizationConfig, 54 ) (*ent.VectorizationResult, error) { 55 return v.vectorize(ctx, input, config, v.urlQuery) 56 } 57 58 func (v *vectorizer) vectorize(ctx context.Context, input string, 59 config ent.VectorizationConfig, url func(string, ent.VectorizationConfig) string, 60 ) (*ent.VectorizationResult, error) { 61 body, err := json.Marshal(vecRequest{ 62 Text: input, 63 Config: vecRequestConfig{ 64 PoolingStrategy: config.PoolingStrategy, 65 }, 66 }) 67 if err != nil { 68 return nil, errors.Wrapf(err, "marshal body") 69 } 70 71 req, err := http.NewRequestWithContext(ctx, "POST", url("/vectors", config), 72 bytes.NewReader(body)) 73 if err != nil { 74 return nil, errors.Wrap(err, "create POST request") 75 } 76 77 res, err := v.httpClient.Do(req) 78 if err != nil { 79 return nil, errors.Wrap(err, "send POST request") 80 } 81 defer res.Body.Close() 82 83 bodyBytes, err := io.ReadAll(res.Body) 84 if err != nil { 85 return nil, errors.Wrap(err, "read response body") 86 } 87 88 var resBody vecRequest 89 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 90 return nil, errors.Wrap(err, "unmarshal response body") 91 } 92 93 if res.StatusCode > 399 { 94 return nil, errors.Errorf("fail with status %d: %s", res.StatusCode, 95 resBody.Error) 96 } 97 98 return &ent.VectorizationResult{ 99 Text: resBody.Text, 100 Dimensions: resBody.Dims, 101 Vector: resBody.Vector, 102 }, nil 103 } 104 105 func (v *vectorizer) urlPassage(path string, config ent.VectorizationConfig) string { 106 baseURL := v.originPassage 107 if config.PassageInferenceURL != "" { 108 baseURL = config.PassageInferenceURL 109 } 110 if config.InferenceURL != "" { 111 baseURL = config.InferenceURL 112 } 113 return fmt.Sprintf("%s%s", baseURL, path) 114 } 115 116 func (v *vectorizer) urlQuery(path string, config ent.VectorizationConfig) string { 117 baseURL := v.originQuery 118 if config.QueryInferenceURL != "" { 119 baseURL = config.QueryInferenceURL 120 } 121 if config.InferenceURL != "" { 122 baseURL = config.InferenceURL 123 } 124 return fmt.Sprintf("%s%s", baseURL, path) 125 } 126 127 type vecRequest struct { 128 Text string `json:"text"` 129 Dims int `json:"dims"` 130 Vector []float32 `json:"vector"` 131 Error string `json:"error"` 132 Config vecRequestConfig `json:"config"` 133 } 134 135 type vecRequestConfig struct { 136 PoolingStrategy string `json:"pooling_strategy"` 137 }