github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/fakes_for_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "context" 16 "fmt" 17 "sort" 18 "sync" 19 "time" 20 21 "github.com/go-openapi/strfmt" 22 "github.com/weaviate/weaviate/entities/additional" 23 "github.com/weaviate/weaviate/entities/dto" 24 libfilters "github.com/weaviate/weaviate/entities/filters" 25 "github.com/weaviate/weaviate/entities/models" 26 "github.com/weaviate/weaviate/entities/modulecapabilities" 27 "github.com/weaviate/weaviate/entities/schema" 28 "github.com/weaviate/weaviate/entities/search" 29 usecasesclassfication "github.com/weaviate/weaviate/usecases/classification" 30 "github.com/weaviate/weaviate/usecases/objects" 31 "github.com/weaviate/weaviate/usecases/sharding" 32 ) 33 34 type fakeSchemaGetter struct { 35 schema schema.Schema 36 } 37 38 func (f *fakeSchemaGetter) GetSchemaSkipAuth() schema.Schema { 39 return f.schema 40 } 41 42 func (f *fakeSchemaGetter) CopyShardingState(class string) *sharding.State { 43 panic("not implemented") 44 } 45 46 func (f *fakeSchemaGetter) ShardOwner(class, shard string) (string, error) { return "", nil } 47 func (f *fakeSchemaGetter) ShardReplicas(class, shard string) ([]string, error) { return nil, nil } 48 49 func (f *fakeSchemaGetter) TenantShard(class, tenant string) (string, string) { 50 return tenant, models.TenantActivityStatusHOT 51 } 52 func (f *fakeSchemaGetter) ShardFromUUID(class string, uuid []byte) string { return "" } 53 54 func (f *fakeSchemaGetter) Nodes() []string { 55 panic("not implemented") 56 } 57 58 func (f *fakeSchemaGetter) NodeName() string { 59 panic("not implemented") 60 } 61 62 func (f *fakeSchemaGetter) ClusterHealthScore() int { 63 panic("not implemented") 64 } 65 66 func (f *fakeSchemaGetter) ResolveParentNodes(string, string, 67 ) (map[string]string, error) { 68 panic("not implemented") 69 } 70 71 type fakeClassificationRepo struct { 72 sync.Mutex 73 db map[strfmt.UUID]models.Classification 74 } 75 76 func newFakeClassificationRepo() *fakeClassificationRepo { 77 return &fakeClassificationRepo{ 78 db: map[strfmt.UUID]models.Classification{}, 79 } 80 } 81 82 func (f *fakeClassificationRepo) Put(ctx context.Context, class models.Classification) error { 83 f.Lock() 84 defer f.Unlock() 85 86 f.db[class.ID] = class 87 return nil 88 } 89 90 func (f *fakeClassificationRepo) Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error) { 91 f.Lock() 92 defer f.Unlock() 93 94 class, ok := f.db[id] 95 if !ok { 96 return nil, nil 97 } 98 99 return &class, nil 100 } 101 102 func newFakeVectorRepoKNN(unclassified, classified search.Results) *fakeVectorRepoKNN { 103 return &fakeVectorRepoKNN{ 104 unclassified: unclassified, 105 classified: classified, 106 db: map[strfmt.UUID]*models.Object{}, 107 } 108 } 109 110 // read requests are specified through unclassified and classified, 111 // write requests (Put[Kind]) are stored in the db map 112 type fakeVectorRepoKNN struct { 113 sync.Mutex 114 unclassified []search.Result 115 classified []search.Result 116 db map[strfmt.UUID]*models.Object 117 errorOnAggregate error 118 batchStorageDelay time.Duration 119 } 120 121 func (f *fakeVectorRepoKNN) GetUnclassified(ctx context.Context, 122 class string, properties []string, 123 filter *libfilters.LocalFilter, 124 ) ([]search.Result, error) { 125 f.Lock() 126 defer f.Unlock() 127 return f.unclassified, nil 128 } 129 130 func (f *fakeVectorRepoKNN) AggregateNeighbors(ctx context.Context, vector []float32, 131 class string, properties []string, k int, 132 filter *libfilters.LocalFilter, 133 ) ([]usecasesclassfication.NeighborRef, error) { 134 f.Lock() 135 defer f.Unlock() 136 137 // simulate that this takes some time 138 time.Sleep(1 * time.Millisecond) 139 140 if k != 1 { 141 return nil, fmt.Errorf("fake vector repo only supports k=1") 142 } 143 144 results := f.classified 145 sort.SliceStable(results, func(i, j int) bool { 146 simI, err := cosineSim(results[i].Vector, vector) 147 if err != nil { 148 panic(err.Error()) 149 } 150 151 simJ, err := cosineSim(results[j].Vector, vector) 152 if err != nil { 153 panic(err.Error()) 154 } 155 return simI > simJ 156 }) 157 158 var out []usecasesclassfication.NeighborRef 159 schema := results[0].Schema.(map[string]interface{}) 160 for _, propName := range properties { 161 prop, ok := schema[propName] 162 if !ok { 163 return nil, fmt.Errorf("missing prop %s", propName) 164 } 165 166 refs := prop.(models.MultipleRef) 167 if len(refs) != 1 { 168 return nil, fmt.Errorf("wrong length %d", len(refs)) 169 } 170 171 out = append(out, usecasesclassfication.NeighborRef{ 172 Beacon: refs[0].Beacon, 173 WinningCount: 1, 174 OverallCount: 1, 175 LosingCount: 1, 176 Property: propName, 177 }) 178 } 179 180 return out, f.errorOnAggregate 181 } 182 183 func (f *fakeVectorRepoKNN) ZeroShotSearch(ctx context.Context, vector []float32, 184 class string, properties []string, 185 filter *libfilters.LocalFilter, 186 ) ([]search.Result, error) { 187 panic("not implemented") 188 } 189 190 func (f *fakeVectorRepoKNN) VectorSearch(ctx context.Context, 191 params dto.GetParams, 192 ) ([]search.Result, error) { 193 f.Lock() 194 defer f.Unlock() 195 return nil, fmt.Errorf("vector class search not implemented in fake") 196 } 197 198 func (f *fakeVectorRepoKNN) BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) { 199 f.Lock() 200 defer f.Unlock() 201 202 if f.batchStorageDelay > 0 { 203 time.Sleep(f.batchStorageDelay) 204 } 205 206 for _, batchObject := range objects { 207 f.db[batchObject.Object.ID] = batchObject.Object 208 } 209 return objects, nil 210 } 211 212 func (f *fakeVectorRepoKNN) get(id strfmt.UUID) (*models.Object, bool) { 213 f.Lock() 214 defer f.Unlock() 215 t, ok := f.db[id] 216 return t, ok 217 } 218 219 type fakeAuthorizer struct{} 220 221 func (f *fakeAuthorizer) Authorize(principal *models.Principal, verb, resource string) error { 222 return nil 223 } 224 225 func newFakeVectorRepoContextual(unclassified, targets search.Results) *fakeVectorRepoContextual { 226 return &fakeVectorRepoContextual{ 227 unclassified: unclassified, 228 targets: targets, 229 db: map[strfmt.UUID]*models.Object{}, 230 } 231 } 232 233 // read requests are specified through unclassified and classified, 234 // write requests (Put[Kind]) are stored in the db map 235 type fakeVectorRepoContextual struct { 236 sync.Mutex 237 unclassified []search.Result 238 targets []search.Result 239 db map[strfmt.UUID]*models.Object 240 errorOnAggregate error 241 } 242 243 func (f *fakeVectorRepoContextual) get(id strfmt.UUID) (*models.Object, bool) { 244 f.Lock() 245 defer f.Unlock() 246 t, ok := f.db[id] 247 return t, ok 248 } 249 250 func (f *fakeVectorRepoContextual) GetUnclassified(ctx context.Context, 251 class string, properties []string, 252 filter *libfilters.LocalFilter, 253 ) ([]search.Result, error) { 254 return f.unclassified, nil 255 } 256 257 func (f *fakeVectorRepoContextual) AggregateNeighbors(ctx context.Context, vector []float32, 258 class string, properties []string, k int, 259 filter *libfilters.LocalFilter, 260 ) ([]usecasesclassfication.NeighborRef, error) { 261 panic("not implemented") 262 } 263 264 func (f *fakeVectorRepoContextual) ZeroShotSearch(ctx context.Context, vector []float32, 265 class string, properties []string, 266 filter *libfilters.LocalFilter, 267 ) ([]search.Result, error) { 268 panic("not implemented") 269 } 270 271 func (f *fakeVectorRepoContextual) BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) { 272 f.Lock() 273 defer f.Unlock() 274 for _, batchObject := range objects { 275 f.db[batchObject.Object.ID] = batchObject.Object 276 } 277 return objects, nil 278 } 279 280 func (f *fakeVectorRepoContextual) VectorSearch(ctx context.Context, 281 params dto.GetParams, 282 ) ([]search.Result, error) { 283 if params.SearchVector == nil { 284 filteredTargets := matchClassName(f.targets, params.ClassName) 285 return filteredTargets, nil 286 } 287 288 // simulate that this takes some time 289 time.Sleep(5 * time.Millisecond) 290 291 filteredTargets := matchClassName(f.targets, params.ClassName) 292 results := filteredTargets 293 sort.SliceStable(results, func(i, j int) bool { 294 simI, err := cosineSim(results[i].Vector, params.SearchVector) 295 if err != nil { 296 panic(err.Error()) 297 } 298 299 simJ, err := cosineSim(results[j].Vector, params.SearchVector) 300 if err != nil { 301 panic(err.Error()) 302 } 303 return simI > simJ 304 }) 305 306 if len(results) == 0 { 307 return nil, f.errorOnAggregate 308 } 309 310 out := []search.Result{ 311 results[0], 312 } 313 314 return out, f.errorOnAggregate 315 } 316 317 func matchClassName(in []search.Result, className string) []search.Result { 318 var out []search.Result 319 for _, item := range in { 320 if item.ClassName == className { 321 out = append(out, item) 322 } 323 } 324 325 return out 326 } 327 328 type fakeModulesProvider struct { 329 contextualClassifier modulecapabilities.Classifier 330 } 331 332 func (fmp *fakeModulesProvider) VectorFromInput(ctx context.Context, className string, input string) ([]float32, error) { 333 panic("not implemented") 334 } 335 336 func NewFakeModulesProvider(vectorizer *fakeVectorizer) *fakeModulesProvider { 337 return &fakeModulesProvider{New(vectorizer)} 338 } 339 340 func (fmp *fakeModulesProvider) ParseClassifierSettings(name string, 341 params *models.Classification, 342 ) error { 343 return fmp.contextualClassifier.ParseClassifierSettings(params) 344 } 345 346 func (fmp *fakeModulesProvider) GetClassificationFn(className, name string, 347 params modulecapabilities.ClassifyParams, 348 ) (modulecapabilities.ClassifyItemFn, error) { 349 return fmp.contextualClassifier.ClassifyFn(params) 350 }