github.com/weaviate/weaviate@v1.24.6/usecases/classification/classifier_run_worker.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "context" 16 "sync" 17 "sync/atomic" 18 19 "github.com/sirupsen/logrus" 20 enterrors "github.com/weaviate/weaviate/entities/errors" 21 22 "github.com/pkg/errors" 23 "github.com/weaviate/weaviate/entities/errorcompounder" 24 "github.com/weaviate/weaviate/entities/models" 25 "github.com/weaviate/weaviate/entities/search" 26 ) 27 28 type runWorker struct { 29 jobs []search.Result 30 successCount *int64 31 errorCount *int64 32 ec *errorcompounder.SafeErrorCompounder 33 classify ClassifyItemFn 34 batchWriter Writer 35 params models.Classification 36 filters Filters 37 id int 38 workerCount int 39 } 40 41 func (w *runWorker) addJob(job search.Result) { 42 w.jobs = append(w.jobs, job) 43 } 44 45 func (w *runWorker) work(ctx context.Context, wg *sync.WaitGroup) { 46 defer wg.Done() 47 48 for i, item := range w.jobs { 49 // check if the whole classification operation has been cancelled 50 // if yes, then abort the classifier worker 51 if err := ctx.Err(); err != nil { 52 w.ec.Add(err) 53 atomic.AddInt64(w.errorCount, 1) 54 break 55 } 56 originalIndex := (i * w.workerCount) + w.id 57 err := w.classify(item, originalIndex, w.params, w.filters, w.batchWriter) 58 if err != nil { 59 w.ec.Add(err) 60 atomic.AddInt64(w.errorCount, 1) 61 } else { 62 atomic.AddInt64(w.successCount, 1) 63 } 64 } 65 } 66 67 func newRunWorker(id int, workerCount int, rw *runWorkers) *runWorker { 68 return &runWorker{ 69 successCount: rw.successCount, 70 errorCount: rw.errorCount, 71 ec: rw.ec, 72 params: rw.params, 73 filters: rw.filters, 74 classify: rw.classify, 75 batchWriter: rw.batchWriter, 76 id: id, 77 workerCount: workerCount, 78 } 79 } 80 81 type runWorkers struct { 82 workers []*runWorker 83 successCount *int64 84 errorCount *int64 85 ec *errorcompounder.SafeErrorCompounder 86 classify ClassifyItemFn 87 params models.Classification 88 filters Filters 89 batchWriter Writer 90 logger logrus.FieldLogger 91 } 92 93 func newRunWorkers(amount int, classifyFn ClassifyItemFn, 94 params models.Classification, filters Filters, vectorRepo vectorRepo, logger logrus.FieldLogger, 95 ) *runWorkers { 96 var successCount int64 97 var errorCount int64 98 99 rw := &runWorkers{ 100 workers: make([]*runWorker, amount), 101 successCount: &successCount, 102 errorCount: &errorCount, 103 ec: &errorcompounder.SafeErrorCompounder{}, 104 classify: classifyFn, 105 params: params, 106 filters: filters, 107 batchWriter: newBatchWriter(vectorRepo, logger), 108 logger: logger, 109 } 110 111 for i := 0; i < amount; i++ { 112 rw.workers[i] = newRunWorker(i, amount, rw) 113 } 114 115 return rw 116 } 117 118 func (ws *runWorkers) addJobs(jobs []search.Result) { 119 for i, job := range jobs { 120 ws.workers[i%len(ws.workers)].addJob(job) 121 } 122 } 123 124 func (ws *runWorkers) work(ctx context.Context) runWorkerResults { 125 ws.batchWriter.Start() 126 127 wg := &sync.WaitGroup{} 128 for _, worker := range ws.workers { 129 worker := worker 130 wg.Add(1) 131 enterrors.GoWrapper(func() { worker.work(ctx, wg) }, ws.logger) 132 133 } 134 135 wg.Wait() 136 137 res := ws.batchWriter.Stop() 138 139 if res.SuccessCount() != *ws.successCount || res.ErrorCount() != *ws.errorCount { 140 ws.ec.Add(errors.New("data save error")) 141 } 142 143 if res.Err() != nil { 144 ws.ec.Add(res.Err()) 145 } 146 147 return runWorkerResults{ 148 successCount: *ws.successCount, 149 errorCount: *ws.errorCount, 150 err: ws.ec.ToError(), 151 } 152 } 153 154 type runWorkerResults struct { 155 successCount int64 156 errorCount int64 157 err error 158 }