github.com/weaviate/weaviate@v1.24.6/modules/generative-cohere/clients/cohere.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"net/url"
    22  	"regexp"
    23  	"strings"
    24  	"time"
    25  
    26  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    27  
    28  	"github.com/pkg/errors"
    29  	"github.com/sirupsen/logrus"
    30  	"github.com/weaviate/weaviate/entities/moduletools"
    31  	"github.com/weaviate/weaviate/modules/generative-cohere/config"
    32  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    33  )
    34  
    35  var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
    36  
    37  type cohere struct {
    38  	apiKey     string
    39  	httpClient *http.Client
    40  	logger     logrus.FieldLogger
    41  }
    42  
    43  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *cohere {
    44  	return &cohere{
    45  		apiKey: apiKey,
    46  		httpClient: &http.Client{
    47  			Timeout: timeout,
    48  		},
    49  		logger: logger,
    50  	}
    51  }
    52  
    53  func (v *cohere) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    54  	forPrompt, err := v.generateForPrompt(textProperties, prompt)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	return v.Generate(ctx, cfg, forPrompt)
    59  }
    60  
    61  func (v *cohere) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    62  	forTask, err := v.generatePromptForTask(textProperties, task)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	return v.Generate(ctx, cfg, forTask)
    67  }
    68  
    69  func (v *cohere) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
    70  	settings := config.NewClassSettings(cfg)
    71  
    72  	cohereUrl, err := v.getCohereUrl(ctx, settings.BaseURL())
    73  	if err != nil {
    74  		return nil, errors.Wrap(err, "join Cohere API host and path")
    75  	}
    76  	input := generateInput{
    77  		Prompt:            prompt,
    78  		Model:             settings.Model(),
    79  		MaxTokens:         settings.MaxTokens(),
    80  		Temperature:       settings.Temperature(),
    81  		K:                 settings.K(),
    82  		StopSequences:     settings.StopSequences(),
    83  		ReturnLikelihoods: settings.ReturnLikelihoods(),
    84  	}
    85  
    86  	body, err := json.Marshal(input)
    87  	if err != nil {
    88  		return nil, errors.Wrap(err, "marshal body")
    89  	}
    90  
    91  	req, err := http.NewRequestWithContext(ctx, "POST", cohereUrl,
    92  		bytes.NewReader(body))
    93  	if err != nil {
    94  		return nil, errors.Wrap(err, "create POST request")
    95  	}
    96  	apiKey, err := v.getApiKey(ctx)
    97  	if err != nil {
    98  		return nil, errors.Wrapf(err, "Cohere API Key")
    99  	}
   100  	req.Header.Add("Authorization", fmt.Sprintf("BEARER %s", apiKey))
   101  	req.Header.Add("Content-Type", "application/json")
   102  	req.Header.Add("Request-Source", "unspecified:weaviate")
   103  
   104  	res, err := v.httpClient.Do(req)
   105  	if err != nil {
   106  		return nil, errors.Wrap(err, "send POST request")
   107  	}
   108  	defer res.Body.Close()
   109  
   110  	bodyBytes, err := io.ReadAll(res.Body)
   111  	if err != nil {
   112  		return nil, errors.Wrap(err, "read response body")
   113  	}
   114  
   115  	var resBody generateResponse
   116  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   117  		return nil, errors.Wrap(err, "unmarshal response body")
   118  	}
   119  
   120  	if res.StatusCode != 200 || resBody.Error != nil {
   121  		if resBody.Error != nil {
   122  			return nil, errors.Errorf("connection to Cohere API failed with status: %d error: %v", res.StatusCode, resBody.Error.Message)
   123  		}
   124  		return nil, errors.Errorf("connection to Cohere API failed with status: %d", res.StatusCode)
   125  	}
   126  
   127  	textResponse := resBody.Generations[0].Text
   128  
   129  	return &generativemodels.GenerateResponse{
   130  		Result: &textResponse,
   131  	}, nil
   132  }
   133  
   134  func (v *cohere) getCohereUrl(ctx context.Context, baseURL string) (string, error) {
   135  	passedBaseURL := baseURL
   136  	if headerBaseURL := v.getValueFromContext(ctx, "X-Cohere-Baseurl"); headerBaseURL != "" {
   137  		passedBaseURL = headerBaseURL
   138  	}
   139  	return url.JoinPath(passedBaseURL, "/v1/generate")
   140  }
   141  
   142  func (v *cohere) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
   143  	marshal, err := json.Marshal(textProperties)
   144  	if err != nil {
   145  		return "", err
   146  	}
   147  	return fmt.Sprintf(`'%v:
   148  %v`, task, string(marshal)), nil
   149  }
   150  
   151  func (v *cohere) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
   152  	all := compile.FindAll([]byte(prompt), -1)
   153  	for _, match := range all {
   154  		originalProperty := string(match)
   155  		replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
   156  		replacedProperty = strings.TrimSpace(replacedProperty)
   157  		value := textProperties[replacedProperty]
   158  		if value == "" {
   159  			return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
   160  		}
   161  		prompt = strings.ReplaceAll(prompt, originalProperty, value)
   162  	}
   163  	return prompt, nil
   164  }
   165  
   166  func (v *cohere) getValueFromContext(ctx context.Context, key string) string {
   167  	if value := ctx.Value(key); value != nil {
   168  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   169  			return keyHeader[0]
   170  		}
   171  	}
   172  	// try getting header from GRPC if not successful
   173  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   174  		return apiKey[0]
   175  	}
   176  	return ""
   177  }
   178  
   179  func (v *cohere) getApiKey(ctx context.Context) (string, error) {
   180  	if apiKey := v.getValueFromContext(ctx, "X-Cohere-Api-Key"); apiKey != "" {
   181  		return apiKey, nil
   182  	}
   183  	if v.apiKey != "" {
   184  		return v.apiKey, nil
   185  	}
   186  	return "", errors.New("no api key found " +
   187  		"neither in request header: X-Cohere-Api-Key " +
   188  		"nor in environment variable under COHERE_APIKEY")
   189  }
   190  
   191  type generateInput struct {
   192  	Prompt            string   `json:"prompt"`
   193  	Model             string   `json:"model"`
   194  	MaxTokens         int      `json:"max_tokens"`
   195  	Temperature       int      `json:"temperature"`
   196  	K                 int      `json:"k"`
   197  	StopSequences     []string `json:"stop_sequences"`
   198  	ReturnLikelihoods string   `json:"return_likelihoods"`
   199  }
   200  
   201  type generateResponse struct {
   202  	Generations []generation
   203  	Error       *cohereApiError `json:"error,omitempty"`
   204  }
   205  
   206  type generation struct {
   207  	Text string `json:"text"`
   208  }
   209  
   210  // need to check this
   211  // I think you just get message
   212  type cohereApiError struct {
   213  	Message string `json:"message"`
   214  }