github.com/weaviate/weaviate@v1.24.6/modules/ner-transformers/clients/ner.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"time"
    22  
    23  	"github.com/pkg/errors"
    24  	"github.com/sirupsen/logrus"
    25  	"github.com/weaviate/weaviate/modules/ner-transformers/ent"
    26  )
    27  
    28  type ner struct {
    29  	origin     string
    30  	httpClient *http.Client
    31  	logger     logrus.FieldLogger
    32  }
    33  
    34  type nerInput struct {
    35  	Text string `json:"text"`
    36  }
    37  
    38  type tokenResponse struct {
    39  	// Property      string  `json:"property"`
    40  	Entity        string  `json:"entity"`
    41  	Certainty     float64 `json:"certainty"`
    42  	Distance      float64 `json:"distance"`
    43  	Word          string  `json:"word"`
    44  	StartPosition int     `json:"startPosition"`
    45  	EndPosition   int     `json:"endPosition"`
    46  }
    47  
    48  type nerResponse struct {
    49  	Error string
    50  	nerInput
    51  	Tokens []tokenResponse `json:"tokens"`
    52  }
    53  
    54  func New(origin string, timeout time.Duration, logger logrus.FieldLogger) *ner {
    55  	return &ner{
    56  		origin:     origin,
    57  		httpClient: &http.Client{Timeout: timeout},
    58  		logger:     logger,
    59  	}
    60  }
    61  
    62  func (n *ner) GetTokens(ctx context.Context, property,
    63  	text string,
    64  ) ([]ent.TokenResult, error) {
    65  	body, err := json.Marshal(nerInput{
    66  		Text: text,
    67  	})
    68  	if err != nil {
    69  		return nil, errors.Wrapf(err, "marshal body")
    70  	}
    71  
    72  	req, err := http.NewRequestWithContext(ctx, "POST", n.url("/ner/"),
    73  		bytes.NewReader(body))
    74  	if err != nil {
    75  		return nil, errors.Wrap(err, "create POST request")
    76  	}
    77  
    78  	res, err := n.httpClient.Do(req)
    79  	if err != nil {
    80  		return nil, errors.Wrap(err, "send POST request")
    81  	}
    82  	defer res.Body.Close()
    83  
    84  	bodyBytes, err := io.ReadAll(res.Body)
    85  	if err != nil {
    86  		return nil, errors.Wrap(err, "read response body")
    87  	}
    88  
    89  	var resBody nerResponse
    90  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
    91  		return nil, errors.Wrap(err, "unmarshal response body")
    92  	}
    93  
    94  	if res.StatusCode > 399 {
    95  		return nil, errors.Errorf("fail with status %d: %s", res.StatusCode, resBody.Error)
    96  	}
    97  
    98  	out := make([]ent.TokenResult, len(resBody.Tokens))
    99  
   100  	for i, elem := range resBody.Tokens {
   101  		out[i].Certainty = elem.Certainty
   102  		out[i].Distance = elem.Distance
   103  		out[i].Entity = elem.Entity
   104  		out[i].Word = elem.Word
   105  		out[i].StartPosition = elem.StartPosition
   106  		out[i].EndPosition = elem.EndPosition
   107  		out[i].Property = property
   108  	}
   109  
   110  	// format resBody to nerResult
   111  	return out, nil
   112  }
   113  
   114  func (n *ner) url(path string) string {
   115  	return fmt.Sprintf("%s%s", n.origin, path)
   116  }