github.com/weaviate/weaviate@v1.24.6/usecases/classification/classifier.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"fmt"
    18  	"time"
    19  
    20  	enterrors "github.com/weaviate/weaviate/entities/errors"
    21  
    22  	"github.com/go-openapi/strfmt"
    23  	"github.com/google/uuid"
    24  	"github.com/pkg/errors"
    25  	"github.com/sirupsen/logrus"
    26  	"github.com/weaviate/weaviate/adapters/handlers/rest/filterext"
    27  	"github.com/weaviate/weaviate/entities/additional"
    28  	"github.com/weaviate/weaviate/entities/dto"
    29  	libfilters "github.com/weaviate/weaviate/entities/filters"
    30  	"github.com/weaviate/weaviate/entities/models"
    31  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    32  	"github.com/weaviate/weaviate/entities/search"
    33  	"github.com/weaviate/weaviate/usecases/objects"
    34  	schemaUC "github.com/weaviate/weaviate/usecases/schema"
    35  	libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer"
    36  )
    37  
    38  type classificationFilters struct {
    39  	source      *libfilters.LocalFilter
    40  	target      *libfilters.LocalFilter
    41  	trainingSet *libfilters.LocalFilter
    42  }
    43  
    44  func (f classificationFilters) Source() *libfilters.LocalFilter {
    45  	return f.source
    46  }
    47  
    48  func (f classificationFilters) Target() *libfilters.LocalFilter {
    49  	return f.target
    50  }
    51  
    52  func (f classificationFilters) TrainingSet() *libfilters.LocalFilter {
    53  	return f.trainingSet
    54  }
    55  
    56  type distancer func(a, b []float32) (float32, error)
    57  
    58  type Classifier struct {
    59  	schemaGetter          schemaUC.SchemaGetter
    60  	repo                  Repo
    61  	vectorRepo            vectorRepo
    62  	vectorClassSearchRepo modulecapabilities.VectorClassSearchRepo
    63  	authorizer            authorizer
    64  	distancer             distancer
    65  	modulesProvider       ModulesProvider
    66  	logger                logrus.FieldLogger
    67  }
    68  
    69  type authorizer interface {
    70  	Authorize(principal *models.Principal, verb, resource string) error
    71  }
    72  
    73  type ModulesProvider interface {
    74  	ParseClassifierSettings(name string,
    75  		params *models.Classification) error
    76  	GetClassificationFn(className, name string,
    77  		params modulecapabilities.ClassifyParams) (modulecapabilities.ClassifyItemFn, error)
    78  }
    79  
    80  func New(sg schemaUC.SchemaGetter, cr Repo, vr vectorRepo, authorizer authorizer,
    81  	logger logrus.FieldLogger, modulesProvider ModulesProvider,
    82  ) *Classifier {
    83  	return &Classifier{
    84  		logger:                logger,
    85  		schemaGetter:          sg,
    86  		repo:                  cr,
    87  		vectorRepo:            vr,
    88  		authorizer:            authorizer,
    89  		distancer:             libvectorizer.NormalizedDistance,
    90  		vectorClassSearchRepo: newVectorClassSearchRepo(vr),
    91  		modulesProvider:       modulesProvider,
    92  	}
    93  }
    94  
    95  // Repo to manage classification state, should be consistent, not used to store
    96  // actual data object vectors, see VectorRepo
    97  type Repo interface {
    98  	Put(ctx context.Context, classification models.Classification) error
    99  	Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error)
   100  }
   101  
   102  type VectorRepo interface {
   103  	GetUnclassified(ctx context.Context, class string,
   104  		properties []string, filter *libfilters.LocalFilter) ([]search.Result, error)
   105  	AggregateNeighbors(ctx context.Context, vector []float32,
   106  		class string, properties []string, k int,
   107  		filter *libfilters.LocalFilter) ([]NeighborRef, error)
   108  	VectorSearch(ctx context.Context, params dto.GetParams) ([]search.Result, error)
   109  	ZeroShotSearch(ctx context.Context, vector []float32,
   110  		class string, properties []string,
   111  		filter *libfilters.LocalFilter) ([]search.Result, error)
   112  }
   113  
   114  type vectorRepo interface {
   115  	VectorRepo
   116  	BatchPutObjects(ctx context.Context, objects objects.BatchObjects,
   117  		repl *additional.ReplicationProperties) (objects.BatchObjects, error)
   118  }
   119  
   120  // NeighborRef is the result of an aggregation of the ref properties of k
   121  // neighbors
   122  type NeighborRef struct {
   123  	// Property indicates which property was aggregated
   124  	Property string
   125  
   126  	// The beacon of the most common (kNN) reference
   127  	Beacon strfmt.URI
   128  
   129  	OverallCount int
   130  	WinningCount int
   131  	LosingCount  int
   132  
   133  	Distances NeighborRefDistances
   134  }
   135  
   136  func (c *Classifier) Schedule(ctx context.Context, principal *models.Principal, params models.Classification) (*models.Classification, error) {
   137  	err := c.authorizer.Authorize(principal, "create", "classifications/*")
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	err = c.parseAndSetDefaults(&params)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	err = NewValidator(c.schemaGetter, params).Do()
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	if err := c.assignNewID(&params); err != nil {
   153  		return nil, fmt.Errorf("classification: assign id: %v", err)
   154  	}
   155  
   156  	params.Status = models.ClassificationStatusRunning
   157  	params.Meta = &models.ClassificationMeta{
   158  		Started: strfmt.DateTime(time.Now()),
   159  	}
   160  
   161  	if err := c.repo.Put(ctx, params); err != nil {
   162  		return nil, fmt.Errorf("classification: put: %v", err)
   163  	}
   164  
   165  	// asynchronously trigger the classification
   166  	filters, err := c.extractFilters(params)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	enterrors.GoWrapper(func() { c.run(params, filters) }, c.logger)
   172  
   173  	return &params, nil
   174  }
   175  
   176  func (c *Classifier) extractFilters(params models.Classification) (Filters, error) {
   177  	if params.Filters == nil {
   178  		return classificationFilters{}, nil
   179  	}
   180  
   181  	source, err := filterext.Parse(params.Filters.SourceWhere, params.Class)
   182  	if err != nil {
   183  		return classificationFilters{}, fmt.Errorf("field 'sourceWhere': %v", err)
   184  	}
   185  
   186  	trainingSet, err := filterext.Parse(params.Filters.TrainingSetWhere, params.Class)
   187  	if err != nil {
   188  		return classificationFilters{}, fmt.Errorf("field 'trainingSetWhere': %v", err)
   189  	}
   190  
   191  	target, err := filterext.Parse(params.Filters.TargetWhere, params.Class)
   192  	if err != nil {
   193  		return classificationFilters{}, fmt.Errorf("field 'targetWhere': %v", err)
   194  	}
   195  
   196  	filters := classificationFilters{
   197  		source:      source,
   198  		trainingSet: trainingSet,
   199  		target:      target,
   200  	}
   201  
   202  	if err = c.validateFilters(&params, &filters); err != nil {
   203  		return nil, err
   204  	}
   205  
   206  	return filters, nil
   207  }
   208  
   209  func (c *Classifier) validateFilters(params *models.Classification, filters *classificationFilters) (err error) {
   210  	if params.Type == TypeKNN {
   211  		if err = c.validateFilter(filters.Source()); err != nil {
   212  			return fmt.Errorf("invalid sourceWhere: %s", err)
   213  		}
   214  		if err = c.validateFilter(filters.TrainingSet()); err != nil {
   215  			return fmt.Errorf("invalid trainingSetWhere: %s", err)
   216  		}
   217  	}
   218  
   219  	if params.Type == TypeContextual || params.Type == TypeZeroShot {
   220  		if err = c.validateFilter(filters.Source()); err != nil {
   221  			return fmt.Errorf("invalid sourceWhere: %s", err)
   222  		}
   223  		if err = c.validateFilter(filters.Target()); err != nil {
   224  			return fmt.Errorf("invalid targetWhere: %s", err)
   225  		}
   226  	}
   227  
   228  	return
   229  }
   230  
   231  func (c *Classifier) validateFilter(filter *libfilters.LocalFilter) error {
   232  	if filter == nil {
   233  		return nil
   234  	}
   235  	return libfilters.ValidateFilters(c.schemaGetter.GetSchemaSkipAuth(), filter)
   236  }
   237  
   238  func (c *Classifier) assignNewID(params *models.Classification) error {
   239  	id, err := uuid.NewRandom()
   240  	if err != nil {
   241  		return err
   242  	}
   243  
   244  	params.ID = strfmt.UUID(id.String())
   245  	return nil
   246  }
   247  
   248  func (c *Classifier) Get(ctx context.Context, principal *models.Principal, id strfmt.UUID) (*models.Classification, error) {
   249  	err := c.authorizer.Authorize(principal, "get", "classifications/*")
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  
   254  	return c.repo.Get(ctx, id)
   255  }
   256  
   257  func (c *Classifier) parseAndSetDefaults(params *models.Classification) error {
   258  	if params.Type == "" {
   259  		defaultType := "knn"
   260  		params.Type = defaultType
   261  	}
   262  
   263  	if params.Type == "knn" {
   264  		if err := c.parseKNNSettings(params); err != nil {
   265  			return errors.Wrapf(err, "parse knn specific settings")
   266  		}
   267  		return nil
   268  	}
   269  
   270  	if c.modulesProvider != nil {
   271  		if err := c.modulesProvider.ParseClassifierSettings(params.Type, params); err != nil {
   272  			return errors.Wrapf(err, "parse %s specific settings", params.Type)
   273  		}
   274  		return nil
   275  	}
   276  
   277  	return nil
   278  }
   279  
   280  func (c *Classifier) parseKNNSettings(params *models.Classification) error {
   281  	raw := params.Settings
   282  	settings := &ParamsKNN{}
   283  	if raw == nil {
   284  		settings.SetDefaults()
   285  		params.Settings = settings
   286  		return nil
   287  	}
   288  
   289  	asMap, ok := raw.(map[string]interface{})
   290  	if !ok {
   291  		return errors.Errorf("settings must be an object got %T", raw)
   292  	}
   293  
   294  	v, err := extractNumberFromMap(asMap, "k")
   295  	if err != nil {
   296  		return err
   297  	}
   298  	settings.K = v
   299  
   300  	settings.SetDefaults()
   301  	params.Settings = settings
   302  
   303  	return nil
   304  }
   305  
   306  type ParamsKNN struct {
   307  	K *int32 `json:"k"`
   308  }
   309  
   310  func (params *ParamsKNN) SetDefaults() {
   311  	if params.K == nil {
   312  		defaultK := int32(3)
   313  		params.K = &defaultK
   314  	}
   315  }
   316  
   317  func extractNumberFromMap(in map[string]interface{}, field string) (*int32, error) {
   318  	unparsed, present := in[field]
   319  	if present {
   320  		parsed, ok := unparsed.(json.Number)
   321  		if !ok {
   322  			return nil, errors.Errorf("settings.%s must be number, got %T",
   323  				field, unparsed)
   324  		}
   325  
   326  		asInt64, err := parsed.Int64()
   327  		if err != nil {
   328  			return nil, errors.Wrapf(err, "settings.%s", field)
   329  		}
   330  
   331  		asInt32 := int32(asInt64)
   332  		return &asInt32, nil
   333  	}
   334  
   335  	return nil, nil
   336  }