github.com/weaviate/weaviate@v1.24.6/modules/reranker-transformers/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modrerankertransformers
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    23  	"github.com/weaviate/weaviate/entities/moduletools"
    24  	client "github.com/weaviate/weaviate/modules/reranker-transformers/clients"
    25  	additionalprovider "github.com/weaviate/weaviate/usecases/modulecomponents/additional"
    26  	"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
    27  )
    28  
    29  const Name = "reranker-transformers"
    30  
    31  func New() *ReRankerModule {
    32  	return &ReRankerModule{}
    33  }
    34  
    35  type ReRankerModule struct {
    36  	reranker                     ReRankerClient
    37  	additionalPropertiesProvider modulecapabilities.AdditionalProperties
    38  }
    39  
    40  type ReRankerClient interface {
    41  	Rank(ctx context.Context, query string, documents []string, cfg moduletools.ClassConfig) (*ent.RankResult, error)
    42  	MetaInfo() (map[string]interface{}, error)
    43  }
    44  
    45  func (m *ReRankerModule) Name() string {
    46  	return Name
    47  }
    48  
    49  func (m *ReRankerModule) Type() modulecapabilities.ModuleType {
    50  	return modulecapabilities.Text2TextReranker
    51  }
    52  
    53  func (m *ReRankerModule) Init(ctx context.Context,
    54  	params moduletools.ModuleInitParams,
    55  ) error {
    56  	if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
    57  		return errors.Wrap(err, "init re encoder")
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (m *ReRankerModule) initAdditional(ctx context.Context, timeout time.Duration,
    64  	logger logrus.FieldLogger,
    65  ) error {
    66  	uri := os.Getenv("RERANKER_INFERENCE_API")
    67  	if uri == "" {
    68  		return errors.Errorf("required variable RERANKER_INFERENCE_API is not set")
    69  	}
    70  
    71  	client := client.New(uri, timeout, logger)
    72  
    73  	m.reranker = client
    74  	if err := client.WaitForStartup(ctx, 1*time.Second); err != nil {
    75  		return errors.Wrap(err, "init remote sum module")
    76  	}
    77  
    78  	m.additionalPropertiesProvider = additionalprovider.NewRankerProvider(client)
    79  	return nil
    80  }
    81  
    82  func (m *ReRankerModule) MetaInfo() (map[string]interface{}, error) {
    83  	return m.reranker.MetaInfo()
    84  }
    85  
    86  func (m *ReRankerModule) RootHandler() http.Handler {
    87  	// TODO: remove once this is a capability interface
    88  	return nil
    89  }
    90  
    91  func (m *ReRankerModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
    92  	return m.additionalPropertiesProvider.AdditionalProperties()
    93  }
    94  
    95  // verify we implement the modules.Module interface
    96  var (
    97  	_ = modulecapabilities.Module(New())
    98  	_ = modulecapabilities.AdditionalProperties(New())
    99  	_ = modulecapabilities.MetaProvider(New())
   100  )