github.com/weaviate/weaviate@v1.24.6/modules/ner-transformers/clients/ner.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/weaviate/weaviate/modules/ner-transformers/ent" 26 ) 27 28 type ner struct { 29 origin string 30 httpClient *http.Client 31 logger logrus.FieldLogger 32 } 33 34 type nerInput struct { 35 Text string `json:"text"` 36 } 37 38 type tokenResponse struct { 39 // Property string `json:"property"` 40 Entity string `json:"entity"` 41 Certainty float64 `json:"certainty"` 42 Distance float64 `json:"distance"` 43 Word string `json:"word"` 44 StartPosition int `json:"startPosition"` 45 EndPosition int `json:"endPosition"` 46 } 47 48 type nerResponse struct { 49 Error string 50 nerInput 51 Tokens []tokenResponse `json:"tokens"` 52 } 53 54 func New(origin string, timeout time.Duration, logger logrus.FieldLogger) *ner { 55 return &ner{ 56 origin: origin, 57 httpClient: &http.Client{Timeout: timeout}, 58 logger: logger, 59 } 60 } 61 62 func (n *ner) GetTokens(ctx context.Context, property, 63 text string, 64 ) ([]ent.TokenResult, error) { 65 body, err := json.Marshal(nerInput{ 66 Text: text, 67 }) 68 if err != nil { 69 return nil, errors.Wrapf(err, "marshal body") 70 } 71 72 req, err := http.NewRequestWithContext(ctx, "POST", n.url("/ner/"), 73 bytes.NewReader(body)) 74 if err != nil { 75 return nil, errors.Wrap(err, "create POST request") 76 } 77 78 res, err := n.httpClient.Do(req) 79 if err != nil { 80 return nil, errors.Wrap(err, "send POST request") 81 } 82 defer res.Body.Close() 83 84 bodyBytes, err := io.ReadAll(res.Body) 85 if err != nil { 86 return nil, errors.Wrap(err, "read response body") 87 } 88 89 var resBody nerResponse 90 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 91 return nil, errors.Wrap(err, "unmarshal response body") 92 } 93 94 if res.StatusCode > 399 { 95 return nil, errors.Errorf("fail with status %d: %s", res.StatusCode, resBody.Error) 96 } 97 98 out := make([]ent.TokenResult, len(resBody.Tokens)) 99 100 for i, elem := range resBody.Tokens { 101 out[i].Certainty = elem.Certainty 102 out[i].Distance = elem.Distance 103 out[i].Entity = elem.Entity 104 out[i].Word = elem.Word 105 out[i].StartPosition = elem.StartPosition 106 out[i].EndPosition = elem.EndPosition 107 out[i].Property = property 108 } 109 110 // format resBody to nerResult 111 return out, nil 112 } 113 114 func (n *ner) url(path string) string { 115 return fmt.Sprintf("%s%s", n.origin, path) 116 }