github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/shard_group_by.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package db 13 14 import ( 15 "context" 16 "encoding/binary" 17 "encoding/json" 18 "fmt" 19 20 "github.com/weaviate/weaviate/adapters/repos/db/helpers" 21 "github.com/weaviate/weaviate/adapters/repos/db/lsmkv" 22 "github.com/weaviate/weaviate/entities/additional" 23 "github.com/weaviate/weaviate/entities/models" 24 "github.com/weaviate/weaviate/entities/schema" 25 "github.com/weaviate/weaviate/entities/searchparams" 26 "github.com/weaviate/weaviate/entities/storobj" 27 ) 28 29 func (s *Shard) groupResults(ctx context.Context, ids []uint64, 30 dists []float32, groupBy *searchparams.GroupBy, 31 additional additional.Properties, 32 ) ([]*storobj.Object, []float32, error) { 33 objsBucket := s.store.Bucket(helpers.ObjectsBucketLSM) 34 className := s.index.Config.ClassName 35 sch := s.index.getSchema.GetSchemaSkipAuth() 36 prop, err := sch.GetProperty(className, schema.PropertyName(groupBy.Property)) 37 if err != nil { 38 return nil, nil, fmt.Errorf("%w: unrecognized property: %s", 39 err, groupBy.Property) 40 } 41 dt, err := sch.FindPropertyDataType(prop.DataType) 42 if err != nil { 43 return nil, nil, fmt.Errorf("%w: unrecognized data type for property: %s", 44 err, groupBy.Property) 45 } 46 47 return newGrouper(ids, dists, groupBy, objsBucket, dt, additional).Do(ctx) 48 } 49 50 type grouper struct { 51 ids []uint64 52 dists []float32 53 groupBy *searchparams.GroupBy 54 additional additional.Properties 55 propertyDataType schema.PropertyDataType 56 objBucket *lsmkv.Bucket 57 } 58 59 func newGrouper(ids []uint64, dists []float32, 60 groupBy *searchparams.GroupBy, objBucket *lsmkv.Bucket, 61 propertyDataType schema.PropertyDataType, 62 additional additional.Properties, 63 ) *grouper { 64 return &grouper{ 65 ids: ids, 66 dists: dists, 67 groupBy: groupBy, 68 objBucket: objBucket, 69 propertyDataType: propertyDataType, 70 additional: additional, 71 } 72 } 73 74 func (g *grouper) Do(ctx context.Context) ([]*storobj.Object, []float32, error) { 75 docIDBytes := make([]byte, 8) 76 77 groupsOrdered := []string{} 78 groups := map[string][]uint64{} 79 docIDObject := map[uint64]*storobj.Object{} 80 docIDDistance := map[uint64]float32{} 81 82 DOCS_LOOP: 83 for i, docID := range g.ids { 84 binary.LittleEndian.PutUint64(docIDBytes, docID) 85 objData, err := g.objBucket.GetBySecondary(0, docIDBytes) 86 if err != nil { 87 return nil, nil, fmt.Errorf("%w: could not get obj by doc id %d", err, docID) 88 } 89 if objData == nil { 90 continue 91 } 92 value, ok, _ := storobj.ParseAndExtractProperty(objData, g.groupBy.Property) 93 if !ok { 94 continue 95 } 96 97 values, err := g.getValues(value) 98 if err != nil { 99 return nil, nil, err 100 } 101 102 for _, val := range values { 103 current, groupExists := groups[val] 104 if len(current) >= g.groupBy.ObjectsPerGroup { 105 continue 106 } 107 108 if !groupExists && len(groups) >= g.groupBy.Groups { 109 continue DOCS_LOOP 110 } 111 112 groups[val] = append(current, docID) 113 114 if !groupExists { 115 // this group doesn't exist add it to the ordered list 116 groupsOrdered = append(groupsOrdered, val) 117 } 118 119 if _, ok := docIDObject[docID]; !ok { 120 // whole object, might be that we only need value and ID to be extracted 121 unmarshalled, err := storobj.FromBinaryOptional(objData, g.additional) 122 if err != nil { 123 return nil, nil, fmt.Errorf("%w: unmarshal data object at position %d", err, i) 124 } 125 docIDObject[docID] = unmarshalled 126 docIDDistance[docID] = g.dists[i] 127 } 128 } 129 } 130 131 objs := make([]*storobj.Object, len(groupsOrdered)) 132 dists := make([]float32, len(groupsOrdered)) 133 objIDs := []uint64{} 134 for i, val := range groupsOrdered { 135 docIDs := groups[val] 136 unmarshalled, err := g.getUnmarshalled(docIDs[0], docIDObject, objIDs) 137 if err != nil { 138 return nil, nil, err 139 } 140 dist := docIDDistance[docIDs[0]] 141 objIDs = append(objIDs, docIDs[0]) 142 hits := make([]map[string]interface{}, len(docIDs)) 143 for j, docID := range docIDs { 144 props := map[string]interface{}{} 145 for k, v := range docIDObject[docID].Properties().(map[string]interface{}) { 146 props[k] = v 147 } 148 props["_additional"] = &additional.GroupHitAdditional{ 149 ID: docIDObject[docID].ID(), 150 Distance: docIDDistance[docID], 151 Vector: docIDObject[docID].Vector, 152 } 153 hits[j] = props 154 } 155 group := &additional.Group{ 156 ID: i, 157 GroupedBy: &additional.GroupedBy{ 158 Value: val, 159 Path: []string{g.groupBy.Property}, 160 }, 161 Count: len(hits), 162 Hits: hits, 163 MinDistance: docIDDistance[docIDs[0]], 164 MaxDistance: docIDDistance[docIDs[len(docIDs)-1]], 165 } 166 167 // add group 168 if unmarshalled.AdditionalProperties() == nil { 169 unmarshalled.Object.Additional = models.AdditionalProperties{} 170 } 171 unmarshalled.AdditionalProperties()["group"] = group 172 173 objs[i] = unmarshalled 174 dists[i] = dist 175 } 176 177 return objs, dists, nil 178 } 179 180 func (g *grouper) getUnmarshalled(docID uint64, 181 docIDObject map[uint64]*storobj.Object, 182 objIDs []uint64, 183 ) (*storobj.Object, error) { 184 containsDocID := false 185 for i := range objIDs { 186 if objIDs[i] == docID { 187 containsDocID = true 188 break 189 } 190 } 191 if containsDocID { 192 // we have already added this object containing a group to the result array 193 // and we need to unmarshall it again so that a group won't get overridden 194 docIDBytes := make([]byte, 8) 195 binary.LittleEndian.PutUint64(docIDBytes, docID) 196 objData, err := g.objBucket.GetBySecondary(0, docIDBytes) 197 if err != nil { 198 return nil, fmt.Errorf("%w: could not get obj by doc id %d", err, docID) 199 } 200 unmarshalled, err := storobj.FromBinaryOptional(objData, g.additional) 201 if err != nil { 202 return nil, fmt.Errorf("%w: unmarshal data object doc id %d", err, docID) 203 } 204 return unmarshalled, nil 205 } 206 return docIDObject[docID], nil 207 } 208 209 func (g *grouper) getValues(values []string) ([]string, error) { 210 if len(values) == 0 { 211 return []string{""}, nil 212 } 213 if g.propertyDataType.IsReference() { 214 beacons := make([]string, len(values)) 215 for i := range values { 216 if values[i] != "" { 217 var ref models.SingleRef 218 err := json.Unmarshal([]byte(values[i]), &ref) 219 if err != nil { 220 return nil, fmt.Errorf("%w: unmarshal grouped by value %s at position %d", 221 err, values[i], i) 222 } 223 beacons[i] = ref.Beacon.String() 224 } 225 } 226 return beacons, nil 227 } 228 return values, nil 229 }