github.com/weaviate/weaviate@v1.24.6/usecases/classification/classifier_run_zeroshot.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "time" 16 17 "github.com/pkg/errors" 18 "github.com/weaviate/weaviate/entities/models" 19 "github.com/weaviate/weaviate/entities/schema" 20 "github.com/weaviate/weaviate/entities/schema/crossref" 21 "github.com/weaviate/weaviate/entities/search" 22 ) 23 24 func (c *Classifier) classifyItemUsingZeroShot(item search.Result, itemIndex int, 25 params models.Classification, filters Filters, writer Writer, 26 ) error { 27 ctx, cancel := contextWithTimeout(2 * time.Second) 28 defer cancel() 29 30 properties := params.ClassifyProperties 31 32 s := c.schemaGetter.GetSchemaSkipAuth() 33 class := s.GetClass(schema.ClassName(item.ClassName)) 34 35 classifyProp := []string{} 36 for _, prop := range properties { 37 for _, classProp := range class.Properties { 38 if classProp.Name == prop { 39 classifyProp = append(classifyProp, classProp.DataType...) 40 } 41 } 42 } 43 44 var classified []string 45 for _, className := range classifyProp { 46 for _, prop := range properties { 47 res, err := c.vectorRepo.ZeroShotSearch(ctx, item.Vector, className, 48 params.ClassifyProperties, filters.Target()) 49 if err != nil { 50 return errors.Wrap(err, "zeroshot: search") 51 } 52 53 if len(res) > 0 { 54 cref := crossref.NewLocalhost(res[0].ClassName, res[0].ID) 55 item.Schema.(map[string]interface{})[prop] = models.MultipleRef{ 56 &models.SingleRef{ 57 Beacon: cref.SingleRef().Beacon, 58 Classification: &models.ReferenceMetaClassification{}, 59 }, 60 } 61 classified = append(classified, prop) 62 } 63 } 64 } 65 66 c.extendItemWithObjectMeta(&item, params, classified) 67 err := writer.Store(item) 68 if err != nil { 69 return errors.Errorf("store %s/%s: %v", item.ClassName, item.ID, err) 70 } 71 72 return nil 73 }