github.com/weaviate/weaviate@v1.24.6/modules/text2vec-openai/vectorizer/class_settings.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"fmt"
    16  
    17  	"github.com/pkg/errors"
    18  
    19  	"github.com/weaviate/weaviate/entities/models"
    20  	"github.com/weaviate/weaviate/entities/moduletools"
    21  	"github.com/weaviate/weaviate/entities/schema"
    22  	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
    23  )
    24  
    25  const (
    26  	DefaultOpenAIDocumentType    = "text"
    27  	DefaultOpenAIModel           = "ada"
    28  	DefaultVectorizeClassName    = true
    29  	DefaultPropertyIndexed       = true
    30  	DefaultVectorizePropertyName = false
    31  	DefaultBaseURL               = "https://api.openai.com"
    32  )
    33  
    34  const (
    35  	TextEmbedding3Small = "text-embedding-3-small"
    36  	TextEmbedding3Large = "text-embedding-3-large"
    37  )
    38  
    39  var (
    40  	TextEmbedding3SmallDefaultDimensions int64 = 1536
    41  	TextEmbedding3LargeDefaultDimensions int64 = 3072
    42  )
    43  
    44  var availableOpenAITypes = []string{"text", "code"}
    45  
    46  var availableV3Models = []string{
    47  	// new v3 models
    48  	TextEmbedding3Small,
    49  	TextEmbedding3Large,
    50  }
    51  
    52  var availableV3ModelsDimensions = map[string][]int64{
    53  	TextEmbedding3Small: {512, TextEmbedding3SmallDefaultDimensions},
    54  	TextEmbedding3Large: {256, 1024, TextEmbedding3LargeDefaultDimensions},
    55  }
    56  
    57  var availableOpenAIModels = []string{
    58  	"ada",     // supports 001 and 002
    59  	"babbage", // only supports 001
    60  	"curie",   // only supports 001
    61  	"davinci", // only supports 001
    62  }
    63  
    64  type classSettings struct {
    65  	basesettings.BaseClassSettings
    66  	cfg moduletools.ClassConfig
    67  }
    68  
    69  func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
    70  	return &classSettings{cfg: cfg, BaseClassSettings: *basesettings.NewBaseClassSettings(cfg)}
    71  }
    72  
    73  func (cs *classSettings) Model() string {
    74  	return cs.getProperty("model", DefaultOpenAIModel)
    75  }
    76  
    77  func (cs *classSettings) Type() string {
    78  	return cs.getProperty("type", DefaultOpenAIDocumentType)
    79  }
    80  
    81  func (cs *classSettings) ModelVersion() string {
    82  	defaultVersion := PickDefaultModelVersion(cs.Model(), cs.Type())
    83  	return cs.getProperty("modelVersion", defaultVersion)
    84  }
    85  
    86  func (cs *classSettings) ResourceName() string {
    87  	return cs.getProperty("resourceName", "")
    88  }
    89  
    90  func (cs *classSettings) BaseURL() string {
    91  	return cs.getProperty("baseURL", DefaultBaseURL)
    92  }
    93  
    94  func (cs *classSettings) DeploymentID() string {
    95  	return cs.getProperty("deploymentId", "")
    96  }
    97  
    98  func (cs *classSettings) IsAzure() bool {
    99  	return cs.ResourceName() != "" && cs.DeploymentID() != ""
   100  }
   101  
   102  func (cs *classSettings) Dimensions() *int64 {
   103  	defaultValue := PickDefaultDimensions(cs.Model())
   104  	return cs.getPropertyAsInt("dimensions", defaultValue)
   105  }
   106  
   107  func (cs *classSettings) Validate(class *models.Class) error {
   108  	if cs.cfg == nil {
   109  		// we would receive a nil-config on cross-class requests, such as Explore{}
   110  		return errors.New("empty config")
   111  	}
   112  
   113  	if err := cs.BaseClassSettings.Validate(); err != nil {
   114  		return err
   115  	}
   116  
   117  	docType := cs.Type()
   118  	if !validateOpenAISetting[string](docType, availableOpenAITypes) {
   119  		return errors.Errorf("wrong OpenAI type name, available model names are: %v", availableOpenAITypes)
   120  	}
   121  
   122  	availableModels := append(availableOpenAIModels, availableV3Models...)
   123  	model := cs.Model()
   124  	if !validateOpenAISetting[string](model, availableModels) {
   125  		return errors.Errorf("wrong OpenAI model name, available model names are: %v", availableModels)
   126  	}
   127  
   128  	dimensions := cs.Dimensions()
   129  	if dimensions != nil {
   130  		if !validateOpenAISetting[string](model, availableV3Models) {
   131  			return errors.Errorf("dimensions setting can only be used with V3 embedding models: %v", availableV3Models)
   132  		}
   133  		availableDimensions := availableV3ModelsDimensions[model]
   134  		if !validateOpenAISetting[int64](*dimensions, availableDimensions) {
   135  			return errors.Errorf("wrong dimensions setting for %s model, available dimensions are: %v", model, availableDimensions)
   136  		}
   137  	}
   138  
   139  	version := cs.ModelVersion()
   140  	if err := cs.validateModelVersion(version, model, docType); err != nil {
   141  		return err
   142  	}
   143  
   144  	err := cs.validateAzureConfig(cs.ResourceName(), cs.DeploymentID())
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	err = cs.validateIndexState(class, cs)
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (cs *classSettings) validateModelVersion(version, model, docType string) error {
   158  	for i := range availableV3Models {
   159  		if model == availableV3Models[i] {
   160  			return nil
   161  		}
   162  	}
   163  
   164  	if version == "001" {
   165  		// no restrictions
   166  		return nil
   167  	}
   168  
   169  	if version == "002" {
   170  		// only ada/davinci 002
   171  		if model != "ada" && model != "davinci" {
   172  			return fmt.Errorf("unsupported version %s", version)
   173  		}
   174  	}
   175  
   176  	if version == "003" && model != "davinci" {
   177  		// only davinci 003
   178  		return fmt.Errorf("unsupported version %s", version)
   179  	}
   180  
   181  	if version != "002" && version != "003" {
   182  		// all other fallback
   183  		return fmt.Errorf("model %s is only available in version 001", model)
   184  	}
   185  
   186  	if docType != "text" {
   187  		return fmt.Errorf("ada-002 no longer distinguishes between text/code, use 'text' for all use cases")
   188  	}
   189  
   190  	return nil
   191  }
   192  
   193  func (cs *classSettings) getProperty(name, defaultValue string) string {
   194  	return cs.BaseClassSettings.GetPropertyAsString(name, defaultValue)
   195  }
   196  
   197  func (cs *classSettings) getPropertyAsInt(name string, defaultValue *int64) *int64 {
   198  	return cs.BaseClassSettings.GetPropertyAsInt64(name, defaultValue)
   199  }
   200  
   201  func (cs *classSettings) validateIndexState(class *models.Class, settings ClassSettings) error {
   202  	if settings.VectorizeClassName() {
   203  		// if the user chooses to vectorize the classname, vector-building will
   204  		// always be possible, no need to investigate further
   205  
   206  		return nil
   207  	}
   208  
   209  	// search if there is at least one indexed, string/text prop. If found pass
   210  	// validation
   211  	for _, prop := range class.Properties {
   212  		if len(prop.DataType) < 1 {
   213  			return errors.Errorf("property %s must have at least one datatype: "+
   214  				"got %v", prop.Name, prop.DataType)
   215  		}
   216  
   217  		if prop.DataType[0] != string(schema.DataTypeText) {
   218  			// we can only vectorize text-like props
   219  			continue
   220  		}
   221  
   222  		if settings.PropertyIndexed(prop.Name) {
   223  			// found at least one, this is a valid schema
   224  			return nil
   225  		}
   226  	}
   227  
   228  	return fmt.Errorf("invalid properties: didn't find a single property which is " +
   229  		"of type string or text and is not excluded from indexing. In addition the " +
   230  		"class name is excluded from vectorization as well, meaning that it cannot be " +
   231  		"used to determine the vector position. To fix this, set 'vectorizeClassName' " +
   232  		"to true if the class name is contextionary-valid. Alternatively add at least " +
   233  		"contextionary-valid text/string property which is not excluded from " +
   234  		"indexing.")
   235  }
   236  
   237  func (cs *classSettings) validateAzureConfig(resourceName string, deploymentId string) error {
   238  	if (resourceName == "" && deploymentId != "") || (resourceName != "" && deploymentId == "") {
   239  		return fmt.Errorf("both resourceName and deploymentId must be provided")
   240  	}
   241  	return nil
   242  }
   243  
   244  func validateOpenAISetting[T string | int64](value T, availableValues []T) bool {
   245  	for i := range availableValues {
   246  		if value == availableValues[i] {
   247  			return true
   248  		}
   249  	}
   250  	return false
   251  }
   252  
   253  func PickDefaultModelVersion(model, docType string) string {
   254  	for i := range availableV3Models {
   255  		if model == availableV3Models[i] {
   256  			return ""
   257  		}
   258  	}
   259  	if model == "ada" && docType == "text" {
   260  		return "002"
   261  	}
   262  	// for all other combinations stick with "001"
   263  	return "001"
   264  }
   265  
   266  func PickDefaultDimensions(model string) *int64 {
   267  	if model == TextEmbedding3Small {
   268  		return &TextEmbedding3SmallDefaultDimensions
   269  	}
   270  	if model == TextEmbedding3Large {
   271  		return &TextEmbedding3LargeDefaultDimensions
   272  	}
   273  	return nil
   274  }