github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-palm/vectorizer/class_settings.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package vectorizer 13 14 import ( 15 "fmt" 16 "strings" 17 18 "github.com/pkg/errors" 19 20 "github.com/weaviate/weaviate/entities/moduletools" 21 basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings" 22 ) 23 24 const ( 25 locationProperty = "location" 26 projectIDProperty = "projectId" 27 modelIDProperty = "modelId" 28 dimensionsProperty = "dimensions" 29 videoIntervalSecondsProperty = "videoIntervalSeconds" 30 ) 31 32 const ( 33 DefaultVectorizeClassName = false 34 DefaultPropertyIndexed = true 35 DefaultVectorizePropertyName = false 36 DefaultApiEndpoint = "us-central1-aiplatform.googleapis.com" 37 DefaultModelID = "multimodalembedding@001" 38 DefaultDime 39 ) 40 41 var ( 42 defaultDimensions1408 = int64(1408) 43 availableDimensions = []int64{128, 256, 512, defaultDimensions1408} 44 defaultVideoIntervalSeconds = int64(120) 45 availableVideoIntervalSeconds = []int64{4, 8, 15, defaultVideoIntervalSeconds} 46 ) 47 48 type classSettings struct { 49 base *basesettings.BaseClassSettings 50 cfg moduletools.ClassConfig 51 } 52 53 func NewClassSettings(cfg moduletools.ClassConfig) *classSettings { 54 return &classSettings{cfg: cfg, base: basesettings.NewBaseClassSettings(cfg)} 55 } 56 57 // PaLM params 58 func (ic *classSettings) Location() string { 59 return ic.getStringProperty(locationProperty, "") 60 } 61 62 func (ic *classSettings) ProjectID() string { 63 return ic.getStringProperty(projectIDProperty, "") 64 } 65 66 func (ic *classSettings) ModelID() string { 67 return ic.getStringProperty(modelIDProperty, DefaultModelID) 68 } 69 70 func (ic *classSettings) Dimensions() int64 { 71 return ic.getInt64Property(dimensionsProperty, defaultDimensions1408) 72 } 73 74 func (ic *classSettings) VideoIntervalSeconds() int64 { 75 return ic.getInt64Property(videoIntervalSecondsProperty, defaultVideoIntervalSeconds) 76 } 77 78 // CLIP module specific settings 79 func (ic *classSettings) ImageField(property string) bool { 80 return ic.field("imageFields", property) 81 } 82 83 func (ic *classSettings) ImageFieldsWeights() ([]float32, error) { 84 return ic.getFieldsWeights("image") 85 } 86 87 func (ic *classSettings) TextField(property string) bool { 88 return ic.field("textFields", property) 89 } 90 91 func (ic *classSettings) TextFieldsWeights() ([]float32, error) { 92 return ic.getFieldsWeights("text") 93 } 94 95 func (ic *classSettings) VideoField(property string) bool { 96 return ic.field("videoFields", property) 97 } 98 99 func (ic *classSettings) VideoFieldsWeights() ([]float32, error) { 100 return ic.getFieldsWeights("video") 101 } 102 103 func (ic *classSettings) field(name, property string) bool { 104 if ic.cfg == nil { 105 // we would receive a nil-config on cross-class requests, such as Explore{} 106 return false 107 } 108 109 fields, ok := ic.cfg.ClassByModuleName("multi2vec-palm")[name] 110 if !ok { 111 return false 112 } 113 114 fieldsArray, ok := fields.([]interface{}) 115 if !ok { 116 return false 117 } 118 119 fieldNames := make([]string, len(fieldsArray)) 120 for i, value := range fieldsArray { 121 fieldNames[i] = value.(string) 122 } 123 124 for i := range fieldNames { 125 if fieldNames[i] == property { 126 return true 127 } 128 } 129 130 return false 131 } 132 133 func (ic *classSettings) getStringProperty(name, defaultValue string) string { 134 return ic.base.GetPropertyAsString(name, defaultValue) 135 } 136 137 func (ic *classSettings) getInt64Property(name string, defaultValue int64) int64 { 138 if val := ic.base.GetPropertyAsInt64(name, &defaultValue); val != nil { 139 return *val 140 } 141 return defaultValue 142 } 143 144 func (ic *classSettings) Validate() error { 145 if ic.cfg == nil { 146 // we would receive a nil-config on cross-class requests, such as Explore{} 147 return errors.New("empty config") 148 } 149 150 var errorMessages []string 151 152 model := ic.ModelID() 153 location := ic.Location() 154 if location == "" { 155 errorMessages = append(errorMessages, "location setting needs to be present") 156 } 157 158 projectID := ic.ProjectID() 159 if projectID == "" { 160 errorMessages = append(errorMessages, "projectId setting needs to be present") 161 } 162 163 dimensions := ic.Dimensions() 164 if !validateSetting[int64](dimensions, availableDimensions) { 165 return errors.Errorf("wrong dimensions setting for %s model, available dimensions are: %v", model, availableDimensions) 166 } 167 168 videoIntervalSeconds := ic.VideoIntervalSeconds() 169 if !validateSetting[int64](videoIntervalSeconds, availableVideoIntervalSeconds) { 170 return errors.Errorf("wrong videoIntervalSeconds setting for %s model, available videoIntervalSeconds are: %v", model, availableVideoIntervalSeconds) 171 } 172 173 imageFields, imageFieldsOk := ic.cfg.Class()["imageFields"] 174 textFields, textFieldsOk := ic.cfg.Class()["textFields"] 175 videoFields, videoFieldsOk := ic.cfg.Class()["videoFields"] 176 if !imageFieldsOk && !textFieldsOk && !videoFieldsOk { 177 errorMessages = append(errorMessages, "textFields or imageFields or videoFields setting needs to be present") 178 } 179 180 if videoFieldsOk && dimensions != defaultDimensions1408 { 181 errorMessages = append(errorMessages, fmt.Sprintf("videoFields support only %d dimensions setting", defaultDimensions1408)) 182 } 183 184 if imageFieldsOk { 185 imageFieldsCount, err := ic.validateFields("image", imageFields) 186 if err != nil { 187 errorMessages = append(errorMessages, err.Error()) 188 } 189 err = ic.validateWeights("image", imageFieldsCount) 190 if err != nil { 191 errorMessages = append(errorMessages, err.Error()) 192 } 193 } 194 195 if textFieldsOk { 196 textFieldsCount, err := ic.validateFields("text", textFields) 197 if err != nil { 198 errorMessages = append(errorMessages, err.Error()) 199 } 200 err = ic.validateWeights("text", textFieldsCount) 201 if err != nil { 202 errorMessages = append(errorMessages, err.Error()) 203 } 204 } 205 206 if videoFieldsOk { 207 videoFieldsCount, err := ic.validateFields("video", videoFields) 208 if err != nil { 209 errorMessages = append(errorMessages, err.Error()) 210 } 211 err = ic.validateWeights("video", videoFieldsCount) 212 if err != nil { 213 errorMessages = append(errorMessages, err.Error()) 214 } 215 } 216 217 if len(errorMessages) > 0 { 218 return fmt.Errorf("%s", strings.Join(errorMessages, ", ")) 219 } 220 221 return nil 222 } 223 224 func (ic *classSettings) validateFields(name string, fields interface{}) (int, error) { 225 fieldsArray, ok := fields.([]interface{}) 226 if !ok { 227 return 0, errors.Errorf("%sFields must be an array", name) 228 } 229 230 if len(fieldsArray) == 0 { 231 return 0, errors.Errorf("must contain at least one %s field name in %sFields", name, name) 232 } 233 234 for _, value := range fieldsArray { 235 v, ok := value.(string) 236 if !ok { 237 return 0, errors.Errorf("%sField must be a string", name) 238 } 239 if len(v) == 0 { 240 return 0, errors.Errorf("%sField values cannot be empty", name) 241 } 242 } 243 244 return len(fieldsArray), nil 245 } 246 247 func (ic *classSettings) validateWeights(name string, count int) error { 248 weights, ok := ic.getWeights(name) 249 if ok { 250 if len(weights) != count { 251 return errors.Errorf("weights.%sFields does not equal number of %sFields", name, name) 252 } 253 _, err := ic.getWeightsArray(weights) 254 if err != nil { 255 return err 256 } 257 } 258 259 return nil 260 } 261 262 func (ic *classSettings) getWeights(name string) ([]interface{}, bool) { 263 weights, ok := ic.cfg.Class()["weights"] 264 if ok { 265 weightsObject, ok := weights.(map[string]interface{}) 266 if ok { 267 fieldWeights, ok := weightsObject[fmt.Sprintf("%sFields", name)] 268 if ok { 269 fieldWeightsArray, ok := fieldWeights.([]interface{}) 270 if ok { 271 return fieldWeightsArray, ok 272 } 273 } 274 } 275 } 276 277 return nil, false 278 } 279 280 func (ic *classSettings) getWeightsArray(weights []interface{}) ([]float32, error) { 281 weightsArray := make([]float32, len(weights)) 282 for i := range weights { 283 weight, err := ic.getNumber(weights[i]) 284 if err != nil { 285 return nil, err 286 } 287 weightsArray[i] = weight 288 } 289 return weightsArray, nil 290 } 291 292 func (ic *classSettings) getFieldsWeights(name string) ([]float32, error) { 293 weights, ok := ic.getWeights(name) 294 if ok { 295 return ic.getWeightsArray(weights) 296 } 297 return nil, nil 298 } 299 300 func (ic *classSettings) getNumber(in interface{}) (float32, error) { 301 return ic.base.GetNumber(in) 302 } 303 304 func validateSetting[T string | int64](value T, availableValues []T) bool { 305 for i := range availableValues { 306 if value == availableValues[i] { 307 return true 308 } 309 } 310 return false 311 }