github.com/weaviate/weaviate@v1.24.6/usecases/classification/classifier.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "context" 16 "encoding/json" 17 "fmt" 18 "time" 19 20 enterrors "github.com/weaviate/weaviate/entities/errors" 21 22 "github.com/go-openapi/strfmt" 23 "github.com/google/uuid" 24 "github.com/pkg/errors" 25 "github.com/sirupsen/logrus" 26 "github.com/weaviate/weaviate/adapters/handlers/rest/filterext" 27 "github.com/weaviate/weaviate/entities/additional" 28 "github.com/weaviate/weaviate/entities/dto" 29 libfilters "github.com/weaviate/weaviate/entities/filters" 30 "github.com/weaviate/weaviate/entities/models" 31 "github.com/weaviate/weaviate/entities/modulecapabilities" 32 "github.com/weaviate/weaviate/entities/search" 33 "github.com/weaviate/weaviate/usecases/objects" 34 schemaUC "github.com/weaviate/weaviate/usecases/schema" 35 libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer" 36 ) 37 38 type classificationFilters struct { 39 source *libfilters.LocalFilter 40 target *libfilters.LocalFilter 41 trainingSet *libfilters.LocalFilter 42 } 43 44 func (f classificationFilters) Source() *libfilters.LocalFilter { 45 return f.source 46 } 47 48 func (f classificationFilters) Target() *libfilters.LocalFilter { 49 return f.target 50 } 51 52 func (f classificationFilters) TrainingSet() *libfilters.LocalFilter { 53 return f.trainingSet 54 } 55 56 type distancer func(a, b []float32) (float32, error) 57 58 type Classifier struct { 59 schemaGetter schemaUC.SchemaGetter 60 repo Repo 61 vectorRepo vectorRepo 62 vectorClassSearchRepo modulecapabilities.VectorClassSearchRepo 63 authorizer authorizer 64 distancer distancer 65 modulesProvider ModulesProvider 66 logger logrus.FieldLogger 67 } 68 69 type authorizer interface { 70 Authorize(principal *models.Principal, verb, resource string) error 71 } 72 73 type ModulesProvider interface { 74 ParseClassifierSettings(name string, 75 params *models.Classification) error 76 GetClassificationFn(className, name string, 77 params modulecapabilities.ClassifyParams) (modulecapabilities.ClassifyItemFn, error) 78 } 79 80 func New(sg schemaUC.SchemaGetter, cr Repo, vr vectorRepo, authorizer authorizer, 81 logger logrus.FieldLogger, modulesProvider ModulesProvider, 82 ) *Classifier { 83 return &Classifier{ 84 logger: logger, 85 schemaGetter: sg, 86 repo: cr, 87 vectorRepo: vr, 88 authorizer: authorizer, 89 distancer: libvectorizer.NormalizedDistance, 90 vectorClassSearchRepo: newVectorClassSearchRepo(vr), 91 modulesProvider: modulesProvider, 92 } 93 } 94 95 // Repo to manage classification state, should be consistent, not used to store 96 // actual data object vectors, see VectorRepo 97 type Repo interface { 98 Put(ctx context.Context, classification models.Classification) error 99 Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error) 100 } 101 102 type VectorRepo interface { 103 GetUnclassified(ctx context.Context, class string, 104 properties []string, filter *libfilters.LocalFilter) ([]search.Result, error) 105 AggregateNeighbors(ctx context.Context, vector []float32, 106 class string, properties []string, k int, 107 filter *libfilters.LocalFilter) ([]NeighborRef, error) 108 VectorSearch(ctx context.Context, params dto.GetParams) ([]search.Result, error) 109 ZeroShotSearch(ctx context.Context, vector []float32, 110 class string, properties []string, 111 filter *libfilters.LocalFilter) ([]search.Result, error) 112 } 113 114 type vectorRepo interface { 115 VectorRepo 116 BatchPutObjects(ctx context.Context, objects objects.BatchObjects, 117 repl *additional.ReplicationProperties) (objects.BatchObjects, error) 118 } 119 120 // NeighborRef is the result of an aggregation of the ref properties of k 121 // neighbors 122 type NeighborRef struct { 123 // Property indicates which property was aggregated 124 Property string 125 126 // The beacon of the most common (kNN) reference 127 Beacon strfmt.URI 128 129 OverallCount int 130 WinningCount int 131 LosingCount int 132 133 Distances NeighborRefDistances 134 } 135 136 func (c *Classifier) Schedule(ctx context.Context, principal *models.Principal, params models.Classification) (*models.Classification, error) { 137 err := c.authorizer.Authorize(principal, "create", "classifications/*") 138 if err != nil { 139 return nil, err 140 } 141 142 err = c.parseAndSetDefaults(¶ms) 143 if err != nil { 144 return nil, err 145 } 146 147 err = NewValidator(c.schemaGetter, params).Do() 148 if err != nil { 149 return nil, err 150 } 151 152 if err := c.assignNewID(¶ms); err != nil { 153 return nil, fmt.Errorf("classification: assign id: %v", err) 154 } 155 156 params.Status = models.ClassificationStatusRunning 157 params.Meta = &models.ClassificationMeta{ 158 Started: strfmt.DateTime(time.Now()), 159 } 160 161 if err := c.repo.Put(ctx, params); err != nil { 162 return nil, fmt.Errorf("classification: put: %v", err) 163 } 164 165 // asynchronously trigger the classification 166 filters, err := c.extractFilters(params) 167 if err != nil { 168 return nil, err 169 } 170 171 enterrors.GoWrapper(func() { c.run(params, filters) }, c.logger) 172 173 return ¶ms, nil 174 } 175 176 func (c *Classifier) extractFilters(params models.Classification) (Filters, error) { 177 if params.Filters == nil { 178 return classificationFilters{}, nil 179 } 180 181 source, err := filterext.Parse(params.Filters.SourceWhere, params.Class) 182 if err != nil { 183 return classificationFilters{}, fmt.Errorf("field 'sourceWhere': %v", err) 184 } 185 186 trainingSet, err := filterext.Parse(params.Filters.TrainingSetWhere, params.Class) 187 if err != nil { 188 return classificationFilters{}, fmt.Errorf("field 'trainingSetWhere': %v", err) 189 } 190 191 target, err := filterext.Parse(params.Filters.TargetWhere, params.Class) 192 if err != nil { 193 return classificationFilters{}, fmt.Errorf("field 'targetWhere': %v", err) 194 } 195 196 filters := classificationFilters{ 197 source: source, 198 trainingSet: trainingSet, 199 target: target, 200 } 201 202 if err = c.validateFilters(¶ms, &filters); err != nil { 203 return nil, err 204 } 205 206 return filters, nil 207 } 208 209 func (c *Classifier) validateFilters(params *models.Classification, filters *classificationFilters) (err error) { 210 if params.Type == TypeKNN { 211 if err = c.validateFilter(filters.Source()); err != nil { 212 return fmt.Errorf("invalid sourceWhere: %s", err) 213 } 214 if err = c.validateFilter(filters.TrainingSet()); err != nil { 215 return fmt.Errorf("invalid trainingSetWhere: %s", err) 216 } 217 } 218 219 if params.Type == TypeContextual || params.Type == TypeZeroShot { 220 if err = c.validateFilter(filters.Source()); err != nil { 221 return fmt.Errorf("invalid sourceWhere: %s", err) 222 } 223 if err = c.validateFilter(filters.Target()); err != nil { 224 return fmt.Errorf("invalid targetWhere: %s", err) 225 } 226 } 227 228 return 229 } 230 231 func (c *Classifier) validateFilter(filter *libfilters.LocalFilter) error { 232 if filter == nil { 233 return nil 234 } 235 return libfilters.ValidateFilters(c.schemaGetter.GetSchemaSkipAuth(), filter) 236 } 237 238 func (c *Classifier) assignNewID(params *models.Classification) error { 239 id, err := uuid.NewRandom() 240 if err != nil { 241 return err 242 } 243 244 params.ID = strfmt.UUID(id.String()) 245 return nil 246 } 247 248 func (c *Classifier) Get(ctx context.Context, principal *models.Principal, id strfmt.UUID) (*models.Classification, error) { 249 err := c.authorizer.Authorize(principal, "get", "classifications/*") 250 if err != nil { 251 return nil, err 252 } 253 254 return c.repo.Get(ctx, id) 255 } 256 257 func (c *Classifier) parseAndSetDefaults(params *models.Classification) error { 258 if params.Type == "" { 259 defaultType := "knn" 260 params.Type = defaultType 261 } 262 263 if params.Type == "knn" { 264 if err := c.parseKNNSettings(params); err != nil { 265 return errors.Wrapf(err, "parse knn specific settings") 266 } 267 return nil 268 } 269 270 if c.modulesProvider != nil { 271 if err := c.modulesProvider.ParseClassifierSettings(params.Type, params); err != nil { 272 return errors.Wrapf(err, "parse %s specific settings", params.Type) 273 } 274 return nil 275 } 276 277 return nil 278 } 279 280 func (c *Classifier) parseKNNSettings(params *models.Classification) error { 281 raw := params.Settings 282 settings := &ParamsKNN{} 283 if raw == nil { 284 settings.SetDefaults() 285 params.Settings = settings 286 return nil 287 } 288 289 asMap, ok := raw.(map[string]interface{}) 290 if !ok { 291 return errors.Errorf("settings must be an object got %T", raw) 292 } 293 294 v, err := extractNumberFromMap(asMap, "k") 295 if err != nil { 296 return err 297 } 298 settings.K = v 299 300 settings.SetDefaults() 301 params.Settings = settings 302 303 return nil 304 } 305 306 type ParamsKNN struct { 307 K *int32 `json:"k"` 308 } 309 310 func (params *ParamsKNN) SetDefaults() { 311 if params.K == nil { 312 defaultK := int32(3) 313 params.K = &defaultK 314 } 315 } 316 317 func extractNumberFromMap(in map[string]interface{}, field string) (*int32, error) { 318 unparsed, present := in[field] 319 if present { 320 parsed, ok := unparsed.(json.Number) 321 if !ok { 322 return nil, errors.Errorf("settings.%s must be number, got %T", 323 field, unparsed) 324 } 325 326 asInt64, err := parsed.Int64() 327 if err != nil { 328 return nil, errors.Wrapf(err, "settings.%s", field) 329 } 330 331 asInt32 := int32(asInt64) 332 return &asInt32, nil 333 } 334 335 return nil, nil 336 }