github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-bind/vectorizer/class_settings.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package vectorizer 13 14 import ( 15 "fmt" 16 17 "github.com/pkg/errors" 18 19 "github.com/weaviate/weaviate/entities/moduletools" 20 basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings" 21 ) 22 23 type classSettings struct { 24 cfg moduletools.ClassConfig 25 base *basesettings.BaseClassSettings 26 } 27 28 func NewClassSettings(cfg moduletools.ClassConfig) *classSettings { 29 return &classSettings{cfg: cfg, base: basesettings.NewBaseClassSettings(cfg)} 30 } 31 32 func (ic *classSettings) ImageField(property string) bool { 33 return ic.field("imageFields", property) 34 } 35 36 func (ic *classSettings) ImageFieldsWeights() ([]float32, error) { 37 return ic.getFieldsWeights("image") 38 } 39 40 func (ic *classSettings) TextField(property string) bool { 41 return ic.field("textFields", property) 42 } 43 44 func (ic *classSettings) TextFieldsWeights() ([]float32, error) { 45 return ic.getFieldsWeights("text") 46 } 47 48 func (ic *classSettings) AudioField(property string) bool { 49 return ic.field("audioFields", property) 50 } 51 52 func (ic *classSettings) AudioFieldsWeights() ([]float32, error) { 53 return ic.getFieldsWeights("audio") 54 } 55 56 func (ic *classSettings) VideoField(property string) bool { 57 return ic.field("videoFields", property) 58 } 59 60 func (ic *classSettings) VideoFieldsWeights() ([]float32, error) { 61 return ic.getFieldsWeights("video") 62 } 63 64 func (ic *classSettings) IMUField(property string) bool { 65 return ic.field("imuFields", property) 66 } 67 68 func (ic *classSettings) IMUFieldsWeights() ([]float32, error) { 69 return ic.getFieldsWeights("imu") 70 } 71 72 func (ic *classSettings) ThermalField(property string) bool { 73 return ic.field("thermalFields", property) 74 } 75 76 func (ic *classSettings) ThermalFieldsWeights() ([]float32, error) { 77 return ic.getFieldsWeights("thermal") 78 } 79 80 func (ic *classSettings) DepthField(property string) bool { 81 return ic.field("depthFields", property) 82 } 83 84 func (ic *classSettings) DepthFieldsWeights() ([]float32, error) { 85 return ic.getFieldsWeights("depth") 86 } 87 88 func (ic *classSettings) field(name, property string) bool { 89 if ic.cfg == nil { 90 // we would receive a nil-config on cross-class requests, such as Explore{} 91 return false 92 } 93 94 fields, ok := ic.cfg.Class()[name] 95 if !ok { 96 return false 97 } 98 99 fieldsArray, ok := fields.([]interface{}) 100 if !ok { 101 return false 102 } 103 104 fieldNames := make([]string, len(fieldsArray)) 105 for i, value := range fieldsArray { 106 fieldNames[i] = value.(string) 107 } 108 109 for i := range fieldNames { 110 if fieldNames[i] == property { 111 return true 112 } 113 } 114 115 return false 116 } 117 118 func (ic *classSettings) Validate() error { 119 if ic.cfg == nil { 120 // we would receive a nil-config on cross-class requests, such as Explore{} 121 return errors.New("empty config") 122 } 123 124 imageFields, imageFieldsOk := ic.cfg.Class()["imageFields"] 125 textFields, textFieldsOk := ic.cfg.Class()["textFields"] 126 audioFields, audioFieldsOk := ic.cfg.Class()["audioFields"] 127 videoFields, videoFieldsOk := ic.cfg.Class()["videoFields"] 128 imuFields, imuFieldsOk := ic.cfg.Class()["imuFields"] 129 thermalFields, thermalFieldsOk := ic.cfg.Class()["thermalFields"] 130 depthFields, depthFieldsOk := ic.cfg.Class()["depthFields"] 131 132 if !imageFieldsOk && !textFieldsOk && !audioFieldsOk && !videoFieldsOk && 133 !imuFieldsOk && !thermalFieldsOk && !depthFieldsOk { 134 return errors.New("textFields or imageFields or audioFields or videoFields " + 135 "or imuFields or thermalFields or depthFields setting needs to be present") 136 } 137 138 if imageFieldsOk { 139 if err := ic.validateWeightFieldCount("image", imageFields); err != nil { 140 return err 141 } 142 } 143 if textFieldsOk { 144 if err := ic.validateWeightFieldCount("text", textFields); err != nil { 145 return err 146 } 147 } 148 if audioFieldsOk { 149 if err := ic.validateWeightFieldCount("audio", audioFields); err != nil { 150 return err 151 } 152 } 153 if videoFieldsOk { 154 if err := ic.validateWeightFieldCount("video", videoFields); err != nil { 155 return err 156 } 157 } 158 if imuFieldsOk { 159 if err := ic.validateWeightFieldCount("imu", imuFields); err != nil { 160 return err 161 } 162 } 163 if thermalFieldsOk { 164 if err := ic.validateWeightFieldCount("thermal", thermalFields); err != nil { 165 return err 166 } 167 } 168 if depthFieldsOk { 169 if err := ic.validateWeightFieldCount("depth", depthFields); err != nil { 170 return err 171 } 172 } 173 174 return nil 175 } 176 177 func (ic *classSettings) validateWeightFieldCount(name string, fields interface{}) error { 178 imageFieldsCount, err := ic.validateFields(name, fields) 179 if err != nil { 180 return err 181 } 182 err = ic.validateWeights(name, imageFieldsCount) 183 if err != nil { 184 return err 185 } 186 return nil 187 } 188 189 func (ic *classSettings) validateFields(name string, fields interface{}) (int, error) { 190 fieldsArray, ok := fields.([]interface{}) 191 if !ok { 192 return 0, errors.Errorf("%sFields must be an array", name) 193 } 194 195 if len(fieldsArray) == 0 { 196 return 0, errors.Errorf("must contain at least one %s field name in %sFields", name, name) 197 } 198 199 for _, value := range fieldsArray { 200 v, ok := value.(string) 201 if !ok { 202 return 0, errors.Errorf("%sField must be a string", name) 203 } 204 if len(v) == 0 { 205 return 0, errors.Errorf("%sField values cannot be empty", name) 206 } 207 } 208 209 return len(fieldsArray), nil 210 } 211 212 func (ic *classSettings) validateWeights(name string, count int) error { 213 weights, ok := ic.getWeights(name) 214 if ok { 215 if len(weights) != count { 216 return errors.Errorf("weights.%sFields does not equal number of %sFields", name, name) 217 } 218 _, err := ic.getWeightsArray(weights) 219 if err != nil { 220 return err 221 } 222 } 223 224 return nil 225 } 226 227 func (ic *classSettings) getWeights(name string) ([]interface{}, bool) { 228 weights, ok := ic.cfg.Class()["weights"] 229 if ok { 230 weightsObject, ok := weights.(map[string]interface{}) 231 if ok { 232 fieldWeights, ok := weightsObject[fmt.Sprintf("%sFields", name)] 233 if ok { 234 fieldWeightsArray, ok := fieldWeights.([]interface{}) 235 if ok { 236 return fieldWeightsArray, ok 237 } 238 } 239 } 240 } 241 242 return nil, false 243 } 244 245 func (ic *classSettings) getWeightsArray(weights []interface{}) ([]float32, error) { 246 weightsArray := make([]float32, len(weights)) 247 for i := range weights { 248 weight, err := ic.getNumber(weights[i]) 249 if err != nil { 250 return nil, err 251 } 252 weightsArray[i] = weight 253 } 254 return weightsArray, nil 255 } 256 257 func (ic *classSettings) getFieldsWeights(name string) ([]float32, error) { 258 weights, ok := ic.getWeights(name) 259 if ok { 260 return ic.getWeightsArray(weights) 261 } 262 return nil, nil 263 } 264 265 func (ic *classSettings) getNumber(in interface{}) (float32, error) { 266 return ic.base.GetNumber(in) 267 }