github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/fakes_for_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"sort"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/go-openapi/strfmt"
    22  	"github.com/weaviate/weaviate/entities/additional"
    23  	"github.com/weaviate/weaviate/entities/dto"
    24  	libfilters "github.com/weaviate/weaviate/entities/filters"
    25  	"github.com/weaviate/weaviate/entities/models"
    26  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    27  	"github.com/weaviate/weaviate/entities/schema"
    28  	"github.com/weaviate/weaviate/entities/search"
    29  	usecasesclassfication "github.com/weaviate/weaviate/usecases/classification"
    30  	"github.com/weaviate/weaviate/usecases/objects"
    31  	"github.com/weaviate/weaviate/usecases/sharding"
    32  )
    33  
    34  type fakeSchemaGetter struct {
    35  	schema schema.Schema
    36  }
    37  
    38  func (f *fakeSchemaGetter) GetSchemaSkipAuth() schema.Schema {
    39  	return f.schema
    40  }
    41  
    42  func (f *fakeSchemaGetter) CopyShardingState(class string) *sharding.State {
    43  	panic("not implemented")
    44  }
    45  
    46  func (f *fakeSchemaGetter) ShardOwner(class, shard string) (string, error)      { return "", nil }
    47  func (f *fakeSchemaGetter) ShardReplicas(class, shard string) ([]string, error) { return nil, nil }
    48  
    49  func (f *fakeSchemaGetter) TenantShard(class, tenant string) (string, string) {
    50  	return tenant, models.TenantActivityStatusHOT
    51  }
    52  func (f *fakeSchemaGetter) ShardFromUUID(class string, uuid []byte) string { return "" }
    53  
    54  func (f *fakeSchemaGetter) Nodes() []string {
    55  	panic("not implemented")
    56  }
    57  
    58  func (f *fakeSchemaGetter) NodeName() string {
    59  	panic("not implemented")
    60  }
    61  
    62  func (f *fakeSchemaGetter) ClusterHealthScore() int {
    63  	panic("not implemented")
    64  }
    65  
    66  func (f *fakeSchemaGetter) ResolveParentNodes(string, string,
    67  ) (map[string]string, error) {
    68  	panic("not implemented")
    69  }
    70  
    71  type fakeClassificationRepo struct {
    72  	sync.Mutex
    73  	db map[strfmt.UUID]models.Classification
    74  }
    75  
    76  func newFakeClassificationRepo() *fakeClassificationRepo {
    77  	return &fakeClassificationRepo{
    78  		db: map[strfmt.UUID]models.Classification{},
    79  	}
    80  }
    81  
    82  func (f *fakeClassificationRepo) Put(ctx context.Context, class models.Classification) error {
    83  	f.Lock()
    84  	defer f.Unlock()
    85  
    86  	f.db[class.ID] = class
    87  	return nil
    88  }
    89  
    90  func (f *fakeClassificationRepo) Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error) {
    91  	f.Lock()
    92  	defer f.Unlock()
    93  
    94  	class, ok := f.db[id]
    95  	if !ok {
    96  		return nil, nil
    97  	}
    98  
    99  	return &class, nil
   100  }
   101  
   102  func newFakeVectorRepoKNN(unclassified, classified search.Results) *fakeVectorRepoKNN {
   103  	return &fakeVectorRepoKNN{
   104  		unclassified: unclassified,
   105  		classified:   classified,
   106  		db:           map[strfmt.UUID]*models.Object{},
   107  	}
   108  }
   109  
   110  // read requests are specified through unclassified and classified,
   111  // write requests (Put[Kind]) are stored in the db map
   112  type fakeVectorRepoKNN struct {
   113  	sync.Mutex
   114  	unclassified      []search.Result
   115  	classified        []search.Result
   116  	db                map[strfmt.UUID]*models.Object
   117  	errorOnAggregate  error
   118  	batchStorageDelay time.Duration
   119  }
   120  
   121  func (f *fakeVectorRepoKNN) GetUnclassified(ctx context.Context,
   122  	class string, properties []string,
   123  	filter *libfilters.LocalFilter,
   124  ) ([]search.Result, error) {
   125  	f.Lock()
   126  	defer f.Unlock()
   127  	return f.unclassified, nil
   128  }
   129  
   130  func (f *fakeVectorRepoKNN) AggregateNeighbors(ctx context.Context, vector []float32,
   131  	class string, properties []string, k int,
   132  	filter *libfilters.LocalFilter,
   133  ) ([]usecasesclassfication.NeighborRef, error) {
   134  	f.Lock()
   135  	defer f.Unlock()
   136  
   137  	// simulate that this takes some time
   138  	time.Sleep(1 * time.Millisecond)
   139  
   140  	if k != 1 {
   141  		return nil, fmt.Errorf("fake vector repo only supports k=1")
   142  	}
   143  
   144  	results := f.classified
   145  	sort.SliceStable(results, func(i, j int) bool {
   146  		simI, err := cosineSim(results[i].Vector, vector)
   147  		if err != nil {
   148  			panic(err.Error())
   149  		}
   150  
   151  		simJ, err := cosineSim(results[j].Vector, vector)
   152  		if err != nil {
   153  			panic(err.Error())
   154  		}
   155  		return simI > simJ
   156  	})
   157  
   158  	var out []usecasesclassfication.NeighborRef
   159  	schema := results[0].Schema.(map[string]interface{})
   160  	for _, propName := range properties {
   161  		prop, ok := schema[propName]
   162  		if !ok {
   163  			return nil, fmt.Errorf("missing prop %s", propName)
   164  		}
   165  
   166  		refs := prop.(models.MultipleRef)
   167  		if len(refs) != 1 {
   168  			return nil, fmt.Errorf("wrong length %d", len(refs))
   169  		}
   170  
   171  		out = append(out, usecasesclassfication.NeighborRef{
   172  			Beacon:       refs[0].Beacon,
   173  			WinningCount: 1,
   174  			OverallCount: 1,
   175  			LosingCount:  1,
   176  			Property:     propName,
   177  		})
   178  	}
   179  
   180  	return out, f.errorOnAggregate
   181  }
   182  
   183  func (f *fakeVectorRepoKNN) ZeroShotSearch(ctx context.Context, vector []float32,
   184  	class string, properties []string,
   185  	filter *libfilters.LocalFilter,
   186  ) ([]search.Result, error) {
   187  	panic("not implemented")
   188  }
   189  
   190  func (f *fakeVectorRepoKNN) VectorSearch(ctx context.Context,
   191  	params dto.GetParams,
   192  ) ([]search.Result, error) {
   193  	f.Lock()
   194  	defer f.Unlock()
   195  	return nil, fmt.Errorf("vector class search not implemented in fake")
   196  }
   197  
   198  func (f *fakeVectorRepoKNN) BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) {
   199  	f.Lock()
   200  	defer f.Unlock()
   201  
   202  	if f.batchStorageDelay > 0 {
   203  		time.Sleep(f.batchStorageDelay)
   204  	}
   205  
   206  	for _, batchObject := range objects {
   207  		f.db[batchObject.Object.ID] = batchObject.Object
   208  	}
   209  	return objects, nil
   210  }
   211  
   212  func (f *fakeVectorRepoKNN) get(id strfmt.UUID) (*models.Object, bool) {
   213  	f.Lock()
   214  	defer f.Unlock()
   215  	t, ok := f.db[id]
   216  	return t, ok
   217  }
   218  
   219  type fakeAuthorizer struct{}
   220  
   221  func (f *fakeAuthorizer) Authorize(principal *models.Principal, verb, resource string) error {
   222  	return nil
   223  }
   224  
   225  func newFakeVectorRepoContextual(unclassified, targets search.Results) *fakeVectorRepoContextual {
   226  	return &fakeVectorRepoContextual{
   227  		unclassified: unclassified,
   228  		targets:      targets,
   229  		db:           map[strfmt.UUID]*models.Object{},
   230  	}
   231  }
   232  
   233  // read requests are specified through unclassified and classified,
   234  // write requests (Put[Kind]) are stored in the db map
   235  type fakeVectorRepoContextual struct {
   236  	sync.Mutex
   237  	unclassified     []search.Result
   238  	targets          []search.Result
   239  	db               map[strfmt.UUID]*models.Object
   240  	errorOnAggregate error
   241  }
   242  
   243  func (f *fakeVectorRepoContextual) get(id strfmt.UUID) (*models.Object, bool) {
   244  	f.Lock()
   245  	defer f.Unlock()
   246  	t, ok := f.db[id]
   247  	return t, ok
   248  }
   249  
   250  func (f *fakeVectorRepoContextual) GetUnclassified(ctx context.Context,
   251  	class string, properties []string,
   252  	filter *libfilters.LocalFilter,
   253  ) ([]search.Result, error) {
   254  	return f.unclassified, nil
   255  }
   256  
   257  func (f *fakeVectorRepoContextual) AggregateNeighbors(ctx context.Context, vector []float32,
   258  	class string, properties []string, k int,
   259  	filter *libfilters.LocalFilter,
   260  ) ([]usecasesclassfication.NeighborRef, error) {
   261  	panic("not implemented")
   262  }
   263  
   264  func (f *fakeVectorRepoContextual) ZeroShotSearch(ctx context.Context, vector []float32,
   265  	class string, properties []string,
   266  	filter *libfilters.LocalFilter,
   267  ) ([]search.Result, error) {
   268  	panic("not implemented")
   269  }
   270  
   271  func (f *fakeVectorRepoContextual) BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) {
   272  	f.Lock()
   273  	defer f.Unlock()
   274  	for _, batchObject := range objects {
   275  		f.db[batchObject.Object.ID] = batchObject.Object
   276  	}
   277  	return objects, nil
   278  }
   279  
   280  func (f *fakeVectorRepoContextual) VectorSearch(ctx context.Context,
   281  	params dto.GetParams,
   282  ) ([]search.Result, error) {
   283  	if params.SearchVector == nil {
   284  		filteredTargets := matchClassName(f.targets, params.ClassName)
   285  		return filteredTargets, nil
   286  	}
   287  
   288  	// simulate that this takes some time
   289  	time.Sleep(5 * time.Millisecond)
   290  
   291  	filteredTargets := matchClassName(f.targets, params.ClassName)
   292  	results := filteredTargets
   293  	sort.SliceStable(results, func(i, j int) bool {
   294  		simI, err := cosineSim(results[i].Vector, params.SearchVector)
   295  		if err != nil {
   296  			panic(err.Error())
   297  		}
   298  
   299  		simJ, err := cosineSim(results[j].Vector, params.SearchVector)
   300  		if err != nil {
   301  			panic(err.Error())
   302  		}
   303  		return simI > simJ
   304  	})
   305  
   306  	if len(results) == 0 {
   307  		return nil, f.errorOnAggregate
   308  	}
   309  
   310  	out := []search.Result{
   311  		results[0],
   312  	}
   313  
   314  	return out, f.errorOnAggregate
   315  }
   316  
   317  func matchClassName(in []search.Result, className string) []search.Result {
   318  	var out []search.Result
   319  	for _, item := range in {
   320  		if item.ClassName == className {
   321  			out = append(out, item)
   322  		}
   323  	}
   324  
   325  	return out
   326  }
   327  
   328  type fakeModulesProvider struct {
   329  	contextualClassifier modulecapabilities.Classifier
   330  }
   331  
   332  func (fmp *fakeModulesProvider) VectorFromInput(ctx context.Context, className string, input string) ([]float32, error) {
   333  	panic("not implemented")
   334  }
   335  
   336  func NewFakeModulesProvider(vectorizer *fakeVectorizer) *fakeModulesProvider {
   337  	return &fakeModulesProvider{New(vectorizer)}
   338  }
   339  
   340  func (fmp *fakeModulesProvider) ParseClassifierSettings(name string,
   341  	params *models.Classification,
   342  ) error {
   343  	return fmp.contextualClassifier.ParseClassifierSettings(params)
   344  }
   345  
   346  func (fmp *fakeModulesProvider) GetClassificationFn(className, name string,
   347  	params modulecapabilities.ClassifyParams,
   348  ) (modulecapabilities.ClassifyItemFn, error) {
   349  	return fmp.contextualClassifier.ClassifyFn(params)
   350  }