github.com/weaviate/weaviate@v1.24.6/usecases/traverser/grouper/merge_group.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package grouper
    13  
    14  import (
    15  	"fmt"
    16  	"strings"
    17  
    18  	"github.com/go-openapi/strfmt"
    19  	"github.com/weaviate/weaviate/entities/models"
    20  	"github.com/weaviate/weaviate/entities/schema/crossref"
    21  	"github.com/weaviate/weaviate/entities/search"
    22  )
    23  
    24  type valueType int
    25  
    26  const (
    27  	numerical valueType = iota
    28  	textual
    29  	boolean
    30  	reference
    31  	geo
    32  	unknown
    33  )
    34  
    35  type valueGroup struct {
    36  	values    []interface{}
    37  	valueType valueType
    38  	name      string
    39  }
    40  
    41  func (g group) flattenMerge() (search.Result, error) {
    42  	values := g.makeValueGroups()
    43  	merged, err := mergeValueGroups(values)
    44  	if err != nil {
    45  		return search.Result{}, fmt.Errorf("merge values: %v", err)
    46  	}
    47  
    48  	vector, err := g.mergeVectors()
    49  	if err != nil {
    50  		return search.Result{}, fmt.Errorf("merge vectors: %v", err)
    51  	}
    52  
    53  	className := g.mergeGetClassName()
    54  
    55  	return search.Result{
    56  		ClassName: className,
    57  		Schema:    merged,
    58  		Vector:    vector,
    59  	}, nil
    60  }
    61  
    62  func (g group) mergeGetClassName() string {
    63  	if len(g.Elements) > 0 {
    64  		return g.Elements[0].ClassName
    65  	}
    66  	return ""
    67  }
    68  
    69  func (g group) makeValueGroups() map[string]valueGroup {
    70  	values := map[string]valueGroup{}
    71  
    72  	for _, elem := range g.Elements {
    73  		if elem.Schema == nil {
    74  			continue
    75  		}
    76  
    77  		for propName, propValue := range elem.Schema.(map[string]interface{}) {
    78  			current, ok := values[propName]
    79  			if !ok {
    80  				current = valueGroup{
    81  					values:    []interface{}{propValue},
    82  					valueType: valueTypeOf(propValue),
    83  					name:      propName,
    84  				}
    85  				values[propName] = current
    86  				continue
    87  			}
    88  
    89  			current.values = append(current.values, propValue)
    90  			values[propName] = current
    91  		}
    92  	}
    93  
    94  	return values
    95  }
    96  
    97  func (g group) mergeVectors() ([]float32, error) {
    98  	amount := len(g.Elements)
    99  	if amount == 0 {
   100  		return nil, nil
   101  	}
   102  
   103  	if amount == 1 {
   104  		return g.Elements[0].Vector, nil
   105  	}
   106  
   107  	dimensions := len(g.Elements[0].Vector)
   108  	out := make([]float32, dimensions)
   109  
   110  	// sum up
   111  	for _, groupElement := range g.Elements {
   112  		if len(groupElement.Vector) != dimensions {
   113  			return nil, fmt.Errorf("vectors have different dimensions")
   114  		}
   115  
   116  		for i, vectorElement := range groupElement.Vector {
   117  			out[i] = out[i] + vectorElement
   118  		}
   119  	}
   120  
   121  	// divide by amount of vectors
   122  	for i := range out {
   123  		out[i] = out[i] / float32(amount)
   124  	}
   125  
   126  	return out, nil
   127  }
   128  
   129  func mergeValueGroups(props map[string]valueGroup) (map[string]interface{}, error) {
   130  	mergedProps := map[string]interface{}{}
   131  
   132  	for propName, group := range props {
   133  		var (
   134  			res interface{}
   135  			err error
   136  		)
   137  		switch group.valueType {
   138  		case textual:
   139  			res, err = mergeTextualProps(group.values)
   140  		case numerical:
   141  			res, err = mergeNumericalProps(group.values)
   142  		case boolean:
   143  			res, err = mergeBooleanProps(group.values)
   144  		case geo:
   145  			res, err = mergeGeoProps(group.values)
   146  		case reference:
   147  			res, err = mergeReferenceProps(group.values)
   148  		case unknown:
   149  			continue
   150  		default:
   151  			err = fmt.Errorf("unrecognized value type")
   152  		}
   153  		if err != nil {
   154  			return nil, fmt.Errorf("prop '%s': %v", propName, err)
   155  		}
   156  
   157  		mergedProps[propName] = res
   158  	}
   159  
   160  	return mergedProps, nil
   161  }
   162  
   163  func valueTypeOf(in interface{}) valueType {
   164  	switch in.(type) {
   165  	case string:
   166  		return textual
   167  	case float64:
   168  		return numerical
   169  	case bool:
   170  		return boolean
   171  	case *models.GeoCoordinates:
   172  		return geo
   173  	// reference properties can be represented as either of these types.
   174  	// see https://github.com/weaviate/weaviate/pull/2320
   175  	case models.MultipleRef, []interface{}:
   176  		return reference
   177  	default:
   178  		return unknown
   179  	}
   180  }
   181  
   182  func mergeTextualProps(in []interface{}) (string, error) {
   183  	var values []string
   184  	seen := make(map[string]struct{}, len(in))
   185  	for i, elem := range in {
   186  		asString, ok := elem.(string)
   187  		if !ok {
   188  			return "", fmt.Errorf("element %d: expected textual element to be string, but got %T", i, elem)
   189  		}
   190  
   191  		if _, ok := seen[asString]; ok {
   192  			// this is a duplicate, don't append it again
   193  			continue
   194  		}
   195  
   196  		seen[asString] = struct{}{}
   197  		values = append(values, asString)
   198  	}
   199  
   200  	if len(values) == 1 {
   201  		return values[0], nil
   202  	}
   203  
   204  	return fmt.Sprintf("%s (%s)", values[0], strings.Join(values[1:], ", ")), nil
   205  }
   206  
   207  func mergeNumericalProps(in []interface{}) (float64, error) {
   208  	var sum float64
   209  	for i, elem := range in {
   210  		asFloat, ok := elem.(float64)
   211  		if !ok {
   212  			return 0, fmt.Errorf("element %d: expected numerical element to be float64, but got %T", i, elem)
   213  		}
   214  
   215  		sum += asFloat
   216  	}
   217  
   218  	return sum / float64(len(in)), nil
   219  }
   220  
   221  func mergeBooleanProps(in []interface{}) (bool, error) {
   222  	var countTrue uint
   223  	var countFalse uint
   224  	for i, elem := range in {
   225  		asBool, ok := elem.(bool)
   226  		if !ok {
   227  			return false, fmt.Errorf("element %d: expected boolean element to be bool, but got %T", i, elem)
   228  		}
   229  
   230  		if asBool {
   231  			countTrue++
   232  		} else {
   233  			countFalse++
   234  		}
   235  	}
   236  
   237  	return countTrue >= countFalse, nil
   238  }
   239  
   240  func mergeGeoProps(in []interface{}) (*models.GeoCoordinates, error) {
   241  	var sumLat float32
   242  	var sumLon float32
   243  
   244  	for i, elem := range in {
   245  		asGeo, ok := elem.(*models.GeoCoordinates)
   246  		if !ok {
   247  			return nil, fmt.Errorf("element %d: expected geo element to be *models.GeoCoordinates, but got %T", i, elem)
   248  		}
   249  
   250  		if asGeo.Latitude != nil {
   251  			sumLat += *asGeo.Latitude
   252  		}
   253  		if asGeo.Longitude != nil {
   254  			sumLon += *asGeo.Longitude
   255  		}
   256  	}
   257  
   258  	return &models.GeoCoordinates{
   259  		Latitude:  ptFloat32(sumLat / float32(len(in))),
   260  		Longitude: ptFloat32(sumLon / float32(len(in))),
   261  	}, nil
   262  }
   263  
   264  func ptFloat32(in float32) *float32 {
   265  	return &in
   266  }
   267  
   268  func mergeReferenceProps(in []interface{}) ([]interface{}, error) {
   269  	var out []interface{}
   270  	seenID := map[string]struct{}{}
   271  
   272  	for i, elem := range in {
   273  		// because reference properties can be represented both as
   274  		// models.MultipleRef and []interface{}, we have to handle
   275  		// parsing both cases accordingly.
   276  		// see https://github.com/weaviate/weaviate/pull/2320
   277  		if asMultiRef, ok := elem.(models.MultipleRef); ok {
   278  			if err := parseRefTypeMultipleRef(asMultiRef, &out, seenID); err != nil {
   279  				return nil, fmt.Errorf("element %d: %w", i, err)
   280  			}
   281  		} else {
   282  			asSlice, ok := elem.([]interface{})
   283  			if !ok {
   284  				return nil, fmt.Errorf(
   285  					"element %d: expected reference values to be slice, but got %T", i, elem)
   286  			}
   287  
   288  			if err := parseRefTypeInterfaceSlice(asSlice, &out, seenID); err != nil {
   289  				return nil, fmt.Errorf("element %d: %w", i, err)
   290  			}
   291  		}
   292  	}
   293  
   294  	return out, nil
   295  }
   296  
   297  func parseRefTypeMultipleRef(refs models.MultipleRef,
   298  	returnRefs *[]interface{}, seenIDs map[string]struct{},
   299  ) error {
   300  	for _, singleRef := range refs {
   301  		parsed, err := crossref.Parse(singleRef.Beacon.String())
   302  		if err != nil {
   303  			return fmt.Errorf("failed to parse beacon %q: %w", singleRef.Beacon.String(), err)
   304  		}
   305  		idString := parsed.TargetID.String()
   306  		if _, ok := seenIDs[idString]; ok {
   307  			// duplicate
   308  			continue
   309  		}
   310  
   311  		*returnRefs = append(*returnRefs, singleRef)
   312  		seenIDs[idString] = struct{}{} // make sure we skip this next time
   313  	}
   314  	return nil
   315  }
   316  
   317  func parseRefTypeInterfaceSlice(refs []interface{},
   318  	returnRefs *[]interface{}, seenIDs map[string]struct{},
   319  ) error {
   320  	for _, singleRef := range refs {
   321  		asRef, ok := singleRef.(search.LocalRef)
   322  		if !ok {
   323  			// don't know what to do with this type, ignore
   324  			continue
   325  		}
   326  
   327  		id, ok := asRef.Fields["id"]
   328  		if !ok {
   329  			return fmt.Errorf("found a search.LocalRef, but 'id' field is missing: %#v", asRef)
   330  		}
   331  
   332  		idString, err := getIDString(id)
   333  		if err != nil {
   334  			return err
   335  		}
   336  
   337  		if _, ok := seenIDs[idString]; ok {
   338  			// duplicate
   339  			continue
   340  		}
   341  
   342  		*returnRefs = append(*returnRefs, asRef)
   343  		seenIDs[idString] = struct{}{} // make sure we skip this next time
   344  	}
   345  	return nil
   346  }
   347  
   348  func getIDString(id interface{}) (string, error) {
   349  	switch v := id.(type) {
   350  	case strfmt.UUID:
   351  		return v.String(), nil
   352  	default:
   353  		return "", fmt.Errorf("found a search.LocalRef, 'id' field type expected to be strfmt.UUID but got %T", v)
   354  	}
   355  }