github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-clip/vectorizer/class_settings.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package vectorizer 13 14 import ( 15 "fmt" 16 17 "github.com/pkg/errors" 18 19 "github.com/weaviate/weaviate/entities/moduletools" 20 basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings" 21 ) 22 23 type classSettings struct { 24 cfg moduletools.ClassConfig 25 base *basesettings.BaseClassSettings 26 } 27 28 func NewClassSettings(cfg moduletools.ClassConfig) *classSettings { 29 return &classSettings{cfg: cfg, base: basesettings.NewBaseClassSettings(cfg)} 30 } 31 32 func (ic *classSettings) ImageField(property string) bool { 33 return ic.field("imageFields", property) 34 } 35 36 func (ic *classSettings) ImageFieldsWeights() ([]float32, error) { 37 return ic.getFieldsWeights("image") 38 } 39 40 func (ic *classSettings) TextField(property string) bool { 41 return ic.field("textFields", property) 42 } 43 44 func (ic *classSettings) TextFieldsWeights() ([]float32, error) { 45 return ic.getFieldsWeights("text") 46 } 47 48 func (ic *classSettings) InferenceURL() string { 49 return ic.base.GetPropertyAsString("inferenceUrl", "") 50 } 51 52 func (ic *classSettings) field(name, property string) bool { 53 if ic.cfg == nil { 54 // we would receive a nil-config on cross-class requests, such as Explore{} 55 return false 56 } 57 58 fields, ok := ic.cfg.Class()[name] 59 if !ok { 60 return false 61 } 62 63 fieldsArray, ok := fields.([]interface{}) 64 if !ok { 65 return false 66 } 67 68 fieldNames := make([]string, len(fieldsArray)) 69 for i, value := range fieldsArray { 70 fieldNames[i] = value.(string) 71 } 72 73 for i := range fieldNames { 74 if fieldNames[i] == property { 75 return true 76 } 77 } 78 79 return false 80 } 81 82 func (ic *classSettings) Validate() error { 83 if ic.cfg == nil { 84 // we would receive a nil-config on cross-class requests, such as Explore{} 85 return errors.New("empty config") 86 } 87 88 imageFields, imageFieldsOk := ic.cfg.Class()["imageFields"] 89 textFields, textFieldsOk := ic.cfg.Class()["textFields"] 90 if !imageFieldsOk && !textFieldsOk { 91 return errors.New("textFields or imageFields setting needs to be present") 92 } 93 94 if imageFieldsOk { 95 imageFieldsCount, err := ic.validateFields("image", imageFields) 96 if err != nil { 97 return err 98 } 99 err = ic.validateWeights("image", imageFieldsCount) 100 if err != nil { 101 return err 102 } 103 } 104 105 if textFieldsOk { 106 textFieldsCount, err := ic.validateFields("text", textFields) 107 if err != nil { 108 return err 109 } 110 err = ic.validateWeights("text", textFieldsCount) 111 if err != nil { 112 return err 113 } 114 } 115 116 return nil 117 } 118 119 func (ic *classSettings) validateFields(name string, fields interface{}) (int, error) { 120 fieldsArray, ok := fields.([]interface{}) 121 if !ok { 122 return 0, errors.Errorf("%sFields must be an array", name) 123 } 124 125 if len(fieldsArray) == 0 { 126 return 0, errors.Errorf("must contain at least one %s field name in %sFields", name, name) 127 } 128 129 for _, value := range fieldsArray { 130 v, ok := value.(string) 131 if !ok { 132 return 0, errors.Errorf("%sField must be a string", name) 133 } 134 if len(v) == 0 { 135 return 0, errors.Errorf("%sField values cannot be empty", name) 136 } 137 } 138 139 return len(fieldsArray), nil 140 } 141 142 func (ic *classSettings) validateWeights(name string, count int) error { 143 weights, ok := ic.getWeights(name) 144 if ok { 145 if len(weights) != count { 146 return errors.Errorf("weights.%sFields does not equal number of %sFields", name, name) 147 } 148 _, err := ic.getWeightsArray(weights) 149 if err != nil { 150 return err 151 } 152 } 153 154 return nil 155 } 156 157 func (ic *classSettings) getWeights(name string) ([]interface{}, bool) { 158 weights, ok := ic.cfg.Class()["weights"] 159 if ok { 160 weightsObject, ok := weights.(map[string]interface{}) 161 if ok { 162 fieldWeights, ok := weightsObject[fmt.Sprintf("%sFields", name)] 163 if ok { 164 fieldWeightsArray, ok := fieldWeights.([]interface{}) 165 if ok { 166 return fieldWeightsArray, ok 167 } 168 } 169 } 170 } 171 172 return nil, false 173 } 174 175 func (ic *classSettings) getWeightsArray(weights []interface{}) ([]float32, error) { 176 weightsArray := make([]float32, len(weights)) 177 for i := range weights { 178 weight, err := ic.getNumber(weights[i]) 179 if err != nil { 180 return nil, err 181 } 182 weightsArray[i] = weight 183 } 184 return weightsArray, nil 185 } 186 187 func (ic *classSettings) getFieldsWeights(name string) ([]float32, error) { 188 weights, ok := ic.getWeights(name) 189 if ok { 190 return ic.getWeightsArray(weights) 191 } 192 return nil, nil 193 } 194 195 func (ic *classSettings) getNumber(in interface{}) (float32, error) { 196 return ic.base.GetNumber(in) 197 }