github.com/weaviate/weaviate@v1.24.6/usecases/classification/validation.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "fmt" 16 17 "github.com/weaviate/weaviate/entities/errorcompounder" 18 "github.com/weaviate/weaviate/entities/models" 19 "github.com/weaviate/weaviate/entities/schema" 20 schemaUC "github.com/weaviate/weaviate/usecases/schema" 21 ) 22 23 const ( 24 TypeKNN = "knn" 25 TypeContextual = "text2vec-contextionary-contextual" 26 TypeZeroShot = "zeroshot" 27 ) 28 29 type Validator struct { 30 schema schema.Schema 31 errors *errorcompounder.SafeErrorCompounder 32 subject models.Classification 33 } 34 35 func NewValidator(sg schemaUC.SchemaGetter, subject models.Classification) *Validator { 36 schema := sg.GetSchemaSkipAuth() 37 return &Validator{ 38 schema: schema, 39 errors: &errorcompounder.SafeErrorCompounder{}, 40 subject: subject, 41 } 42 } 43 44 func (v *Validator) Do() error { 45 v.validate() 46 47 err := v.errors.ToError() 48 if err != nil { 49 return fmt.Errorf("invalid classification: %v", err) 50 } 51 52 return nil 53 } 54 55 func (v *Validator) validate() { 56 if v.subject.Class == "" { 57 v.errors.Add(fmt.Errorf("class must be set")) 58 return 59 } 60 61 class := v.schema.FindClassByName(schema.ClassName(v.subject.Class)) 62 if class == nil { 63 v.errors.Addf("class '%s' not found in schema", v.subject.Class) 64 return 65 } 66 67 v.contextualTypeFeasibility() 68 v.knnTypeFeasibility() 69 v.basedOnProperties(class) 70 v.classifyProperties(class) 71 } 72 73 func (v *Validator) contextualTypeFeasibility() { 74 if !v.typeText2vecContextionaryContextual() { 75 return 76 } 77 78 if v.subject.Filters != nil && v.subject.Filters.TrainingSetWhere != nil { 79 v.errors.Addf("type is 'text2vec-contextionary-contextual', but 'trainingSetWhere' filter is set, for 'text2vec-contextionary-contextual' there is no training data, instead limit possible target data directly through setting 'targetWhere'") 80 } 81 } 82 83 func (v *Validator) knnTypeFeasibility() { 84 if !v.typeKNN() { 85 return 86 } 87 88 if v.subject.Filters != nil && v.subject.Filters.TargetWhere != nil { 89 v.errors.Addf("type is 'knn', but 'targetWhere' filter is set, for 'knn' you cannot limit target data directly, instead limit training data through setting 'trainingSetWhere'") 90 } 91 } 92 93 func (v *Validator) basedOnProperties(class *models.Class) { 94 if v.subject.BasedOnProperties == nil || len(v.subject.BasedOnProperties) == 0 { 95 v.errors.Addf("basedOnProperties must have at least one property") 96 return 97 } 98 99 if len(v.subject.BasedOnProperties) > 1 { 100 v.errors.Addf("only a single property in basedOnProperties supported at the moment, got %v", 101 v.subject.BasedOnProperties) 102 return 103 } 104 105 for _, prop := range v.subject.BasedOnProperties { 106 v.basedOnProperty(class, prop) 107 } 108 } 109 110 func (v *Validator) basedOnProperty(class *models.Class, propName string) { 111 prop, ok := v.propertyByName(class, propName) 112 if !ok { 113 v.errors.Addf("basedOnProperties: property '%s' does not exist", propName) 114 return 115 } 116 117 dt, err := v.schema.FindPropertyDataType(prop.DataType) 118 if err != nil { 119 v.errors.Addf("basedOnProperties: %v", err) 120 return 121 } 122 123 if !dt.IsPrimitive() { 124 v.errors.Addf("basedOnProperties: property '%s' must be of type 'text'", propName) 125 return 126 } 127 128 if dt.AsPrimitive() != schema.DataTypeText { 129 v.errors.Addf("basedOnProperties: property '%s' must be of type 'text'", propName) 130 return 131 } 132 } 133 134 func (v *Validator) classifyProperties(class *models.Class) { 135 if v.subject.ClassifyProperties == nil || len(v.subject.ClassifyProperties) == 0 { 136 v.errors.Addf("classifyProperties must have at least one property") 137 return 138 } 139 140 for _, prop := range v.subject.ClassifyProperties { 141 v.classifyProperty(class, prop) 142 } 143 } 144 145 func (v *Validator) classifyProperty(class *models.Class, propName string) { 146 prop, ok := v.propertyByName(class, propName) 147 if !ok { 148 v.errors.Addf("classifyProperties: property '%s' does not exist", propName) 149 return 150 } 151 152 dt, err := v.schema.FindPropertyDataType(prop.DataType) 153 if err != nil { 154 v.errors.Addf("classifyProperties: %v", err) 155 return 156 } 157 158 if !dt.IsReference() { 159 v.errors.Addf("classifyProperties: property '%s' must be of reference type (cref)", propName) 160 return 161 } 162 163 if v.typeText2vecContextionaryContextual() { 164 if len(dt.Classes()) > 1 { 165 v.errors.Addf("classifyProperties: property '%s'"+ 166 " has more than one target class, classification of type 'text2vec-contextionary-contextual' requires exactly one target class", propName) 167 return 168 } 169 } 170 } 171 172 func (v *Validator) propertyByName(class *models.Class, propName string) (*models.Property, bool) { 173 for _, prop := range class.Properties { 174 if prop.Name == propName { 175 return prop, true 176 } 177 } 178 179 return nil, false 180 } 181 182 func (v *Validator) typeText2vecContextionaryContextual() bool { 183 if v.subject.Type == "" { 184 return false 185 } 186 187 return v.subject.Type == TypeContextual 188 } 189 190 func (v *Validator) typeKNN() bool { 191 if v.subject.Type == "" { 192 return true 193 } 194 195 return v.subject.Type == TypeKNN 196 }