github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-bind/vectorizer/class_settings.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"fmt"
    16  
    17  	"github.com/pkg/errors"
    18  
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
    21  )
    22  
    23  type classSettings struct {
    24  	cfg  moduletools.ClassConfig
    25  	base *basesettings.BaseClassSettings
    26  }
    27  
    28  func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
    29  	return &classSettings{cfg: cfg, base: basesettings.NewBaseClassSettings(cfg)}
    30  }
    31  
    32  func (ic *classSettings) ImageField(property string) bool {
    33  	return ic.field("imageFields", property)
    34  }
    35  
    36  func (ic *classSettings) ImageFieldsWeights() ([]float32, error) {
    37  	return ic.getFieldsWeights("image")
    38  }
    39  
    40  func (ic *classSettings) TextField(property string) bool {
    41  	return ic.field("textFields", property)
    42  }
    43  
    44  func (ic *classSettings) TextFieldsWeights() ([]float32, error) {
    45  	return ic.getFieldsWeights("text")
    46  }
    47  
    48  func (ic *classSettings) AudioField(property string) bool {
    49  	return ic.field("audioFields", property)
    50  }
    51  
    52  func (ic *classSettings) AudioFieldsWeights() ([]float32, error) {
    53  	return ic.getFieldsWeights("audio")
    54  }
    55  
    56  func (ic *classSettings) VideoField(property string) bool {
    57  	return ic.field("videoFields", property)
    58  }
    59  
    60  func (ic *classSettings) VideoFieldsWeights() ([]float32, error) {
    61  	return ic.getFieldsWeights("video")
    62  }
    63  
    64  func (ic *classSettings) IMUField(property string) bool {
    65  	return ic.field("imuFields", property)
    66  }
    67  
    68  func (ic *classSettings) IMUFieldsWeights() ([]float32, error) {
    69  	return ic.getFieldsWeights("imu")
    70  }
    71  
    72  func (ic *classSettings) ThermalField(property string) bool {
    73  	return ic.field("thermalFields", property)
    74  }
    75  
    76  func (ic *classSettings) ThermalFieldsWeights() ([]float32, error) {
    77  	return ic.getFieldsWeights("thermal")
    78  }
    79  
    80  func (ic *classSettings) DepthField(property string) bool {
    81  	return ic.field("depthFields", property)
    82  }
    83  
    84  func (ic *classSettings) DepthFieldsWeights() ([]float32, error) {
    85  	return ic.getFieldsWeights("depth")
    86  }
    87  
    88  func (ic *classSettings) field(name, property string) bool {
    89  	if ic.cfg == nil {
    90  		// we would receive a nil-config on cross-class requests, such as Explore{}
    91  		return false
    92  	}
    93  
    94  	fields, ok := ic.cfg.Class()[name]
    95  	if !ok {
    96  		return false
    97  	}
    98  
    99  	fieldsArray, ok := fields.([]interface{})
   100  	if !ok {
   101  		return false
   102  	}
   103  
   104  	fieldNames := make([]string, len(fieldsArray))
   105  	for i, value := range fieldsArray {
   106  		fieldNames[i] = value.(string)
   107  	}
   108  
   109  	for i := range fieldNames {
   110  		if fieldNames[i] == property {
   111  			return true
   112  		}
   113  	}
   114  
   115  	return false
   116  }
   117  
   118  func (ic *classSettings) Validate() error {
   119  	if ic.cfg == nil {
   120  		// we would receive a nil-config on cross-class requests, such as Explore{}
   121  		return errors.New("empty config")
   122  	}
   123  
   124  	imageFields, imageFieldsOk := ic.cfg.Class()["imageFields"]
   125  	textFields, textFieldsOk := ic.cfg.Class()["textFields"]
   126  	audioFields, audioFieldsOk := ic.cfg.Class()["audioFields"]
   127  	videoFields, videoFieldsOk := ic.cfg.Class()["videoFields"]
   128  	imuFields, imuFieldsOk := ic.cfg.Class()["imuFields"]
   129  	thermalFields, thermalFieldsOk := ic.cfg.Class()["thermalFields"]
   130  	depthFields, depthFieldsOk := ic.cfg.Class()["depthFields"]
   131  
   132  	if !imageFieldsOk && !textFieldsOk && !audioFieldsOk && !videoFieldsOk &&
   133  		!imuFieldsOk && !thermalFieldsOk && !depthFieldsOk {
   134  		return errors.New("textFields or imageFields or audioFields or videoFields " +
   135  			"or imuFields or thermalFields or depthFields setting needs to be present")
   136  	}
   137  
   138  	if imageFieldsOk {
   139  		if err := ic.validateWeightFieldCount("image", imageFields); err != nil {
   140  			return err
   141  		}
   142  	}
   143  	if textFieldsOk {
   144  		if err := ic.validateWeightFieldCount("text", textFields); err != nil {
   145  			return err
   146  		}
   147  	}
   148  	if audioFieldsOk {
   149  		if err := ic.validateWeightFieldCount("audio", audioFields); err != nil {
   150  			return err
   151  		}
   152  	}
   153  	if videoFieldsOk {
   154  		if err := ic.validateWeightFieldCount("video", videoFields); err != nil {
   155  			return err
   156  		}
   157  	}
   158  	if imuFieldsOk {
   159  		if err := ic.validateWeightFieldCount("imu", imuFields); err != nil {
   160  			return err
   161  		}
   162  	}
   163  	if thermalFieldsOk {
   164  		if err := ic.validateWeightFieldCount("thermal", thermalFields); err != nil {
   165  			return err
   166  		}
   167  	}
   168  	if depthFieldsOk {
   169  		if err := ic.validateWeightFieldCount("depth", depthFields); err != nil {
   170  			return err
   171  		}
   172  	}
   173  
   174  	return nil
   175  }
   176  
   177  func (ic *classSettings) validateWeightFieldCount(name string, fields interface{}) error {
   178  	imageFieldsCount, err := ic.validateFields(name, fields)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	err = ic.validateWeights(name, imageFieldsCount)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	return nil
   187  }
   188  
   189  func (ic *classSettings) validateFields(name string, fields interface{}) (int, error) {
   190  	fieldsArray, ok := fields.([]interface{})
   191  	if !ok {
   192  		return 0, errors.Errorf("%sFields must be an array", name)
   193  	}
   194  
   195  	if len(fieldsArray) == 0 {
   196  		return 0, errors.Errorf("must contain at least one %s field name in %sFields", name, name)
   197  	}
   198  
   199  	for _, value := range fieldsArray {
   200  		v, ok := value.(string)
   201  		if !ok {
   202  			return 0, errors.Errorf("%sField must be a string", name)
   203  		}
   204  		if len(v) == 0 {
   205  			return 0, errors.Errorf("%sField values cannot be empty", name)
   206  		}
   207  	}
   208  
   209  	return len(fieldsArray), nil
   210  }
   211  
   212  func (ic *classSettings) validateWeights(name string, count int) error {
   213  	weights, ok := ic.getWeights(name)
   214  	if ok {
   215  		if len(weights) != count {
   216  			return errors.Errorf("weights.%sFields does not equal number of %sFields", name, name)
   217  		}
   218  		_, err := ic.getWeightsArray(weights)
   219  		if err != nil {
   220  			return err
   221  		}
   222  	}
   223  
   224  	return nil
   225  }
   226  
   227  func (ic *classSettings) getWeights(name string) ([]interface{}, bool) {
   228  	weights, ok := ic.cfg.Class()["weights"]
   229  	if ok {
   230  		weightsObject, ok := weights.(map[string]interface{})
   231  		if ok {
   232  			fieldWeights, ok := weightsObject[fmt.Sprintf("%sFields", name)]
   233  			if ok {
   234  				fieldWeightsArray, ok := fieldWeights.([]interface{})
   235  				if ok {
   236  					return fieldWeightsArray, ok
   237  				}
   238  			}
   239  		}
   240  	}
   241  
   242  	return nil, false
   243  }
   244  
   245  func (ic *classSettings) getWeightsArray(weights []interface{}) ([]float32, error) {
   246  	weightsArray := make([]float32, len(weights))
   247  	for i := range weights {
   248  		weight, err := ic.getNumber(weights[i])
   249  		if err != nil {
   250  			return nil, err
   251  		}
   252  		weightsArray[i] = weight
   253  	}
   254  	return weightsArray, nil
   255  }
   256  
   257  func (ic *classSettings) getFieldsWeights(name string) ([]float32, error) {
   258  	weights, ok := ic.getWeights(name)
   259  	if ok {
   260  		return ic.getWeightsArray(weights)
   261  	}
   262  	return nil, nil
   263  }
   264  
   265  func (ic *classSettings) getNumber(in interface{}) (float32, error) {
   266  	return ic.base.GetNumber(in)
   267  }