github.com/weaviate/weaviate@v1.24.6/usecases/classification/fakes_for_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"math"
    18  	"sort"
    19  	"sync"
    20  	"time"
    21  
    22  	"github.com/go-openapi/strfmt"
    23  	"github.com/pkg/errors"
    24  	"github.com/weaviate/weaviate/entities/additional"
    25  	"github.com/weaviate/weaviate/entities/dto"
    26  	libfilters "github.com/weaviate/weaviate/entities/filters"
    27  	"github.com/weaviate/weaviate/entities/models"
    28  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    29  	"github.com/weaviate/weaviate/entities/schema"
    30  	"github.com/weaviate/weaviate/entities/search"
    31  	"github.com/weaviate/weaviate/usecases/objects"
    32  	"github.com/weaviate/weaviate/usecases/sharding"
    33  )
    34  
    35  type fakeSchemaGetter struct {
    36  	schema schema.Schema
    37  }
    38  
    39  func (f *fakeSchemaGetter) GetSchemaSkipAuth() schema.Schema {
    40  	return f.schema
    41  }
    42  
    43  func (f *fakeSchemaGetter) CopyShardingState(class string) *sharding.State {
    44  	panic("not implemented")
    45  }
    46  
    47  func (f *fakeSchemaGetter) ShardOwner(class, shard string) (string, error) {
    48  	return shard, nil
    49  }
    50  
    51  func (f *fakeSchemaGetter) ShardReplicas(class, shard string) ([]string, error) {
    52  	return []string{shard}, nil
    53  }
    54  
    55  func (f *fakeSchemaGetter) TenantShard(class, tenant string) (string, string) {
    56  	return tenant, models.TenantActivityStatusHOT
    57  }
    58  func (f *fakeSchemaGetter) ShardFromUUID(class string, uuid []byte) string { return string(uuid) }
    59  
    60  func (f *fakeSchemaGetter) Nodes() []string {
    61  	panic("not implemented")
    62  }
    63  
    64  func (f *fakeSchemaGetter) NodeName() string {
    65  	panic("not implemented")
    66  }
    67  
    68  func (f *fakeSchemaGetter) ClusterHealthScore() int {
    69  	panic("not implemented")
    70  }
    71  
    72  func (f *fakeSchemaGetter) ResolveParentNodes(string, string,
    73  ) (map[string]string, error) {
    74  	panic("not implemented")
    75  }
    76  
    77  type fakeClassificationRepo struct {
    78  	sync.Mutex
    79  	db map[strfmt.UUID]models.Classification
    80  }
    81  
    82  func newFakeClassificationRepo() *fakeClassificationRepo {
    83  	return &fakeClassificationRepo{
    84  		db: map[strfmt.UUID]models.Classification{},
    85  	}
    86  }
    87  
    88  func (f *fakeClassificationRepo) Put(ctx context.Context, class models.Classification) error {
    89  	f.Lock()
    90  	defer f.Unlock()
    91  
    92  	f.db[class.ID] = class
    93  	return nil
    94  }
    95  
    96  func (f *fakeClassificationRepo) Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error) {
    97  	f.Lock()
    98  	defer f.Unlock()
    99  
   100  	class, ok := f.db[id]
   101  	if !ok {
   102  		return nil, nil
   103  	}
   104  
   105  	return &class, nil
   106  }
   107  
   108  func newFakeVectorRepoKNN(unclassified, classified search.Results) *fakeVectorRepoKNN {
   109  	return &fakeVectorRepoKNN{
   110  		unclassified: unclassified,
   111  		classified:   classified,
   112  		db:           map[strfmt.UUID]*models.Object{},
   113  	}
   114  }
   115  
   116  // read requests are specified through unclassified and classified,
   117  // write requests (Put[Kind]) are stored in the db map
   118  type fakeVectorRepoKNN struct {
   119  	sync.Mutex
   120  	unclassified      []search.Result
   121  	classified        []search.Result
   122  	db                map[strfmt.UUID]*models.Object
   123  	errorOnAggregate  error
   124  	batchStorageDelay time.Duration
   125  }
   126  
   127  func (f *fakeVectorRepoKNN) GetUnclassified(ctx context.Context,
   128  	class string, properties []string,
   129  	filter *libfilters.LocalFilter,
   130  ) ([]search.Result, error) {
   131  	f.Lock()
   132  	defer f.Unlock()
   133  	return f.unclassified, nil
   134  }
   135  
   136  func (f *fakeVectorRepoKNN) AggregateNeighbors(ctx context.Context, vector []float32,
   137  	class string, properties []string, k int,
   138  	filter *libfilters.LocalFilter,
   139  ) ([]NeighborRef, error) {
   140  	f.Lock()
   141  	defer f.Unlock()
   142  
   143  	// simulate that this takes some time
   144  	time.Sleep(1 * time.Millisecond)
   145  
   146  	if k != 1 {
   147  		return nil, fmt.Errorf("fake vector repo only supports k=1")
   148  	}
   149  
   150  	results := f.classified
   151  	sort.SliceStable(results, func(i, j int) bool {
   152  		simI, err := cosineSim(results[i].Vector, vector)
   153  		if err != nil {
   154  			panic(err.Error())
   155  		}
   156  
   157  		simJ, err := cosineSim(results[j].Vector, vector)
   158  		if err != nil {
   159  			panic(err.Error())
   160  		}
   161  		return simI > simJ
   162  	})
   163  
   164  	var out []NeighborRef
   165  	schema := results[0].Schema.(map[string]interface{})
   166  	for _, propName := range properties {
   167  		prop, ok := schema[propName]
   168  		if !ok {
   169  			return nil, fmt.Errorf("missing prop %s", propName)
   170  		}
   171  
   172  		refs := prop.(models.MultipleRef)
   173  		if len(refs) != 1 {
   174  			return nil, fmt.Errorf("wrong length %d", len(refs))
   175  		}
   176  
   177  		out = append(out, NeighborRef{
   178  			Beacon:       refs[0].Beacon,
   179  			WinningCount: 1,
   180  			OverallCount: 1,
   181  			LosingCount:  1,
   182  			Property:     propName,
   183  		})
   184  	}
   185  
   186  	return out, f.errorOnAggregate
   187  }
   188  
   189  func (f *fakeVectorRepoKNN) ZeroShotSearch(ctx context.Context, vector []float32,
   190  	class string, properties []string,
   191  	filter *libfilters.LocalFilter,
   192  ) ([]search.Result, error) {
   193  	return []search.Result{}, nil
   194  }
   195  
   196  func (f *fakeVectorRepoKNN) VectorSearch(ctx context.Context,
   197  	params dto.GetParams,
   198  ) ([]search.Result, error) {
   199  	f.Lock()
   200  	defer f.Unlock()
   201  	return nil, fmt.Errorf("vector class search not implemented in fake")
   202  }
   203  
   204  func (f *fakeVectorRepoKNN) BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) {
   205  	f.Lock()
   206  	defer f.Unlock()
   207  
   208  	if f.batchStorageDelay > 0 {
   209  		time.Sleep(f.batchStorageDelay)
   210  	}
   211  
   212  	for _, batchObject := range objects {
   213  		f.db[batchObject.Object.ID] = batchObject.Object
   214  	}
   215  	return objects, nil
   216  }
   217  
   218  func (f *fakeVectorRepoKNN) get(id strfmt.UUID) (*models.Object, bool) {
   219  	f.Lock()
   220  	defer f.Unlock()
   221  	t, ok := f.db[id]
   222  	return t, ok
   223  }
   224  
   225  type fakeAuthorizer struct{}
   226  
   227  func (f *fakeAuthorizer) Authorize(principal *models.Principal, verb, resource string) error {
   228  	return nil
   229  }
   230  
   231  func newFakeVectorRepoContextual(unclassified, targets search.Results) *fakeVectorRepoContextual {
   232  	return &fakeVectorRepoContextual{
   233  		unclassified: unclassified,
   234  		targets:      targets,
   235  		db:           map[strfmt.UUID]*models.Object{},
   236  	}
   237  }
   238  
   239  // read requests are specified through unclassified and classified,
   240  // write requests (Put[Kind]) are stored in the db map
   241  type fakeVectorRepoContextual struct {
   242  	sync.Mutex
   243  	unclassified     []search.Result
   244  	targets          []search.Result
   245  	db               map[strfmt.UUID]*models.Object
   246  	errorOnAggregate error
   247  }
   248  
   249  func (f *fakeVectorRepoContextual) get(id strfmt.UUID) (*models.Object, bool) {
   250  	f.Lock()
   251  	defer f.Unlock()
   252  	t, ok := f.db[id]
   253  	return t, ok
   254  }
   255  
   256  func (f *fakeVectorRepoContextual) GetUnclassified(ctx context.Context,
   257  	class string, properties []string,
   258  	filter *libfilters.LocalFilter,
   259  ) ([]search.Result, error) {
   260  	return f.unclassified, nil
   261  }
   262  
   263  func (f *fakeVectorRepoContextual) AggregateNeighbors(ctx context.Context, vector []float32,
   264  	class string, properties []string, k int,
   265  	filter *libfilters.LocalFilter,
   266  ) ([]NeighborRef, error) {
   267  	panic("not implemented")
   268  }
   269  
   270  func (f *fakeVectorRepoContextual) ZeroShotSearch(ctx context.Context, vector []float32,
   271  	class string, properties []string,
   272  	filter *libfilters.LocalFilter,
   273  ) ([]search.Result, error) {
   274  	panic("not implemented")
   275  }
   276  
   277  func (f *fakeVectorRepoContextual) BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) {
   278  	f.Lock()
   279  	defer f.Unlock()
   280  	for _, batchObject := range objects {
   281  		f.db[batchObject.Object.ID] = batchObject.Object
   282  	}
   283  	return objects, nil
   284  }
   285  
   286  func (f *fakeVectorRepoContextual) VectorSearch(ctx context.Context,
   287  	params dto.GetParams,
   288  ) ([]search.Result, error) {
   289  	if params.SearchVector == nil {
   290  		filteredTargets := matchClassName(f.targets, params.ClassName)
   291  		return filteredTargets, nil
   292  	}
   293  
   294  	// simulate that this takes some time
   295  	time.Sleep(5 * time.Millisecond)
   296  
   297  	filteredTargets := matchClassName(f.targets, params.ClassName)
   298  	results := filteredTargets
   299  	sort.SliceStable(results, func(i, j int) bool {
   300  		simI, err := cosineSim(results[i].Vector, params.SearchVector)
   301  		if err != nil {
   302  			panic(err.Error())
   303  		}
   304  
   305  		simJ, err := cosineSim(results[j].Vector, params.SearchVector)
   306  		if err != nil {
   307  			panic(err.Error())
   308  		}
   309  		return simI > simJ
   310  	})
   311  
   312  	if len(results) == 0 {
   313  		return nil, f.errorOnAggregate
   314  	}
   315  
   316  	out := []search.Result{
   317  		results[0],
   318  	}
   319  
   320  	return out, f.errorOnAggregate
   321  }
   322  
   323  func cosineSim(a, b []float32) (float32, error) {
   324  	if len(a) != len(b) {
   325  		return 0, fmt.Errorf("vectors have different dimensions")
   326  	}
   327  
   328  	var (
   329  		sumProduct float64
   330  		sumASquare float64
   331  		sumBSquare float64
   332  	)
   333  
   334  	for i := range a {
   335  		sumProduct += float64(a[i] * b[i])
   336  		sumASquare += float64(a[i] * a[i])
   337  		sumBSquare += float64(b[i] * b[i])
   338  	}
   339  
   340  	return float32(sumProduct / (math.Sqrt(sumASquare) * math.Sqrt(sumBSquare))), nil
   341  }
   342  
   343  func matchClassName(in []search.Result, className string) []search.Result {
   344  	var out []search.Result
   345  	for _, item := range in {
   346  		if item.ClassName == className {
   347  			out = append(out, item)
   348  		}
   349  	}
   350  
   351  	return out
   352  }
   353  
   354  type fakeModuleClassifyFn struct {
   355  	fakeExactCategoryMappings map[string]string
   356  	fakeMainCategoryMappings  map[string]string
   357  }
   358  
   359  func NewFakeModuleClassifyFn() *fakeModuleClassifyFn {
   360  	return &fakeModuleClassifyFn{
   361  		fakeExactCategoryMappings: map[string]string{
   362  			"75ba35af-6a08-40ae-b442-3bec69b355f9": "1b204f16-7da6-44fd-bbd2-8cc4a7414bc3",
   363  			"a2bbcbdc-76e1-477d-9e72-a6d2cfb50109": "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2",
   364  			"069410c3-4b9e-4f68-8034-32a066cb7997": "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2",
   365  			"06a1e824-889c-4649-97f9-1ed3fa401d8e": "027b708a-31ca-43ea-9001-88bec864c79c",
   366  		},
   367  		fakeMainCategoryMappings: map[string]string{
   368  			"6402e649-b1e0-40ea-b192-a64eab0d5e56": "5a3d909a-4f0d-4168-8f5c-cd3074d1e79a",
   369  			"f850439a-d3cd-4f17-8fbf-5a64405645cd": "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e",
   370  			"069410c3-4b9e-4f68-8034-32a066cb7997": "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e",
   371  		},
   372  	}
   373  }
   374  
   375  func (c *fakeModuleClassifyFn) classifyFn(item search.Result, itemIndex int,
   376  	params models.Classification, filters modulecapabilities.Filters, writer modulecapabilities.Writer,
   377  ) error {
   378  	var classified []string
   379  
   380  	classifiedProp := c.fakeClassification(&item, "exactCategory", c.fakeExactCategoryMappings)
   381  	if len(classifiedProp) > 0 {
   382  		classified = append(classified, classifiedProp)
   383  	}
   384  
   385  	classifiedProp = c.fakeClassification(&item, "mainCategory", c.fakeMainCategoryMappings)
   386  	if len(classifiedProp) > 0 {
   387  		classified = append(classified, classifiedProp)
   388  	}
   389  
   390  	c.extendItemWithObjectMeta(&item, params, classified)
   391  
   392  	err := writer.Store(item)
   393  	if err != nil {
   394  		return fmt.Errorf("store %s/%s: %v", item.ClassName, item.ID, err)
   395  	}
   396  	return nil
   397  }
   398  
   399  func (c *fakeModuleClassifyFn) fakeClassification(item *search.Result, propName string,
   400  	fakes map[string]string,
   401  ) string {
   402  	if target, ok := fakes[item.ID.String()]; ok {
   403  		beacon := "weaviate://localhost/" + target
   404  		item.Schema.(map[string]interface{})[propName] = models.MultipleRef{
   405  			&models.SingleRef{
   406  				Beacon:         strfmt.URI(beacon),
   407  				Classification: nil,
   408  			},
   409  		}
   410  		return propName
   411  	}
   412  	return ""
   413  }
   414  
   415  func (c *fakeModuleClassifyFn) extendItemWithObjectMeta(item *search.Result,
   416  	params models.Classification, classified []string,
   417  ) {
   418  	if item.AdditionalProperties == nil {
   419  		item.AdditionalProperties = models.AdditionalProperties{}
   420  	}
   421  
   422  	item.AdditionalProperties["classification"] = additional.Classification{
   423  		ID:               params.ID,
   424  		Scope:            params.ClassifyProperties,
   425  		ClassifiedFields: classified,
   426  		Completed:        strfmt.DateTime(time.Now()),
   427  	}
   428  }
   429  
   430  type fakeModulesProvider struct {
   431  	fakeModuleClassifyFn *fakeModuleClassifyFn
   432  }
   433  
   434  func NewFakeModulesProvider() *fakeModulesProvider {
   435  	return &fakeModulesProvider{NewFakeModuleClassifyFn()}
   436  }
   437  
   438  func (m *fakeModulesProvider) ParseClassifierSettings(name string,
   439  	params *models.Classification,
   440  ) error {
   441  	return nil
   442  }
   443  
   444  func (m *fakeModulesProvider) GetClassificationFn(className, name string,
   445  	params modulecapabilities.ClassifyParams,
   446  ) (modulecapabilities.ClassifyItemFn, error) {
   447  	if name == "text2vec-contextionary-custom-contextual" {
   448  		return m.fakeModuleClassifyFn.classifyFn, nil
   449  	}
   450  	return nil, errors.Errorf("classifier %s not found", name)
   451  }