github.com/weaviate/weaviate@v1.24.6/modules/sum-transformers/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modsum
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    23  	"github.com/weaviate/weaviate/entities/moduletools"
    24  	sumadditional "github.com/weaviate/weaviate/modules/sum-transformers/additional"
    25  	sumadditionalsummary "github.com/weaviate/weaviate/modules/sum-transformers/additional/summary"
    26  	"github.com/weaviate/weaviate/modules/sum-transformers/client"
    27  	"github.com/weaviate/weaviate/modules/sum-transformers/ent"
    28  )
    29  
    30  func New() *SUMModule {
    31  	return &SUMModule{}
    32  }
    33  
    34  type SUMModule struct {
    35  	sum                          sumClient
    36  	additionalPropertiesProvider modulecapabilities.AdditionalProperties
    37  }
    38  
    39  type sumClient interface {
    40  	GetSummary(ctx context.Context, property, text string) ([]ent.SummaryResult, error)
    41  	MetaInfo() (map[string]interface{}, error)
    42  }
    43  
    44  func (m *SUMModule) Name() string {
    45  	return "sum-transformers"
    46  }
    47  
    48  func (m *SUMModule) Type() modulecapabilities.ModuleType {
    49  	return modulecapabilities.Text2TextSummarize
    50  }
    51  
    52  func (m *SUMModule) Init(ctx context.Context,
    53  	params moduletools.ModuleInitParams,
    54  ) error {
    55  	if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
    56  		return errors.Wrap(err, "init additional")
    57  	}
    58  	return nil
    59  }
    60  
    61  func (m *SUMModule) initAdditional(ctx context.Context, timeout time.Duration,
    62  	logger logrus.FieldLogger,
    63  ) error {
    64  	uri := os.Getenv("SUM_INFERENCE_API")
    65  	if uri == "" {
    66  		return errors.Errorf("required variable SUM_INFERENCE_API is not set")
    67  	}
    68  
    69  	client := client.New(uri, timeout, logger)
    70  	if err := client.WaitForStartup(ctx, 1*time.Second); err != nil {
    71  		return errors.Wrap(err, "init remote sum module")
    72  	}
    73  
    74  	m.sum = client
    75  
    76  	tokenProvider := sumadditionalsummary.New(m.sum)
    77  	m.additionalPropertiesProvider = sumadditional.New(tokenProvider)
    78  
    79  	return nil
    80  }
    81  
    82  func (m *SUMModule) RootHandler() http.Handler {
    83  	// TODO: remove once this is a capability interface
    84  	return nil
    85  }
    86  
    87  func (m *SUMModule) MetaInfo() (map[string]interface{}, error) {
    88  	return m.sum.MetaInfo()
    89  }
    90  
    91  func (m *SUMModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
    92  	return m.additionalPropertiesProvider.AdditionalProperties()
    93  }
    94  
    95  // verify we implement the modules.Module interface
    96  var (
    97  	_ = modulecapabilities.Module(New())
    98  	_ = modulecapabilities.AdditionalProperties(New())
    99  	_ = modulecapabilities.MetaProvider(New())
   100  )