github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/vectorizer/class_settings.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package vectorizer 13 14 import ( 15 "fmt" 16 17 "github.com/pkg/errors" 18 19 "github.com/weaviate/weaviate/entities/models" 20 "github.com/weaviate/weaviate/entities/moduletools" 21 "github.com/weaviate/weaviate/entities/schema" 22 basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings" 23 ) 24 25 const ( 26 DefaultHuggingFaceModel = "sentence-transformers/msmarco-bert-base-dot-v5" 27 DefaultOptionWaitForModel = false 28 DefaultOptionUseGPU = false 29 DefaultOptionUseCache = true 30 DefaultVectorizeClassName = true 31 DefaultPropertyIndexed = true 32 DefaultVectorizePropertyName = false 33 ) 34 35 type classSettings struct { 36 basesettings.BaseClassSettings 37 cfg moduletools.ClassConfig 38 } 39 40 func NewClassSettings(cfg moduletools.ClassConfig) *classSettings { 41 return &classSettings{cfg: cfg, BaseClassSettings: *basesettings.NewBaseClassSettings(cfg)} 42 } 43 44 func (cs *classSettings) EndpointURL() string { 45 return cs.getEndpointURL() 46 } 47 48 func (cs *classSettings) PassageModel() string { 49 model := cs.getPassageModel() 50 if model == "" { 51 return DefaultHuggingFaceModel 52 } 53 return model 54 } 55 56 func (cs *classSettings) QueryModel() string { 57 model := cs.getQueryModel() 58 if model == "" { 59 return DefaultHuggingFaceModel 60 } 61 return model 62 } 63 64 func (cs *classSettings) OptionWaitForModel() bool { 65 return cs.getOptionOrDefault("waitForModel", DefaultOptionWaitForModel) 66 } 67 68 func (cs *classSettings) OptionUseGPU() bool { 69 return cs.getOptionOrDefault("useGPU", DefaultOptionUseGPU) 70 } 71 72 func (cs *classSettings) OptionUseCache() bool { 73 return cs.getOptionOrDefault("useCache", DefaultOptionUseCache) 74 } 75 76 func (cs *classSettings) Validate(class *models.Class) error { 77 if cs.cfg == nil { 78 // we would receive a nil-config on cross-class requests, such as Explore{} 79 return errors.New("empty config") 80 } 81 82 err := cs.validateClassSettings() 83 if err != nil { 84 return err 85 } 86 87 err = cs.validateIndexState(class, cs) 88 if err != nil { 89 return err 90 } 91 92 return nil 93 } 94 95 func (cs *classSettings) validateClassSettings() error { 96 if err := cs.BaseClassSettings.Validate(); err != nil { 97 return err 98 } 99 100 endpointURL := cs.getEndpointURL() 101 if endpointURL != "" { 102 // endpoint is set, should be used for feature extraction 103 // all other settings are not relevant 104 return nil 105 } 106 107 model := cs.getProperty("model") 108 passageModel := cs.getProperty("passageModel") 109 queryModel := cs.getProperty("queryModel") 110 111 if model != "" && (passageModel != "" || queryModel != "") { 112 return errors.New("only one setting must be set either 'model' or 'passageModel' with 'queryModel'") 113 } 114 115 if model == "" { 116 if passageModel != "" && queryModel == "" { 117 return errors.New("'passageModel' is set, but 'queryModel' is empty") 118 } 119 if passageModel == "" && queryModel != "" { 120 return errors.New("'queryModel' is set, but 'passageModel' is empty") 121 } 122 } 123 return nil 124 } 125 126 func (cs *classSettings) getPassageModel() string { 127 model := cs.getProperty("model") 128 if model == "" { 129 model = cs.getProperty("passageModel") 130 } 131 return model 132 } 133 134 func (cs *classSettings) getQueryModel() string { 135 model := cs.getProperty("model") 136 if model == "" { 137 model = cs.getProperty("queryModel") 138 } 139 return model 140 } 141 142 func (cs *classSettings) getEndpointURL() string { 143 endpointURL := cs.getProperty("endpointUrl") 144 if endpointURL == "" { 145 endpointURL = cs.getProperty("endpointURL") 146 } 147 return endpointURL 148 } 149 150 func (cs *classSettings) getOption(option string) *bool { 151 if cs.cfg != nil { 152 options, ok := cs.cfg.Class()["options"] 153 if ok { 154 asMap, ok := options.(map[string]interface{}) 155 if ok { 156 option, ok := asMap[option] 157 if ok { 158 asBool, ok := option.(bool) 159 if ok { 160 return &asBool 161 } 162 } 163 } 164 } 165 } 166 return nil 167 } 168 169 func (cs *classSettings) getOptionOrDefault(option string, defaultValue bool) bool { 170 optionValue := cs.getOption(option) 171 if optionValue != nil { 172 return *optionValue 173 } 174 return defaultValue 175 } 176 177 func (cs *classSettings) getProperty(name string) string { 178 return cs.BaseClassSettings.GetPropertyAsString(name, "") 179 } 180 181 func (cs *classSettings) validateIndexState(class *models.Class, settings ClassSettings) error { 182 if settings.VectorizeClassName() { 183 // if the user chooses to vectorize the classname, vector-building will 184 // always be possible, no need to investigate further 185 186 return nil 187 } 188 189 // search if there is at least one indexed, string/text prop. If found pass 190 // validation 191 for _, prop := range class.Properties { 192 if len(prop.DataType) < 1 { 193 return errors.Errorf("property %s must have at least one datatype: "+ 194 "got %v", prop.Name, prop.DataType) 195 } 196 197 if prop.DataType[0] != string(schema.DataTypeText) { 198 // we can only vectorize text-like props 199 continue 200 } 201 202 if settings.PropertyIndexed(prop.Name) { 203 // found at least one, this is a valid schema 204 return nil 205 } 206 } 207 208 return fmt.Errorf("invalid properties: didn't find a single property which is " + 209 "of type string or text and is not excluded from indexing. In addition the " + 210 "class name is excluded from vectorization as well, meaning that it cannot be " + 211 "used to determine the vector position. To fix this, set 'vectorizeClassName' " + 212 "to true if the class name is contextionary-valid. Alternatively add at least " + 213 "contextionary-valid text/string property which is not excluded from " + 214 "indexing.") 215 }