github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-clip/vectorizer/class_settings.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"fmt"
    16  
    17  	"github.com/pkg/errors"
    18  
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
    21  )
    22  
    23  type classSettings struct {
    24  	cfg  moduletools.ClassConfig
    25  	base *basesettings.BaseClassSettings
    26  }
    27  
    28  func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
    29  	return &classSettings{cfg: cfg, base: basesettings.NewBaseClassSettings(cfg)}
    30  }
    31  
    32  func (ic *classSettings) ImageField(property string) bool {
    33  	return ic.field("imageFields", property)
    34  }
    35  
    36  func (ic *classSettings) ImageFieldsWeights() ([]float32, error) {
    37  	return ic.getFieldsWeights("image")
    38  }
    39  
    40  func (ic *classSettings) TextField(property string) bool {
    41  	return ic.field("textFields", property)
    42  }
    43  
    44  func (ic *classSettings) TextFieldsWeights() ([]float32, error) {
    45  	return ic.getFieldsWeights("text")
    46  }
    47  
    48  func (ic *classSettings) InferenceURL() string {
    49  	return ic.base.GetPropertyAsString("inferenceUrl", "")
    50  }
    51  
    52  func (ic *classSettings) field(name, property string) bool {
    53  	if ic.cfg == nil {
    54  		// we would receive a nil-config on cross-class requests, such as Explore{}
    55  		return false
    56  	}
    57  
    58  	fields, ok := ic.cfg.Class()[name]
    59  	if !ok {
    60  		return false
    61  	}
    62  
    63  	fieldsArray, ok := fields.([]interface{})
    64  	if !ok {
    65  		return false
    66  	}
    67  
    68  	fieldNames := make([]string, len(fieldsArray))
    69  	for i, value := range fieldsArray {
    70  		fieldNames[i] = value.(string)
    71  	}
    72  
    73  	for i := range fieldNames {
    74  		if fieldNames[i] == property {
    75  			return true
    76  		}
    77  	}
    78  
    79  	return false
    80  }
    81  
    82  func (ic *classSettings) Validate() error {
    83  	if ic.cfg == nil {
    84  		// we would receive a nil-config on cross-class requests, such as Explore{}
    85  		return errors.New("empty config")
    86  	}
    87  
    88  	imageFields, imageFieldsOk := ic.cfg.Class()["imageFields"]
    89  	textFields, textFieldsOk := ic.cfg.Class()["textFields"]
    90  	if !imageFieldsOk && !textFieldsOk {
    91  		return errors.New("textFields or imageFields setting needs to be present")
    92  	}
    93  
    94  	if imageFieldsOk {
    95  		imageFieldsCount, err := ic.validateFields("image", imageFields)
    96  		if err != nil {
    97  			return err
    98  		}
    99  		err = ic.validateWeights("image", imageFieldsCount)
   100  		if err != nil {
   101  			return err
   102  		}
   103  	}
   104  
   105  	if textFieldsOk {
   106  		textFieldsCount, err := ic.validateFields("text", textFields)
   107  		if err != nil {
   108  			return err
   109  		}
   110  		err = ic.validateWeights("text", textFieldsCount)
   111  		if err != nil {
   112  			return err
   113  		}
   114  	}
   115  
   116  	return nil
   117  }
   118  
   119  func (ic *classSettings) validateFields(name string, fields interface{}) (int, error) {
   120  	fieldsArray, ok := fields.([]interface{})
   121  	if !ok {
   122  		return 0, errors.Errorf("%sFields must be an array", name)
   123  	}
   124  
   125  	if len(fieldsArray) == 0 {
   126  		return 0, errors.Errorf("must contain at least one %s field name in %sFields", name, name)
   127  	}
   128  
   129  	for _, value := range fieldsArray {
   130  		v, ok := value.(string)
   131  		if !ok {
   132  			return 0, errors.Errorf("%sField must be a string", name)
   133  		}
   134  		if len(v) == 0 {
   135  			return 0, errors.Errorf("%sField values cannot be empty", name)
   136  		}
   137  	}
   138  
   139  	return len(fieldsArray), nil
   140  }
   141  
   142  func (ic *classSettings) validateWeights(name string, count int) error {
   143  	weights, ok := ic.getWeights(name)
   144  	if ok {
   145  		if len(weights) != count {
   146  			return errors.Errorf("weights.%sFields does not equal number of %sFields", name, name)
   147  		}
   148  		_, err := ic.getWeightsArray(weights)
   149  		if err != nil {
   150  			return err
   151  		}
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (ic *classSettings) getWeights(name string) ([]interface{}, bool) {
   158  	weights, ok := ic.cfg.Class()["weights"]
   159  	if ok {
   160  		weightsObject, ok := weights.(map[string]interface{})
   161  		if ok {
   162  			fieldWeights, ok := weightsObject[fmt.Sprintf("%sFields", name)]
   163  			if ok {
   164  				fieldWeightsArray, ok := fieldWeights.([]interface{})
   165  				if ok {
   166  					return fieldWeightsArray, ok
   167  				}
   168  			}
   169  		}
   170  	}
   171  
   172  	return nil, false
   173  }
   174  
   175  func (ic *classSettings) getWeightsArray(weights []interface{}) ([]float32, error) {
   176  	weightsArray := make([]float32, len(weights))
   177  	for i := range weights {
   178  		weight, err := ic.getNumber(weights[i])
   179  		if err != nil {
   180  			return nil, err
   181  		}
   182  		weightsArray[i] = weight
   183  	}
   184  	return weightsArray, nil
   185  }
   186  
   187  func (ic *classSettings) getFieldsWeights(name string) ([]float32, error) {
   188  	weights, ok := ic.getWeights(name)
   189  	if ok {
   190  		return ic.getWeightsArray(weights)
   191  	}
   192  	return nil, nil
   193  }
   194  
   195  func (ic *classSettings) getNumber(in interface{}) (float32, error) {
   196  	return ic.base.GetNumber(in)
   197  }