github.com/weaviate/weaviate@v1.24.6/modules/generative-palm/config/class_settings.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"fmt"
    16  	"strings"
    17  
    18  	"github.com/pkg/errors"
    19  	"github.com/weaviate/weaviate/entities/models"
    20  	"github.com/weaviate/weaviate/entities/moduletools"
    21  	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
    22  )
    23  
    24  const (
    25  	apiEndpointProperty = "apiEndpoint"
    26  	projectIDProperty   = "projectId"
    27  	endpointIDProperty  = "endpointId"
    28  	modelIDProperty     = "modelId"
    29  	temperatureProperty = "temperature"
    30  	tokenLimitProperty  = "tokenLimit"
    31  	topPProperty        = "topP"
    32  	topKProperty        = "topK"
    33  )
    34  
    35  var (
    36  	DefaultPaLMApiEndpoint        = "us-central1-aiplatform.googleapis.com"
    37  	DefaultPaLMModel              = "chat-bison"
    38  	DefaultPaLMTemperature        = 0.2
    39  	DefaultTokenLimit             = 256
    40  	DefaultPaLMTopP               = 0.95
    41  	DefaultPaLMTopK               = 40
    42  	DefaulGenerativeAIApiEndpoint = "generativelanguage.googleapis.com"
    43  	DefaulGenerativeAIModelID     = "chat-bison-001"
    44  )
    45  
    46  var supportedGenerativeAIModels = []string{
    47  	DefaulGenerativeAIModelID,
    48  	"gemini-pro",
    49  	"gemini-pro-vision",
    50  	"gemini-ultra",
    51  }
    52  
    53  type ClassSettings interface {
    54  	Validate(class *models.Class) error
    55  	// Module settings
    56  	ApiEndpoint() string
    57  	ProjectID() string
    58  	EndpointID() string
    59  	ModelID() string
    60  
    61  	// parameters
    62  	// 0.0 - 1.0
    63  	Temperature() float64
    64  	// 1 - 1024
    65  	TokenLimit() int
    66  	// 1 - 40
    67  	TopK() int
    68  	// 0.0 - 1.0
    69  	TopP() float64
    70  }
    71  
    72  type classSettings struct {
    73  	cfg                  moduletools.ClassConfig
    74  	propertyValuesHelper basesettings.PropertyValuesHelper
    75  }
    76  
    77  func NewClassSettings(cfg moduletools.ClassConfig) ClassSettings {
    78  	return &classSettings{cfg: cfg, propertyValuesHelper: basesettings.NewPropertyValuesHelper("generative-palm")}
    79  }
    80  
    81  func (ic *classSettings) Validate(class *models.Class) error {
    82  	if ic.cfg == nil {
    83  		// we would receive a nil-config on cross-class requests, such as Explore{}
    84  		return errors.New("empty config")
    85  	}
    86  
    87  	var errorMessages []string
    88  
    89  	apiEndpoint := ic.ApiEndpoint()
    90  	projectID := ic.ProjectID()
    91  	if apiEndpoint != DefaulGenerativeAIApiEndpoint && projectID == "" {
    92  		errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", projectIDProperty))
    93  	}
    94  	temperature := ic.Temperature()
    95  	if temperature < 0 || temperature > 1 {
    96  		errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", temperatureProperty))
    97  	}
    98  	tokenLimit := ic.TokenLimit()
    99  	if tokenLimit < 1 || tokenLimit > 1024 {
   100  		errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 1024", tokenLimitProperty))
   101  	}
   102  	topK := ic.TopK()
   103  	if topK < 1 || topK > 40 {
   104  		errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 40", topKProperty))
   105  	}
   106  	topP := ic.TopP()
   107  	if topP < 0 || topP > 1 {
   108  		errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", topPProperty))
   109  	}
   110  	// Google MakerSuite
   111  	model := ic.ModelID()
   112  	if apiEndpoint == DefaulGenerativeAIApiEndpoint && !contains[string](supportedGenerativeAIModels, model) {
   113  		errorMessages = append(errorMessages, fmt.Sprintf("%s is not supported available models are: %+v", model, supportedGenerativeAIModels))
   114  	}
   115  
   116  	if len(errorMessages) > 0 {
   117  		return fmt.Errorf("%s", strings.Join(errorMessages, ", "))
   118  	}
   119  
   120  	return nil
   121  }
   122  
   123  func (ic *classSettings) getStringProperty(name, defaultValue string) string {
   124  	return ic.propertyValuesHelper.GetPropertyAsString(ic.cfg, name, defaultValue)
   125  }
   126  
   127  func (ic *classSettings) getFloatProperty(name string, defaultValue float64) float64 {
   128  	asFloat64 := ic.propertyValuesHelper.GetPropertyAsFloat64(ic.cfg, name, &defaultValue)
   129  	return *asFloat64
   130  }
   131  
   132  func (ic *classSettings) getIntProperty(name string, defaultValue int) int {
   133  	asInt := ic.propertyValuesHelper.GetPropertyAsInt(ic.cfg, name, &defaultValue)
   134  	return *asInt
   135  }
   136  
   137  func (ic *classSettings) getDefaultModel(apiEndpoint string) string {
   138  	if apiEndpoint == DefaulGenerativeAIApiEndpoint {
   139  		return DefaulGenerativeAIModelID
   140  	}
   141  	return DefaultPaLMModel
   142  }
   143  
   144  // PaLM params
   145  func (ic *classSettings) ApiEndpoint() string {
   146  	return ic.getStringProperty(apiEndpointProperty, DefaultPaLMApiEndpoint)
   147  }
   148  
   149  func (ic *classSettings) ProjectID() string {
   150  	return ic.getStringProperty(projectIDProperty, "")
   151  }
   152  
   153  func (ic *classSettings) EndpointID() string {
   154  	return ic.getStringProperty(endpointIDProperty, "")
   155  }
   156  
   157  func (ic *classSettings) ModelID() string {
   158  	return ic.getStringProperty(modelIDProperty, ic.getDefaultModel(ic.ApiEndpoint()))
   159  }
   160  
   161  // parameters
   162  
   163  // 0.0 - 1.0
   164  func (ic *classSettings) Temperature() float64 {
   165  	return ic.getFloatProperty(temperatureProperty, DefaultPaLMTemperature)
   166  }
   167  
   168  // 1 - 1024
   169  func (ic *classSettings) TokenLimit() int {
   170  	return ic.getIntProperty(tokenLimitProperty, DefaultTokenLimit)
   171  }
   172  
   173  // 1 - 40
   174  func (ic *classSettings) TopK() int {
   175  	return ic.getIntProperty(topKProperty, DefaultPaLMTopK)
   176  }
   177  
   178  // 0.0 - 1.0
   179  func (ic *classSettings) TopP() float64 {
   180  	return ic.getFloatProperty(topPProperty, DefaultPaLMTopP)
   181  }
   182  
   183  func contains[T comparable](s []T, e T) bool {
   184  	for _, v := range s {
   185  		if v == e {
   186  			return true
   187  		}
   188  	}
   189  	return false
   190  }