github.com/weaviate/weaviate@v1.24.6/modules/qna-transformers/clients/qna.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"time"
    22  
    23  	"github.com/pkg/errors"
    24  	"github.com/sirupsen/logrus"
    25  	"github.com/weaviate/weaviate/entities/additional"
    26  	"github.com/weaviate/weaviate/modules/qna-transformers/ent"
    27  )
    28  
    29  type qna struct {
    30  	origin     string
    31  	httpClient *http.Client
    32  	logger     logrus.FieldLogger
    33  }
    34  
    35  func New(origin string, timeout time.Duration, logger logrus.FieldLogger) *qna {
    36  	return &qna{
    37  		origin:     origin,
    38  		httpClient: &http.Client{Timeout: timeout},
    39  		logger:     logger,
    40  	}
    41  }
    42  
    43  func (q *qna) Answer(ctx context.Context,
    44  	text, question string,
    45  ) (*ent.AnswerResult, error) {
    46  	body, err := json.Marshal(answersInput{
    47  		Text:     text,
    48  		Question: question,
    49  	})
    50  	if err != nil {
    51  		return nil, errors.Wrapf(err, "marshal body")
    52  	}
    53  
    54  	req, err := http.NewRequestWithContext(ctx, "POST", q.url("/answers/"),
    55  		bytes.NewReader(body))
    56  	if err != nil {
    57  		return nil, errors.Wrap(err, "create POST request")
    58  	}
    59  
    60  	res, err := q.httpClient.Do(req)
    61  	if err != nil {
    62  		return nil, errors.Wrap(err, "send POST request")
    63  	}
    64  	defer res.Body.Close()
    65  
    66  	bodyBytes, err := io.ReadAll(res.Body)
    67  	if err != nil {
    68  		return nil, errors.Wrap(err, "read response body")
    69  	}
    70  
    71  	var resBody answersResponse
    72  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
    73  		return nil, errors.Wrap(err, "unmarshal response body")
    74  	}
    75  
    76  	if res.StatusCode > 399 {
    77  		return nil, errors.Errorf("fail with status %d: %s", res.StatusCode,
    78  			resBody.Error)
    79  	}
    80  
    81  	return &ent.AnswerResult{
    82  		Text:      resBody.Text,
    83  		Question:  resBody.Question,
    84  		Answer:    resBody.Answer,
    85  		Certainty: resBody.Certainty,
    86  		Distance:  additional.CertaintyToDistPtr(resBody.Certainty),
    87  	}, nil
    88  }
    89  
    90  func (q *qna) url(path string) string {
    91  	return fmt.Sprintf("%s%s", q.origin, path)
    92  }
    93  
    94  type answersInput struct {
    95  	Text     string `json:"text"`
    96  	Question string `json:"question"`
    97  }
    98  
    99  type answersResponse struct {
   100  	answersInput `json:"answersInput"`
   101  	Answer       *string  `json:"answer"`
   102  	Certainty    *float64 `json:"certainty"`
   103  	Distance     *float64 `json:"distance"`
   104  	Error        string   `json:"error"`
   105  }