github.com/weaviate/weaviate@v1.24.6/usecases/classification/classifier_run.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "context" 16 "fmt" 17 "runtime" 18 "time" 19 20 "github.com/go-openapi/strfmt" 21 "github.com/pkg/errors" 22 "github.com/sirupsen/logrus" 23 "github.com/weaviate/weaviate/entities/additional" 24 "github.com/weaviate/weaviate/entities/models" 25 "github.com/weaviate/weaviate/entities/modulecapabilities" 26 "github.com/weaviate/weaviate/entities/schema" 27 "github.com/weaviate/weaviate/entities/search" 28 ) 29 30 // the contents of this file deal with anything about a classification run 31 // which is generic, whereas the individual classify_item fns can be found in 32 // the respective files such as classifier_run_knn.go 33 34 func (c *Classifier) run(params models.Classification, 35 filters Filters, 36 ) { 37 ctx, cancel := contextWithTimeout(30 * time.Minute) 38 defer cancel() 39 40 go c.monitorClassification(ctx, cancel, schema.ClassName(params.Class)) 41 42 c.logBegin(params, filters) 43 unclassifiedItems, err := c.vectorRepo.GetUnclassified(ctx, 44 params.Class, params.ClassifyProperties, filters.Source()) 45 if err != nil { 46 c.failRunWithError(params, errors.Wrap(err, "retrieve to-be-classifieds")) 47 return 48 } 49 50 if len(unclassifiedItems) == 0 { 51 c.failRunWithError(params, 52 fmt.Errorf("no classes to be classified - did you run a previous classification already?")) 53 return 54 } 55 c.logItemsFetched(params, unclassifiedItems) 56 57 classifyItem, err := c.prepareRun(params, filters, unclassifiedItems) 58 if err != nil { 59 c.failRunWithError(params, errors.Wrap(err, "prepare classification")) 60 return 61 } 62 63 params, err = c.runItems(ctx, classifyItem, params, filters, unclassifiedItems) 64 if err != nil { 65 c.failRunWithError(params, err) 66 return 67 } 68 69 c.succeedRun(params) 70 } 71 72 func (c *Classifier) monitorClassification(ctx context.Context, cancelFn context.CancelFunc, 73 className schema.ClassName, 74 ) { 75 ticker := time.NewTicker(100 * time.Millisecond) 76 defer ticker.Stop() 77 for { 78 select { 79 case <-ctx.Done(): 80 return 81 case <-ticker.C: 82 schema := c.schemaGetter.GetSchemaSkipAuth() 83 class := schema.FindClassByName(className) 84 if class == nil { 85 cancelFn() 86 return 87 } 88 } 89 } 90 } 91 92 func (c *Classifier) prepareRun(params models.Classification, filters Filters, 93 unclassifiedItems []search.Result, 94 ) (ClassifyItemFn, error) { 95 c.logBeginPreparation(params) 96 defer c.logFinishPreparation(params) 97 98 if params.Type == "knn" { 99 return c.classifyItemUsingKNN, nil 100 } 101 102 if params.Type == "zeroshot" { 103 return c.classifyItemUsingZeroShot, nil 104 } 105 106 if c.modulesProvider != nil { 107 classifyItemFn, err := c.modulesProvider.GetClassificationFn(params.Class, params.Type, 108 c.getClassifyParams(params, filters, unclassifiedItems)) 109 if err != nil { 110 return nil, errors.Wrapf(err, "cannot classify") 111 } 112 if classifyItemFn == nil { 113 return nil, errors.Errorf("cannot classify: empty classifier for %s", params.Type) 114 } 115 classification := &moduleClassification{classifyItemFn} 116 return classification.classifyFn, nil 117 } 118 119 return nil, errors.Errorf("unsupported type '%s', have no classify item fn for this", params.Type) 120 } 121 122 func (c *Classifier) getClassifyParams(params models.Classification, 123 filters Filters, unclassifiedItems []search.Result, 124 ) modulecapabilities.ClassifyParams { 125 return modulecapabilities.ClassifyParams{ 126 Schema: c.schemaGetter.GetSchemaSkipAuth(), 127 Params: params, 128 Filters: filters, 129 UnclassifiedItems: unclassifiedItems, 130 VectorRepo: c.vectorClassSearchRepo, 131 } 132 } 133 134 // runItems splits the job list into batches that can be worked on parallelly 135 // depending on the available CPUs 136 func (c *Classifier) runItems(ctx context.Context, classifyItem ClassifyItemFn, params models.Classification, filters Filters, 137 items []search.Result, 138 ) (models.Classification, error) { 139 workerCount := runtime.GOMAXPROCS(0) 140 if len(items) < workerCount { 141 workerCount = len(items) 142 } 143 144 workers := newRunWorkers(workerCount, classifyItem, params, filters, c.vectorRepo, c.logger) 145 workers.addJobs(items) 146 res := workers.work(ctx) 147 148 params.Meta.Completed = strfmt.DateTime(time.Now()) 149 params.Meta.CountSucceeded = res.successCount 150 params.Meta.CountFailed = res.errorCount 151 params.Meta.Count = res.successCount + res.errorCount 152 153 return params, res.err 154 } 155 156 func (c *Classifier) succeedRun(params models.Classification) { 157 params.Status = models.ClassificationStatusCompleted 158 ctx, cancel := contextWithTimeout(2 * time.Second) 159 defer cancel() 160 err := c.repo.Put(ctx, params) 161 if err != nil { 162 c.logExecutionError("store succeeded run", err, params) 163 } 164 c.logFinish(params) 165 } 166 167 func (c *Classifier) failRunWithError(params models.Classification, err error) { 168 params.Status = models.ClassificationStatusFailed 169 params.Error = fmt.Sprintf("classification failed: %v", err) 170 err = c.repo.Put(context.Background(), params) 171 if err != nil { 172 c.logExecutionError("store failed run", err, params) 173 } 174 c.logFinish(params) 175 } 176 177 func (c *Classifier) extendItemWithObjectMeta(item *search.Result, 178 params models.Classification, classified []string, 179 ) { 180 // don't overwrite existing non-classification meta info 181 if item.AdditionalProperties == nil { 182 item.AdditionalProperties = models.AdditionalProperties{} 183 } 184 185 item.AdditionalProperties["classification"] = additional.Classification{ 186 ID: params.ID, 187 Scope: params.ClassifyProperties, 188 ClassifiedFields: classified, 189 Completed: strfmt.DateTime(time.Now()), 190 } 191 } 192 193 func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) { 194 return context.WithTimeout(context.Background(), d) 195 } 196 197 // Logging helper methods 198 func (c *Classifier) logBase(params models.Classification, event string) *logrus.Entry { 199 return c.logger.WithField("action", "classification_run"). 200 WithField("event", event). 201 WithField("params", params). 202 WithField("classification_type", params.Type) 203 } 204 205 func (c *Classifier) logBegin(params models.Classification, filters Filters) { 206 c.logBase(params, "classification_begin"). 207 WithField("filters", filters). 208 Debug("classification started") 209 } 210 211 func (c *Classifier) logFinish(params models.Classification) { 212 c.logBase(params, "classification_finish"). 213 WithField("status", params.Status). 214 Debug("classification finished") 215 } 216 217 func (c *Classifier) logItemsFetched(params models.Classification, items search.Results) { 218 c.logBase(params, "classification_items_fetched"). 219 WithField("status", params.Status). 220 WithField("item_count", len(items)). 221 Debug("fetched source items") 222 } 223 224 func (c *Classifier) logBeginPreparation(params models.Classification) { 225 c.logBase(params, "classification_preparation_begin"). 226 Debug("begin run preparation") 227 } 228 229 func (c *Classifier) logFinishPreparation(params models.Classification) { 230 c.logBase(params, "classification_preparation_finish"). 231 Debug("finish run preparation") 232 } 233 234 func (c *Classifier) logExecutionError(event string, err error, params models.Classification) { 235 c.logBase(params, event). 236 WithError(err). 237 Error("classification execution failure") 238 }