github.com/weaviate/weaviate@v1.24.6/modules/text-spellcheck/clients/spellcheck.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/weaviate/weaviate/modules/text-spellcheck/ent" 26 ) 27 28 type spellCheckInput struct { 29 Text []string `json:"text"` 30 } 31 32 type spellCheckCorrection struct { 33 Original string `json:"original"` 34 Correction string `json:"correction"` 35 } 36 37 type spellCheckResponse struct { 38 spellCheckInput 39 Changes []spellCheckCorrection `json:"changes"` 40 } 41 42 type spellCheck struct { 43 origin string 44 httpClient *http.Client 45 logger logrus.FieldLogger 46 } 47 48 func New(origin string, timeout time.Duration, logger logrus.FieldLogger) *spellCheck { 49 return &spellCheck{ 50 origin: origin, 51 httpClient: &http.Client{ 52 Timeout: timeout, 53 }, 54 logger: logger, 55 } 56 } 57 58 func (s *spellCheck) Check(ctx context.Context, text []string) (*ent.SpellCheckResult, error) { 59 body, err := json.Marshal(spellCheckInput{ 60 Text: text, 61 }) 62 if err != nil { 63 return nil, errors.Wrapf(err, "marshal body") 64 } 65 66 req, err := http.NewRequestWithContext(ctx, "POST", s.url("/spellcheck/"), 67 bytes.NewReader(body)) 68 if err != nil { 69 return nil, errors.Wrap(err, "create POST request") 70 } 71 72 res, err := s.httpClient.Do(req) 73 if err != nil { 74 return nil, errors.Wrap(err, "send POST request") 75 } 76 defer res.Body.Close() 77 78 bodyBytes, err := io.ReadAll(res.Body) 79 if err != nil { 80 return nil, errors.Wrap(err, "read response body") 81 } 82 83 var resBody spellCheckResponse 84 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 85 return nil, errors.Wrap(err, "unmarshal response body") 86 } 87 88 if res.StatusCode > 399 { 89 return nil, errors.Errorf("fail with status %d", res.StatusCode) 90 } 91 92 return &ent.SpellCheckResult{ 93 Text: resBody.Text, 94 Changes: s.getCorrections(resBody.Changes), 95 }, nil 96 } 97 98 func (s *spellCheck) url(path string) string { 99 return fmt.Sprintf("%s%s", s.origin, path) 100 } 101 102 func (s *spellCheck) getCorrections(changes []spellCheckCorrection) []ent.SpellCheckCorrection { 103 if len(changes) == 0 { 104 return nil 105 } 106 corrections := make([]ent.SpellCheckCorrection, len(changes)) 107 for i := range changes { 108 corrections[i] = ent.SpellCheckCorrection{ 109 Original: changes[i].Original, 110 Correction: changes[i].Correction, 111 } 112 } 113 return corrections 114 }