github.com/weaviate/weaviate@v1.24.6/usecases/traverser/grouper/merge_group.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package grouper 13 14 import ( 15 "fmt" 16 "strings" 17 18 "github.com/go-openapi/strfmt" 19 "github.com/weaviate/weaviate/entities/models" 20 "github.com/weaviate/weaviate/entities/schema/crossref" 21 "github.com/weaviate/weaviate/entities/search" 22 ) 23 24 type valueType int 25 26 const ( 27 numerical valueType = iota 28 textual 29 boolean 30 reference 31 geo 32 unknown 33 ) 34 35 type valueGroup struct { 36 values []interface{} 37 valueType valueType 38 name string 39 } 40 41 func (g group) flattenMerge() (search.Result, error) { 42 values := g.makeValueGroups() 43 merged, err := mergeValueGroups(values) 44 if err != nil { 45 return search.Result{}, fmt.Errorf("merge values: %v", err) 46 } 47 48 vector, err := g.mergeVectors() 49 if err != nil { 50 return search.Result{}, fmt.Errorf("merge vectors: %v", err) 51 } 52 53 className := g.mergeGetClassName() 54 55 return search.Result{ 56 ClassName: className, 57 Schema: merged, 58 Vector: vector, 59 }, nil 60 } 61 62 func (g group) mergeGetClassName() string { 63 if len(g.Elements) > 0 { 64 return g.Elements[0].ClassName 65 } 66 return "" 67 } 68 69 func (g group) makeValueGroups() map[string]valueGroup { 70 values := map[string]valueGroup{} 71 72 for _, elem := range g.Elements { 73 if elem.Schema == nil { 74 continue 75 } 76 77 for propName, propValue := range elem.Schema.(map[string]interface{}) { 78 current, ok := values[propName] 79 if !ok { 80 current = valueGroup{ 81 values: []interface{}{propValue}, 82 valueType: valueTypeOf(propValue), 83 name: propName, 84 } 85 values[propName] = current 86 continue 87 } 88 89 current.values = append(current.values, propValue) 90 values[propName] = current 91 } 92 } 93 94 return values 95 } 96 97 func (g group) mergeVectors() ([]float32, error) { 98 amount := len(g.Elements) 99 if amount == 0 { 100 return nil, nil 101 } 102 103 if amount == 1 { 104 return g.Elements[0].Vector, nil 105 } 106 107 dimensions := len(g.Elements[0].Vector) 108 out := make([]float32, dimensions) 109 110 // sum up 111 for _, groupElement := range g.Elements { 112 if len(groupElement.Vector) != dimensions { 113 return nil, fmt.Errorf("vectors have different dimensions") 114 } 115 116 for i, vectorElement := range groupElement.Vector { 117 out[i] = out[i] + vectorElement 118 } 119 } 120 121 // divide by amount of vectors 122 for i := range out { 123 out[i] = out[i] / float32(amount) 124 } 125 126 return out, nil 127 } 128 129 func mergeValueGroups(props map[string]valueGroup) (map[string]interface{}, error) { 130 mergedProps := map[string]interface{}{} 131 132 for propName, group := range props { 133 var ( 134 res interface{} 135 err error 136 ) 137 switch group.valueType { 138 case textual: 139 res, err = mergeTextualProps(group.values) 140 case numerical: 141 res, err = mergeNumericalProps(group.values) 142 case boolean: 143 res, err = mergeBooleanProps(group.values) 144 case geo: 145 res, err = mergeGeoProps(group.values) 146 case reference: 147 res, err = mergeReferenceProps(group.values) 148 case unknown: 149 continue 150 default: 151 err = fmt.Errorf("unrecognized value type") 152 } 153 if err != nil { 154 return nil, fmt.Errorf("prop '%s': %v", propName, err) 155 } 156 157 mergedProps[propName] = res 158 } 159 160 return mergedProps, nil 161 } 162 163 func valueTypeOf(in interface{}) valueType { 164 switch in.(type) { 165 case string: 166 return textual 167 case float64: 168 return numerical 169 case bool: 170 return boolean 171 case *models.GeoCoordinates: 172 return geo 173 // reference properties can be represented as either of these types. 174 // see https://github.com/weaviate/weaviate/pull/2320 175 case models.MultipleRef, []interface{}: 176 return reference 177 default: 178 return unknown 179 } 180 } 181 182 func mergeTextualProps(in []interface{}) (string, error) { 183 var values []string 184 seen := make(map[string]struct{}, len(in)) 185 for i, elem := range in { 186 asString, ok := elem.(string) 187 if !ok { 188 return "", fmt.Errorf("element %d: expected textual element to be string, but got %T", i, elem) 189 } 190 191 if _, ok := seen[asString]; ok { 192 // this is a duplicate, don't append it again 193 continue 194 } 195 196 seen[asString] = struct{}{} 197 values = append(values, asString) 198 } 199 200 if len(values) == 1 { 201 return values[0], nil 202 } 203 204 return fmt.Sprintf("%s (%s)", values[0], strings.Join(values[1:], ", ")), nil 205 } 206 207 func mergeNumericalProps(in []interface{}) (float64, error) { 208 var sum float64 209 for i, elem := range in { 210 asFloat, ok := elem.(float64) 211 if !ok { 212 return 0, fmt.Errorf("element %d: expected numerical element to be float64, but got %T", i, elem) 213 } 214 215 sum += asFloat 216 } 217 218 return sum / float64(len(in)), nil 219 } 220 221 func mergeBooleanProps(in []interface{}) (bool, error) { 222 var countTrue uint 223 var countFalse uint 224 for i, elem := range in { 225 asBool, ok := elem.(bool) 226 if !ok { 227 return false, fmt.Errorf("element %d: expected boolean element to be bool, but got %T", i, elem) 228 } 229 230 if asBool { 231 countTrue++ 232 } else { 233 countFalse++ 234 } 235 } 236 237 return countTrue >= countFalse, nil 238 } 239 240 func mergeGeoProps(in []interface{}) (*models.GeoCoordinates, error) { 241 var sumLat float32 242 var sumLon float32 243 244 for i, elem := range in { 245 asGeo, ok := elem.(*models.GeoCoordinates) 246 if !ok { 247 return nil, fmt.Errorf("element %d: expected geo element to be *models.GeoCoordinates, but got %T", i, elem) 248 } 249 250 if asGeo.Latitude != nil { 251 sumLat += *asGeo.Latitude 252 } 253 if asGeo.Longitude != nil { 254 sumLon += *asGeo.Longitude 255 } 256 } 257 258 return &models.GeoCoordinates{ 259 Latitude: ptFloat32(sumLat / float32(len(in))), 260 Longitude: ptFloat32(sumLon / float32(len(in))), 261 }, nil 262 } 263 264 func ptFloat32(in float32) *float32 { 265 return &in 266 } 267 268 func mergeReferenceProps(in []interface{}) ([]interface{}, error) { 269 var out []interface{} 270 seenID := map[string]struct{}{} 271 272 for i, elem := range in { 273 // because reference properties can be represented both as 274 // models.MultipleRef and []interface{}, we have to handle 275 // parsing both cases accordingly. 276 // see https://github.com/weaviate/weaviate/pull/2320 277 if asMultiRef, ok := elem.(models.MultipleRef); ok { 278 if err := parseRefTypeMultipleRef(asMultiRef, &out, seenID); err != nil { 279 return nil, fmt.Errorf("element %d: %w", i, err) 280 } 281 } else { 282 asSlice, ok := elem.([]interface{}) 283 if !ok { 284 return nil, fmt.Errorf( 285 "element %d: expected reference values to be slice, but got %T", i, elem) 286 } 287 288 if err := parseRefTypeInterfaceSlice(asSlice, &out, seenID); err != nil { 289 return nil, fmt.Errorf("element %d: %w", i, err) 290 } 291 } 292 } 293 294 return out, nil 295 } 296 297 func parseRefTypeMultipleRef(refs models.MultipleRef, 298 returnRefs *[]interface{}, seenIDs map[string]struct{}, 299 ) error { 300 for _, singleRef := range refs { 301 parsed, err := crossref.Parse(singleRef.Beacon.String()) 302 if err != nil { 303 return fmt.Errorf("failed to parse beacon %q: %w", singleRef.Beacon.String(), err) 304 } 305 idString := parsed.TargetID.String() 306 if _, ok := seenIDs[idString]; ok { 307 // duplicate 308 continue 309 } 310 311 *returnRefs = append(*returnRefs, singleRef) 312 seenIDs[idString] = struct{}{} // make sure we skip this next time 313 } 314 return nil 315 } 316 317 func parseRefTypeInterfaceSlice(refs []interface{}, 318 returnRefs *[]interface{}, seenIDs map[string]struct{}, 319 ) error { 320 for _, singleRef := range refs { 321 asRef, ok := singleRef.(search.LocalRef) 322 if !ok { 323 // don't know what to do with this type, ignore 324 continue 325 } 326 327 id, ok := asRef.Fields["id"] 328 if !ok { 329 return fmt.Errorf("found a search.LocalRef, but 'id' field is missing: %#v", asRef) 330 } 331 332 idString, err := getIDString(id) 333 if err != nil { 334 return err 335 } 336 337 if _, ok := seenIDs[idString]; ok { 338 // duplicate 339 continue 340 } 341 342 *returnRefs = append(*returnRefs, asRef) 343 seenIDs[idString] = struct{}{} // make sure we skip this next time 344 } 345 return nil 346 } 347 348 func getIDString(id interface{}) (string, error) { 349 switch v := id.(type) { 350 case strfmt.UUID: 351 return v.String(), nil 352 default: 353 return "", fmt.Errorf("found a search.LocalRef, 'id' field type expected to be strfmt.UUID but got %T", v) 354 } 355 }