github.com/weaviate/weaviate@v1.24.6/modules/reranker-cohere/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modrerankercohere
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    23  	"github.com/weaviate/weaviate/entities/moduletools"
    24  	"github.com/weaviate/weaviate/modules/reranker-cohere/clients"
    25  	rerankeradditional "github.com/weaviate/weaviate/usecases/modulecomponents/additional"
    26  	"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
    27  )
    28  
    29  const Name = "reranker-cohere"
    30  
    31  func New() *ReRankerCohereModule {
    32  	return &ReRankerCohereModule{}
    33  }
    34  
    35  type ReRankerCohereModule struct {
    36  	reranker                     ReRankerCohereClient
    37  	additionalPropertiesProvider modulecapabilities.AdditionalProperties
    38  }
    39  
    40  type ReRankerCohereClient interface {
    41  	Rank(ctx context.Context, query string, documents []string, cfg moduletools.ClassConfig) (*ent.RankResult, error)
    42  	MetaInfo() (map[string]interface{}, error)
    43  }
    44  
    45  func (m *ReRankerCohereModule) Name() string {
    46  	return Name
    47  }
    48  
    49  func (m *ReRankerCohereModule) Type() modulecapabilities.ModuleType {
    50  	return modulecapabilities.Text2TextReranker
    51  }
    52  
    53  func (m *ReRankerCohereModule) Init(ctx context.Context,
    54  	params moduletools.ModuleInitParams,
    55  ) error {
    56  	if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
    57  		return errors.Wrap(err, "init cross encoder")
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (m *ReRankerCohereModule) initAdditional(ctx context.Context, timeout time.Duration,
    64  	logger logrus.FieldLogger,
    65  ) error {
    66  	apiKey := os.Getenv("COHERE_APIKEY")
    67  	client := clients.New(apiKey, timeout, logger)
    68  	m.reranker = client
    69  	m.additionalPropertiesProvider = rerankeradditional.NewRankerProvider(m.reranker)
    70  	return nil
    71  }
    72  
    73  func (m *ReRankerCohereModule) MetaInfo() (map[string]interface{}, error) {
    74  	return m.reranker.MetaInfo()
    75  }
    76  
    77  func (m *ReRankerCohereModule) RootHandler() http.Handler {
    78  	// TODO: remove once this is a capability interface
    79  	return nil
    80  }
    81  
    82  func (m *ReRankerCohereModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
    83  	return m.additionalPropertiesProvider.AdditionalProperties()
    84  }
    85  
    86  // verify we implement the modules.Module interface
    87  var (
    88  	_ = modulecapabilities.Module(New())
    89  	_ = modulecapabilities.AdditionalProperties(New())
    90  	_ = modulecapabilities.MetaProvider(New())
    91  )