github.com/weaviate/weaviate@v1.24.6/usecases/classification/classifier_run_worker.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"context"
    16  	"sync"
    17  	"sync/atomic"
    18  
    19  	"github.com/sirupsen/logrus"
    20  	enterrors "github.com/weaviate/weaviate/entities/errors"
    21  
    22  	"github.com/pkg/errors"
    23  	"github.com/weaviate/weaviate/entities/errorcompounder"
    24  	"github.com/weaviate/weaviate/entities/models"
    25  	"github.com/weaviate/weaviate/entities/search"
    26  )
    27  
    28  type runWorker struct {
    29  	jobs         []search.Result
    30  	successCount *int64
    31  	errorCount   *int64
    32  	ec           *errorcompounder.SafeErrorCompounder
    33  	classify     ClassifyItemFn
    34  	batchWriter  Writer
    35  	params       models.Classification
    36  	filters      Filters
    37  	id           int
    38  	workerCount  int
    39  }
    40  
    41  func (w *runWorker) addJob(job search.Result) {
    42  	w.jobs = append(w.jobs, job)
    43  }
    44  
    45  func (w *runWorker) work(ctx context.Context, wg *sync.WaitGroup) {
    46  	defer wg.Done()
    47  
    48  	for i, item := range w.jobs {
    49  		// check if the whole classification operation has been cancelled
    50  		// if yes, then abort the classifier worker
    51  		if err := ctx.Err(); err != nil {
    52  			w.ec.Add(err)
    53  			atomic.AddInt64(w.errorCount, 1)
    54  			break
    55  		}
    56  		originalIndex := (i * w.workerCount) + w.id
    57  		err := w.classify(item, originalIndex, w.params, w.filters, w.batchWriter)
    58  		if err != nil {
    59  			w.ec.Add(err)
    60  			atomic.AddInt64(w.errorCount, 1)
    61  		} else {
    62  			atomic.AddInt64(w.successCount, 1)
    63  		}
    64  	}
    65  }
    66  
    67  func newRunWorker(id int, workerCount int, rw *runWorkers) *runWorker {
    68  	return &runWorker{
    69  		successCount: rw.successCount,
    70  		errorCount:   rw.errorCount,
    71  		ec:           rw.ec,
    72  		params:       rw.params,
    73  		filters:      rw.filters,
    74  		classify:     rw.classify,
    75  		batchWriter:  rw.batchWriter,
    76  		id:           id,
    77  		workerCount:  workerCount,
    78  	}
    79  }
    80  
    81  type runWorkers struct {
    82  	workers      []*runWorker
    83  	successCount *int64
    84  	errorCount   *int64
    85  	ec           *errorcompounder.SafeErrorCompounder
    86  	classify     ClassifyItemFn
    87  	params       models.Classification
    88  	filters      Filters
    89  	batchWriter  Writer
    90  	logger       logrus.FieldLogger
    91  }
    92  
    93  func newRunWorkers(amount int, classifyFn ClassifyItemFn,
    94  	params models.Classification, filters Filters, vectorRepo vectorRepo, logger logrus.FieldLogger,
    95  ) *runWorkers {
    96  	var successCount int64
    97  	var errorCount int64
    98  
    99  	rw := &runWorkers{
   100  		workers:      make([]*runWorker, amount),
   101  		successCount: &successCount,
   102  		errorCount:   &errorCount,
   103  		ec:           &errorcompounder.SafeErrorCompounder{},
   104  		classify:     classifyFn,
   105  		params:       params,
   106  		filters:      filters,
   107  		batchWriter:  newBatchWriter(vectorRepo, logger),
   108  		logger:       logger,
   109  	}
   110  
   111  	for i := 0; i < amount; i++ {
   112  		rw.workers[i] = newRunWorker(i, amount, rw)
   113  	}
   114  
   115  	return rw
   116  }
   117  
   118  func (ws *runWorkers) addJobs(jobs []search.Result) {
   119  	for i, job := range jobs {
   120  		ws.workers[i%len(ws.workers)].addJob(job)
   121  	}
   122  }
   123  
   124  func (ws *runWorkers) work(ctx context.Context) runWorkerResults {
   125  	ws.batchWriter.Start()
   126  
   127  	wg := &sync.WaitGroup{}
   128  	for _, worker := range ws.workers {
   129  		worker := worker
   130  		wg.Add(1)
   131  		enterrors.GoWrapper(func() { worker.work(ctx, wg) }, ws.logger)
   132  
   133  	}
   134  
   135  	wg.Wait()
   136  
   137  	res := ws.batchWriter.Stop()
   138  
   139  	if res.SuccessCount() != *ws.successCount || res.ErrorCount() != *ws.errorCount {
   140  		ws.ec.Add(errors.New("data save error"))
   141  	}
   142  
   143  	if res.Err() != nil {
   144  		ws.ec.Add(res.Err())
   145  	}
   146  
   147  	return runWorkerResults{
   148  		successCount: *ws.successCount,
   149  		errorCount:   *ws.errorCount,
   150  		err:          ws.ec.ToError(),
   151  	}
   152  }
   153  
   154  type runWorkerResults struct {
   155  	successCount int64
   156  	errorCount   int64
   157  	err          error
   158  }