github.com/weaviate/weaviate@v1.24.6/usecases/classification/fakes_for_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "context" 16 "fmt" 17 "math" 18 "sort" 19 "sync" 20 "time" 21 22 "github.com/go-openapi/strfmt" 23 "github.com/pkg/errors" 24 "github.com/weaviate/weaviate/entities/additional" 25 "github.com/weaviate/weaviate/entities/dto" 26 libfilters "github.com/weaviate/weaviate/entities/filters" 27 "github.com/weaviate/weaviate/entities/models" 28 "github.com/weaviate/weaviate/entities/modulecapabilities" 29 "github.com/weaviate/weaviate/entities/schema" 30 "github.com/weaviate/weaviate/entities/search" 31 "github.com/weaviate/weaviate/usecases/objects" 32 "github.com/weaviate/weaviate/usecases/sharding" 33 ) 34 35 type fakeSchemaGetter struct { 36 schema schema.Schema 37 } 38 39 func (f *fakeSchemaGetter) GetSchemaSkipAuth() schema.Schema { 40 return f.schema 41 } 42 43 func (f *fakeSchemaGetter) CopyShardingState(class string) *sharding.State { 44 panic("not implemented") 45 } 46 47 func (f *fakeSchemaGetter) ShardOwner(class, shard string) (string, error) { 48 return shard, nil 49 } 50 51 func (f *fakeSchemaGetter) ShardReplicas(class, shard string) ([]string, error) { 52 return []string{shard}, nil 53 } 54 55 func (f *fakeSchemaGetter) TenantShard(class, tenant string) (string, string) { 56 return tenant, models.TenantActivityStatusHOT 57 } 58 func (f *fakeSchemaGetter) ShardFromUUID(class string, uuid []byte) string { return string(uuid) } 59 60 func (f *fakeSchemaGetter) Nodes() []string { 61 panic("not implemented") 62 } 63 64 func (f *fakeSchemaGetter) NodeName() string { 65 panic("not implemented") 66 } 67 68 func (f *fakeSchemaGetter) ClusterHealthScore() int { 69 panic("not implemented") 70 } 71 72 func (f *fakeSchemaGetter) ResolveParentNodes(string, string, 73 ) (map[string]string, error) { 74 panic("not implemented") 75 } 76 77 type fakeClassificationRepo struct { 78 sync.Mutex 79 db map[strfmt.UUID]models.Classification 80 } 81 82 func newFakeClassificationRepo() *fakeClassificationRepo { 83 return &fakeClassificationRepo{ 84 db: map[strfmt.UUID]models.Classification{}, 85 } 86 } 87 88 func (f *fakeClassificationRepo) Put(ctx context.Context, class models.Classification) error { 89 f.Lock() 90 defer f.Unlock() 91 92 f.db[class.ID] = class 93 return nil 94 } 95 96 func (f *fakeClassificationRepo) Get(ctx context.Context, id strfmt.UUID) (*models.Classification, error) { 97 f.Lock() 98 defer f.Unlock() 99 100 class, ok := f.db[id] 101 if !ok { 102 return nil, nil 103 } 104 105 return &class, nil 106 } 107 108 func newFakeVectorRepoKNN(unclassified, classified search.Results) *fakeVectorRepoKNN { 109 return &fakeVectorRepoKNN{ 110 unclassified: unclassified, 111 classified: classified, 112 db: map[strfmt.UUID]*models.Object{}, 113 } 114 } 115 116 // read requests are specified through unclassified and classified, 117 // write requests (Put[Kind]) are stored in the db map 118 type fakeVectorRepoKNN struct { 119 sync.Mutex 120 unclassified []search.Result 121 classified []search.Result 122 db map[strfmt.UUID]*models.Object 123 errorOnAggregate error 124 batchStorageDelay time.Duration 125 } 126 127 func (f *fakeVectorRepoKNN) GetUnclassified(ctx context.Context, 128 class string, properties []string, 129 filter *libfilters.LocalFilter, 130 ) ([]search.Result, error) { 131 f.Lock() 132 defer f.Unlock() 133 return f.unclassified, nil 134 } 135 136 func (f *fakeVectorRepoKNN) AggregateNeighbors(ctx context.Context, vector []float32, 137 class string, properties []string, k int, 138 filter *libfilters.LocalFilter, 139 ) ([]NeighborRef, error) { 140 f.Lock() 141 defer f.Unlock() 142 143 // simulate that this takes some time 144 time.Sleep(1 * time.Millisecond) 145 146 if k != 1 { 147 return nil, fmt.Errorf("fake vector repo only supports k=1") 148 } 149 150 results := f.classified 151 sort.SliceStable(results, func(i, j int) bool { 152 simI, err := cosineSim(results[i].Vector, vector) 153 if err != nil { 154 panic(err.Error()) 155 } 156 157 simJ, err := cosineSim(results[j].Vector, vector) 158 if err != nil { 159 panic(err.Error()) 160 } 161 return simI > simJ 162 }) 163 164 var out []NeighborRef 165 schema := results[0].Schema.(map[string]interface{}) 166 for _, propName := range properties { 167 prop, ok := schema[propName] 168 if !ok { 169 return nil, fmt.Errorf("missing prop %s", propName) 170 } 171 172 refs := prop.(models.MultipleRef) 173 if len(refs) != 1 { 174 return nil, fmt.Errorf("wrong length %d", len(refs)) 175 } 176 177 out = append(out, NeighborRef{ 178 Beacon: refs[0].Beacon, 179 WinningCount: 1, 180 OverallCount: 1, 181 LosingCount: 1, 182 Property: propName, 183 }) 184 } 185 186 return out, f.errorOnAggregate 187 } 188 189 func (f *fakeVectorRepoKNN) ZeroShotSearch(ctx context.Context, vector []float32, 190 class string, properties []string, 191 filter *libfilters.LocalFilter, 192 ) ([]search.Result, error) { 193 return []search.Result{}, nil 194 } 195 196 func (f *fakeVectorRepoKNN) VectorSearch(ctx context.Context, 197 params dto.GetParams, 198 ) ([]search.Result, error) { 199 f.Lock() 200 defer f.Unlock() 201 return nil, fmt.Errorf("vector class search not implemented in fake") 202 } 203 204 func (f *fakeVectorRepoKNN) BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) { 205 f.Lock() 206 defer f.Unlock() 207 208 if f.batchStorageDelay > 0 { 209 time.Sleep(f.batchStorageDelay) 210 } 211 212 for _, batchObject := range objects { 213 f.db[batchObject.Object.ID] = batchObject.Object 214 } 215 return objects, nil 216 } 217 218 func (f *fakeVectorRepoKNN) get(id strfmt.UUID) (*models.Object, bool) { 219 f.Lock() 220 defer f.Unlock() 221 t, ok := f.db[id] 222 return t, ok 223 } 224 225 type fakeAuthorizer struct{} 226 227 func (f *fakeAuthorizer) Authorize(principal *models.Principal, verb, resource string) error { 228 return nil 229 } 230 231 func newFakeVectorRepoContextual(unclassified, targets search.Results) *fakeVectorRepoContextual { 232 return &fakeVectorRepoContextual{ 233 unclassified: unclassified, 234 targets: targets, 235 db: map[strfmt.UUID]*models.Object{}, 236 } 237 } 238 239 // read requests are specified through unclassified and classified, 240 // write requests (Put[Kind]) are stored in the db map 241 type fakeVectorRepoContextual struct { 242 sync.Mutex 243 unclassified []search.Result 244 targets []search.Result 245 db map[strfmt.UUID]*models.Object 246 errorOnAggregate error 247 } 248 249 func (f *fakeVectorRepoContextual) get(id strfmt.UUID) (*models.Object, bool) { 250 f.Lock() 251 defer f.Unlock() 252 t, ok := f.db[id] 253 return t, ok 254 } 255 256 func (f *fakeVectorRepoContextual) GetUnclassified(ctx context.Context, 257 class string, properties []string, 258 filter *libfilters.LocalFilter, 259 ) ([]search.Result, error) { 260 return f.unclassified, nil 261 } 262 263 func (f *fakeVectorRepoContextual) AggregateNeighbors(ctx context.Context, vector []float32, 264 class string, properties []string, k int, 265 filter *libfilters.LocalFilter, 266 ) ([]NeighborRef, error) { 267 panic("not implemented") 268 } 269 270 func (f *fakeVectorRepoContextual) ZeroShotSearch(ctx context.Context, vector []float32, 271 class string, properties []string, 272 filter *libfilters.LocalFilter, 273 ) ([]search.Result, error) { 274 panic("not implemented") 275 } 276 277 func (f *fakeVectorRepoContextual) BatchPutObjects(ctx context.Context, objects objects.BatchObjects, repl *additional.ReplicationProperties) (objects.BatchObjects, error) { 278 f.Lock() 279 defer f.Unlock() 280 for _, batchObject := range objects { 281 f.db[batchObject.Object.ID] = batchObject.Object 282 } 283 return objects, nil 284 } 285 286 func (f *fakeVectorRepoContextual) VectorSearch(ctx context.Context, 287 params dto.GetParams, 288 ) ([]search.Result, error) { 289 if params.SearchVector == nil { 290 filteredTargets := matchClassName(f.targets, params.ClassName) 291 return filteredTargets, nil 292 } 293 294 // simulate that this takes some time 295 time.Sleep(5 * time.Millisecond) 296 297 filteredTargets := matchClassName(f.targets, params.ClassName) 298 results := filteredTargets 299 sort.SliceStable(results, func(i, j int) bool { 300 simI, err := cosineSim(results[i].Vector, params.SearchVector) 301 if err != nil { 302 panic(err.Error()) 303 } 304 305 simJ, err := cosineSim(results[j].Vector, params.SearchVector) 306 if err != nil { 307 panic(err.Error()) 308 } 309 return simI > simJ 310 }) 311 312 if len(results) == 0 { 313 return nil, f.errorOnAggregate 314 } 315 316 out := []search.Result{ 317 results[0], 318 } 319 320 return out, f.errorOnAggregate 321 } 322 323 func cosineSim(a, b []float32) (float32, error) { 324 if len(a) != len(b) { 325 return 0, fmt.Errorf("vectors have different dimensions") 326 } 327 328 var ( 329 sumProduct float64 330 sumASquare float64 331 sumBSquare float64 332 ) 333 334 for i := range a { 335 sumProduct += float64(a[i] * b[i]) 336 sumASquare += float64(a[i] * a[i]) 337 sumBSquare += float64(b[i] * b[i]) 338 } 339 340 return float32(sumProduct / (math.Sqrt(sumASquare) * math.Sqrt(sumBSquare))), nil 341 } 342 343 func matchClassName(in []search.Result, className string) []search.Result { 344 var out []search.Result 345 for _, item := range in { 346 if item.ClassName == className { 347 out = append(out, item) 348 } 349 } 350 351 return out 352 } 353 354 type fakeModuleClassifyFn struct { 355 fakeExactCategoryMappings map[string]string 356 fakeMainCategoryMappings map[string]string 357 } 358 359 func NewFakeModuleClassifyFn() *fakeModuleClassifyFn { 360 return &fakeModuleClassifyFn{ 361 fakeExactCategoryMappings: map[string]string{ 362 "75ba35af-6a08-40ae-b442-3bec69b355f9": "1b204f16-7da6-44fd-bbd2-8cc4a7414bc3", 363 "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109": "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2", 364 "069410c3-4b9e-4f68-8034-32a066cb7997": "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2", 365 "06a1e824-889c-4649-97f9-1ed3fa401d8e": "027b708a-31ca-43ea-9001-88bec864c79c", 366 }, 367 fakeMainCategoryMappings: map[string]string{ 368 "6402e649-b1e0-40ea-b192-a64eab0d5e56": "5a3d909a-4f0d-4168-8f5c-cd3074d1e79a", 369 "f850439a-d3cd-4f17-8fbf-5a64405645cd": "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e", 370 "069410c3-4b9e-4f68-8034-32a066cb7997": "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e", 371 }, 372 } 373 } 374 375 func (c *fakeModuleClassifyFn) classifyFn(item search.Result, itemIndex int, 376 params models.Classification, filters modulecapabilities.Filters, writer modulecapabilities.Writer, 377 ) error { 378 var classified []string 379 380 classifiedProp := c.fakeClassification(&item, "exactCategory", c.fakeExactCategoryMappings) 381 if len(classifiedProp) > 0 { 382 classified = append(classified, classifiedProp) 383 } 384 385 classifiedProp = c.fakeClassification(&item, "mainCategory", c.fakeMainCategoryMappings) 386 if len(classifiedProp) > 0 { 387 classified = append(classified, classifiedProp) 388 } 389 390 c.extendItemWithObjectMeta(&item, params, classified) 391 392 err := writer.Store(item) 393 if err != nil { 394 return fmt.Errorf("store %s/%s: %v", item.ClassName, item.ID, err) 395 } 396 return nil 397 } 398 399 func (c *fakeModuleClassifyFn) fakeClassification(item *search.Result, propName string, 400 fakes map[string]string, 401 ) string { 402 if target, ok := fakes[item.ID.String()]; ok { 403 beacon := "weaviate://localhost/" + target 404 item.Schema.(map[string]interface{})[propName] = models.MultipleRef{ 405 &models.SingleRef{ 406 Beacon: strfmt.URI(beacon), 407 Classification: nil, 408 }, 409 } 410 return propName 411 } 412 return "" 413 } 414 415 func (c *fakeModuleClassifyFn) extendItemWithObjectMeta(item *search.Result, 416 params models.Classification, classified []string, 417 ) { 418 if item.AdditionalProperties == nil { 419 item.AdditionalProperties = models.AdditionalProperties{} 420 } 421 422 item.AdditionalProperties["classification"] = additional.Classification{ 423 ID: params.ID, 424 Scope: params.ClassifyProperties, 425 ClassifiedFields: classified, 426 Completed: strfmt.DateTime(time.Now()), 427 } 428 } 429 430 type fakeModulesProvider struct { 431 fakeModuleClassifyFn *fakeModuleClassifyFn 432 } 433 434 func NewFakeModulesProvider() *fakeModulesProvider { 435 return &fakeModulesProvider{NewFakeModuleClassifyFn()} 436 } 437 438 func (m *fakeModulesProvider) ParseClassifierSettings(name string, 439 params *models.Classification, 440 ) error { 441 return nil 442 } 443 444 func (m *fakeModulesProvider) GetClassificationFn(className, name string, 445 params modulecapabilities.ClassifyParams, 446 ) (modulecapabilities.ClassifyItemFn, error) { 447 if name == "text2vec-contextionary-custom-contextual" { 448 return m.fakeModuleClassifyFn.classifyFn, nil 449 } 450 return nil, errors.Errorf("classifier %s not found", name) 451 }