github.com/weaviate/weaviate@v1.24.6/usecases/classification/classifier_run.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"runtime"
    18  	"time"
    19  
    20  	"github.com/go-openapi/strfmt"
    21  	"github.com/pkg/errors"
    22  	"github.com/sirupsen/logrus"
    23  	"github.com/weaviate/weaviate/entities/additional"
    24  	"github.com/weaviate/weaviate/entities/models"
    25  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    26  	"github.com/weaviate/weaviate/entities/schema"
    27  	"github.com/weaviate/weaviate/entities/search"
    28  )
    29  
    30  // the contents of this file deal with anything about a classification run
    31  // which is generic, whereas the individual classify_item fns can be found in
    32  // the respective files such as classifier_run_knn.go
    33  
    34  func (c *Classifier) run(params models.Classification,
    35  	filters Filters,
    36  ) {
    37  	ctx, cancel := contextWithTimeout(30 * time.Minute)
    38  	defer cancel()
    39  
    40  	go c.monitorClassification(ctx, cancel, schema.ClassName(params.Class))
    41  
    42  	c.logBegin(params, filters)
    43  	unclassifiedItems, err := c.vectorRepo.GetUnclassified(ctx,
    44  		params.Class, params.ClassifyProperties, filters.Source())
    45  	if err != nil {
    46  		c.failRunWithError(params, errors.Wrap(err, "retrieve to-be-classifieds"))
    47  		return
    48  	}
    49  
    50  	if len(unclassifiedItems) == 0 {
    51  		c.failRunWithError(params,
    52  			fmt.Errorf("no classes to be classified - did you run a previous classification already?"))
    53  		return
    54  	}
    55  	c.logItemsFetched(params, unclassifiedItems)
    56  
    57  	classifyItem, err := c.prepareRun(params, filters, unclassifiedItems)
    58  	if err != nil {
    59  		c.failRunWithError(params, errors.Wrap(err, "prepare classification"))
    60  		return
    61  	}
    62  
    63  	params, err = c.runItems(ctx, classifyItem, params, filters, unclassifiedItems)
    64  	if err != nil {
    65  		c.failRunWithError(params, err)
    66  		return
    67  	}
    68  
    69  	c.succeedRun(params)
    70  }
    71  
    72  func (c *Classifier) monitorClassification(ctx context.Context, cancelFn context.CancelFunc,
    73  	className schema.ClassName,
    74  ) {
    75  	ticker := time.NewTicker(100 * time.Millisecond)
    76  	defer ticker.Stop()
    77  	for {
    78  		select {
    79  		case <-ctx.Done():
    80  			return
    81  		case <-ticker.C:
    82  			schema := c.schemaGetter.GetSchemaSkipAuth()
    83  			class := schema.FindClassByName(className)
    84  			if class == nil {
    85  				cancelFn()
    86  				return
    87  			}
    88  		}
    89  	}
    90  }
    91  
    92  func (c *Classifier) prepareRun(params models.Classification, filters Filters,
    93  	unclassifiedItems []search.Result,
    94  ) (ClassifyItemFn, error) {
    95  	c.logBeginPreparation(params)
    96  	defer c.logFinishPreparation(params)
    97  
    98  	if params.Type == "knn" {
    99  		return c.classifyItemUsingKNN, nil
   100  	}
   101  
   102  	if params.Type == "zeroshot" {
   103  		return c.classifyItemUsingZeroShot, nil
   104  	}
   105  
   106  	if c.modulesProvider != nil {
   107  		classifyItemFn, err := c.modulesProvider.GetClassificationFn(params.Class, params.Type,
   108  			c.getClassifyParams(params, filters, unclassifiedItems))
   109  		if err != nil {
   110  			return nil, errors.Wrapf(err, "cannot classify")
   111  		}
   112  		if classifyItemFn == nil {
   113  			return nil, errors.Errorf("cannot classify: empty classifier for %s", params.Type)
   114  		}
   115  		classification := &moduleClassification{classifyItemFn}
   116  		return classification.classifyFn, nil
   117  	}
   118  
   119  	return nil, errors.Errorf("unsupported type '%s', have no classify item fn for this", params.Type)
   120  }
   121  
   122  func (c *Classifier) getClassifyParams(params models.Classification,
   123  	filters Filters, unclassifiedItems []search.Result,
   124  ) modulecapabilities.ClassifyParams {
   125  	return modulecapabilities.ClassifyParams{
   126  		Schema:            c.schemaGetter.GetSchemaSkipAuth(),
   127  		Params:            params,
   128  		Filters:           filters,
   129  		UnclassifiedItems: unclassifiedItems,
   130  		VectorRepo:        c.vectorClassSearchRepo,
   131  	}
   132  }
   133  
   134  // runItems splits the job list into batches that can be worked on parallelly
   135  // depending on the available CPUs
   136  func (c *Classifier) runItems(ctx context.Context, classifyItem ClassifyItemFn, params models.Classification, filters Filters,
   137  	items []search.Result,
   138  ) (models.Classification, error) {
   139  	workerCount := runtime.GOMAXPROCS(0)
   140  	if len(items) < workerCount {
   141  		workerCount = len(items)
   142  	}
   143  
   144  	workers := newRunWorkers(workerCount, classifyItem, params, filters, c.vectorRepo, c.logger)
   145  	workers.addJobs(items)
   146  	res := workers.work(ctx)
   147  
   148  	params.Meta.Completed = strfmt.DateTime(time.Now())
   149  	params.Meta.CountSucceeded = res.successCount
   150  	params.Meta.CountFailed = res.errorCount
   151  	params.Meta.Count = res.successCount + res.errorCount
   152  
   153  	return params, res.err
   154  }
   155  
   156  func (c *Classifier) succeedRun(params models.Classification) {
   157  	params.Status = models.ClassificationStatusCompleted
   158  	ctx, cancel := contextWithTimeout(2 * time.Second)
   159  	defer cancel()
   160  	err := c.repo.Put(ctx, params)
   161  	if err != nil {
   162  		c.logExecutionError("store succeeded run", err, params)
   163  	}
   164  	c.logFinish(params)
   165  }
   166  
   167  func (c *Classifier) failRunWithError(params models.Classification, err error) {
   168  	params.Status = models.ClassificationStatusFailed
   169  	params.Error = fmt.Sprintf("classification failed: %v", err)
   170  	err = c.repo.Put(context.Background(), params)
   171  	if err != nil {
   172  		c.logExecutionError("store failed run", err, params)
   173  	}
   174  	c.logFinish(params)
   175  }
   176  
   177  func (c *Classifier) extendItemWithObjectMeta(item *search.Result,
   178  	params models.Classification, classified []string,
   179  ) {
   180  	// don't overwrite existing non-classification meta info
   181  	if item.AdditionalProperties == nil {
   182  		item.AdditionalProperties = models.AdditionalProperties{}
   183  	}
   184  
   185  	item.AdditionalProperties["classification"] = additional.Classification{
   186  		ID:               params.ID,
   187  		Scope:            params.ClassifyProperties,
   188  		ClassifiedFields: classified,
   189  		Completed:        strfmt.DateTime(time.Now()),
   190  	}
   191  }
   192  
   193  func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) {
   194  	return context.WithTimeout(context.Background(), d)
   195  }
   196  
   197  // Logging helper methods
   198  func (c *Classifier) logBase(params models.Classification, event string) *logrus.Entry {
   199  	return c.logger.WithField("action", "classification_run").
   200  		WithField("event", event).
   201  		WithField("params", params).
   202  		WithField("classification_type", params.Type)
   203  }
   204  
   205  func (c *Classifier) logBegin(params models.Classification, filters Filters) {
   206  	c.logBase(params, "classification_begin").
   207  		WithField("filters", filters).
   208  		Debug("classification started")
   209  }
   210  
   211  func (c *Classifier) logFinish(params models.Classification) {
   212  	c.logBase(params, "classification_finish").
   213  		WithField("status", params.Status).
   214  		Debug("classification finished")
   215  }
   216  
   217  func (c *Classifier) logItemsFetched(params models.Classification, items search.Results) {
   218  	c.logBase(params, "classification_items_fetched").
   219  		WithField("status", params.Status).
   220  		WithField("item_count", len(items)).
   221  		Debug("fetched source items")
   222  }
   223  
   224  func (c *Classifier) logBeginPreparation(params models.Classification) {
   225  	c.logBase(params, "classification_preparation_begin").
   226  		Debug("begin run preparation")
   227  }
   228  
   229  func (c *Classifier) logFinishPreparation(params models.Classification) {
   230  	c.logBase(params, "classification_preparation_finish").
   231  		Debug("finish run preparation")
   232  }
   233  
   234  func (c *Classifier) logExecutionError(event string, err error, params models.Classification) {
   235  	c.logBase(params, event).
   236  		WithError(err).
   237  		Error("classification execution failure")
   238  }