github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-palm/vectorizer/class_settings.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"fmt"
    16  	"strings"
    17  
    18  	"github.com/pkg/errors"
    19  
    20  	"github.com/weaviate/weaviate/entities/moduletools"
    21  	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
    22  )
    23  
    24  const (
    25  	locationProperty             = "location"
    26  	projectIDProperty            = "projectId"
    27  	modelIDProperty              = "modelId"
    28  	dimensionsProperty           = "dimensions"
    29  	videoIntervalSecondsProperty = "videoIntervalSeconds"
    30  )
    31  
    32  const (
    33  	DefaultVectorizeClassName    = false
    34  	DefaultPropertyIndexed       = true
    35  	DefaultVectorizePropertyName = false
    36  	DefaultApiEndpoint           = "us-central1-aiplatform.googleapis.com"
    37  	DefaultModelID               = "multimodalembedding@001"
    38  	DefaultDime
    39  )
    40  
    41  var (
    42  	defaultDimensions1408         = int64(1408)
    43  	availableDimensions           = []int64{128, 256, 512, defaultDimensions1408}
    44  	defaultVideoIntervalSeconds   = int64(120)
    45  	availableVideoIntervalSeconds = []int64{4, 8, 15, defaultVideoIntervalSeconds}
    46  )
    47  
    48  type classSettings struct {
    49  	base *basesettings.BaseClassSettings
    50  	cfg  moduletools.ClassConfig
    51  }
    52  
    53  func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
    54  	return &classSettings{cfg: cfg, base: basesettings.NewBaseClassSettings(cfg)}
    55  }
    56  
    57  // PaLM params
    58  func (ic *classSettings) Location() string {
    59  	return ic.getStringProperty(locationProperty, "")
    60  }
    61  
    62  func (ic *classSettings) ProjectID() string {
    63  	return ic.getStringProperty(projectIDProperty, "")
    64  }
    65  
    66  func (ic *classSettings) ModelID() string {
    67  	return ic.getStringProperty(modelIDProperty, DefaultModelID)
    68  }
    69  
    70  func (ic *classSettings) Dimensions() int64 {
    71  	return ic.getInt64Property(dimensionsProperty, defaultDimensions1408)
    72  }
    73  
    74  func (ic *classSettings) VideoIntervalSeconds() int64 {
    75  	return ic.getInt64Property(videoIntervalSecondsProperty, defaultVideoIntervalSeconds)
    76  }
    77  
    78  // CLIP module specific settings
    79  func (ic *classSettings) ImageField(property string) bool {
    80  	return ic.field("imageFields", property)
    81  }
    82  
    83  func (ic *classSettings) ImageFieldsWeights() ([]float32, error) {
    84  	return ic.getFieldsWeights("image")
    85  }
    86  
    87  func (ic *classSettings) TextField(property string) bool {
    88  	return ic.field("textFields", property)
    89  }
    90  
    91  func (ic *classSettings) TextFieldsWeights() ([]float32, error) {
    92  	return ic.getFieldsWeights("text")
    93  }
    94  
    95  func (ic *classSettings) VideoField(property string) bool {
    96  	return ic.field("videoFields", property)
    97  }
    98  
    99  func (ic *classSettings) VideoFieldsWeights() ([]float32, error) {
   100  	return ic.getFieldsWeights("video")
   101  }
   102  
   103  func (ic *classSettings) field(name, property string) bool {
   104  	if ic.cfg == nil {
   105  		// we would receive a nil-config on cross-class requests, such as Explore{}
   106  		return false
   107  	}
   108  
   109  	fields, ok := ic.cfg.ClassByModuleName("multi2vec-palm")[name]
   110  	if !ok {
   111  		return false
   112  	}
   113  
   114  	fieldsArray, ok := fields.([]interface{})
   115  	if !ok {
   116  		return false
   117  	}
   118  
   119  	fieldNames := make([]string, len(fieldsArray))
   120  	for i, value := range fieldsArray {
   121  		fieldNames[i] = value.(string)
   122  	}
   123  
   124  	for i := range fieldNames {
   125  		if fieldNames[i] == property {
   126  			return true
   127  		}
   128  	}
   129  
   130  	return false
   131  }
   132  
   133  func (ic *classSettings) getStringProperty(name, defaultValue string) string {
   134  	return ic.base.GetPropertyAsString(name, defaultValue)
   135  }
   136  
   137  func (ic *classSettings) getInt64Property(name string, defaultValue int64) int64 {
   138  	if val := ic.base.GetPropertyAsInt64(name, &defaultValue); val != nil {
   139  		return *val
   140  	}
   141  	return defaultValue
   142  }
   143  
   144  func (ic *classSettings) Validate() error {
   145  	if ic.cfg == nil {
   146  		// we would receive a nil-config on cross-class requests, such as Explore{}
   147  		return errors.New("empty config")
   148  	}
   149  
   150  	var errorMessages []string
   151  
   152  	model := ic.ModelID()
   153  	location := ic.Location()
   154  	if location == "" {
   155  		errorMessages = append(errorMessages, "location setting needs to be present")
   156  	}
   157  
   158  	projectID := ic.ProjectID()
   159  	if projectID == "" {
   160  		errorMessages = append(errorMessages, "projectId setting needs to be present")
   161  	}
   162  
   163  	dimensions := ic.Dimensions()
   164  	if !validateSetting[int64](dimensions, availableDimensions) {
   165  		return errors.Errorf("wrong dimensions setting for %s model, available dimensions are: %v", model, availableDimensions)
   166  	}
   167  
   168  	videoIntervalSeconds := ic.VideoIntervalSeconds()
   169  	if !validateSetting[int64](videoIntervalSeconds, availableVideoIntervalSeconds) {
   170  		return errors.Errorf("wrong videoIntervalSeconds setting for %s model, available videoIntervalSeconds are: %v", model, availableVideoIntervalSeconds)
   171  	}
   172  
   173  	imageFields, imageFieldsOk := ic.cfg.Class()["imageFields"]
   174  	textFields, textFieldsOk := ic.cfg.Class()["textFields"]
   175  	videoFields, videoFieldsOk := ic.cfg.Class()["videoFields"]
   176  	if !imageFieldsOk && !textFieldsOk && !videoFieldsOk {
   177  		errorMessages = append(errorMessages, "textFields or imageFields or videoFields setting needs to be present")
   178  	}
   179  
   180  	if videoFieldsOk && dimensions != defaultDimensions1408 {
   181  		errorMessages = append(errorMessages, fmt.Sprintf("videoFields support only %d dimensions setting", defaultDimensions1408))
   182  	}
   183  
   184  	if imageFieldsOk {
   185  		imageFieldsCount, err := ic.validateFields("image", imageFields)
   186  		if err != nil {
   187  			errorMessages = append(errorMessages, err.Error())
   188  		}
   189  		err = ic.validateWeights("image", imageFieldsCount)
   190  		if err != nil {
   191  			errorMessages = append(errorMessages, err.Error())
   192  		}
   193  	}
   194  
   195  	if textFieldsOk {
   196  		textFieldsCount, err := ic.validateFields("text", textFields)
   197  		if err != nil {
   198  			errorMessages = append(errorMessages, err.Error())
   199  		}
   200  		err = ic.validateWeights("text", textFieldsCount)
   201  		if err != nil {
   202  			errorMessages = append(errorMessages, err.Error())
   203  		}
   204  	}
   205  
   206  	if videoFieldsOk {
   207  		videoFieldsCount, err := ic.validateFields("video", videoFields)
   208  		if err != nil {
   209  			errorMessages = append(errorMessages, err.Error())
   210  		}
   211  		err = ic.validateWeights("video", videoFieldsCount)
   212  		if err != nil {
   213  			errorMessages = append(errorMessages, err.Error())
   214  		}
   215  	}
   216  
   217  	if len(errorMessages) > 0 {
   218  		return fmt.Errorf("%s", strings.Join(errorMessages, ", "))
   219  	}
   220  
   221  	return nil
   222  }
   223  
   224  func (ic *classSettings) validateFields(name string, fields interface{}) (int, error) {
   225  	fieldsArray, ok := fields.([]interface{})
   226  	if !ok {
   227  		return 0, errors.Errorf("%sFields must be an array", name)
   228  	}
   229  
   230  	if len(fieldsArray) == 0 {
   231  		return 0, errors.Errorf("must contain at least one %s field name in %sFields", name, name)
   232  	}
   233  
   234  	for _, value := range fieldsArray {
   235  		v, ok := value.(string)
   236  		if !ok {
   237  			return 0, errors.Errorf("%sField must be a string", name)
   238  		}
   239  		if len(v) == 0 {
   240  			return 0, errors.Errorf("%sField values cannot be empty", name)
   241  		}
   242  	}
   243  
   244  	return len(fieldsArray), nil
   245  }
   246  
   247  func (ic *classSettings) validateWeights(name string, count int) error {
   248  	weights, ok := ic.getWeights(name)
   249  	if ok {
   250  		if len(weights) != count {
   251  			return errors.Errorf("weights.%sFields does not equal number of %sFields", name, name)
   252  		}
   253  		_, err := ic.getWeightsArray(weights)
   254  		if err != nil {
   255  			return err
   256  		}
   257  	}
   258  
   259  	return nil
   260  }
   261  
   262  func (ic *classSettings) getWeights(name string) ([]interface{}, bool) {
   263  	weights, ok := ic.cfg.Class()["weights"]
   264  	if ok {
   265  		weightsObject, ok := weights.(map[string]interface{})
   266  		if ok {
   267  			fieldWeights, ok := weightsObject[fmt.Sprintf("%sFields", name)]
   268  			if ok {
   269  				fieldWeightsArray, ok := fieldWeights.([]interface{})
   270  				if ok {
   271  					return fieldWeightsArray, ok
   272  				}
   273  			}
   274  		}
   275  	}
   276  
   277  	return nil, false
   278  }
   279  
   280  func (ic *classSettings) getWeightsArray(weights []interface{}) ([]float32, error) {
   281  	weightsArray := make([]float32, len(weights))
   282  	for i := range weights {
   283  		weight, err := ic.getNumber(weights[i])
   284  		if err != nil {
   285  			return nil, err
   286  		}
   287  		weightsArray[i] = weight
   288  	}
   289  	return weightsArray, nil
   290  }
   291  
   292  func (ic *classSettings) getFieldsWeights(name string) ([]float32, error) {
   293  	weights, ok := ic.getWeights(name)
   294  	if ok {
   295  		return ic.getWeightsArray(weights)
   296  	}
   297  	return nil, nil
   298  }
   299  
   300  func (ic *classSettings) getNumber(in interface{}) (float32, error) {
   301  	return ic.base.GetNumber(in)
   302  }
   303  
   304  func validateSetting[T string | int64](value T, availableValues []T) bool {
   305  	for i := range availableValues {
   306  		if value == availableValues[i] {
   307  			return true
   308  		}
   309  	}
   310  	return false
   311  }