github.com/weaviate/weaviate@v1.24.6/usecases/traverser/grouper/grouper.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package grouper
    13  
    14  import (
    15  	"fmt"
    16  
    17  	"github.com/sirupsen/logrus"
    18  	"github.com/weaviate/weaviate/entities/search"
    19  	"github.com/weaviate/weaviate/usecases/vectorizer"
    20  )
    21  
    22  // Grouper groups or merges search results by how related they are
    23  type Grouper struct {
    24  	logger logrus.FieldLogger
    25  }
    26  
    27  // NewGrouper creates a Grouper UC from the specified configuration
    28  func New(logger logrus.FieldLogger) *Grouper {
    29  	return &Grouper{logger: logger}
    30  }
    31  
    32  // Group using the applied strategy and force
    33  func (g *Grouper) Group(in []search.Result, strategy string,
    34  	force float32,
    35  ) ([]search.Result, error) {
    36  	groups := groups{logger: g.logger}
    37  
    38  	for _, current := range in {
    39  		pos, ok := groups.hasMatch(current.Vector, force)
    40  		if !ok {
    41  			groups.new(current)
    42  		} else {
    43  			groups.Elements[pos].add(current)
    44  		}
    45  	}
    46  
    47  	return groups.flatten(strategy)
    48  }
    49  
    50  type group struct {
    51  	Elements []search.Result `json:"elements"`
    52  }
    53  
    54  func (g *group) add(item search.Result) {
    55  	g.Elements = append(g.Elements, item)
    56  }
    57  
    58  func (g group) matches(vector []float32, force float32) bool {
    59  	// iterate over all group Elements and consider it a match if any matches
    60  	for _, elem := range g.Elements {
    61  		dist, err := vectorizer.NormalizedDistance(vector, elem.Vector)
    62  		if err != nil {
    63  			// TODO: log error
    64  			// we don't expect to ever see this error, so we don't need to handle it
    65  			// explicitly, however, let's still log it in case that the above
    66  			// assumption is wrong
    67  			continue
    68  		}
    69  
    70  		if dist < force {
    71  			return true
    72  		}
    73  	}
    74  
    75  	return false
    76  }
    77  
    78  type groups struct {
    79  	Elements []group `json:"elements"`
    80  	logger   logrus.FieldLogger
    81  }
    82  
    83  func (gs groups) hasMatch(vector []float32, force float32) (int, bool) {
    84  	for pos, group := range gs.Elements {
    85  		if group.matches(vector, force) {
    86  			return pos, true
    87  		}
    88  	}
    89  	return -1, false
    90  }
    91  
    92  func (gs *groups) new(item search.Result) {
    93  	gs.Elements = append(gs.Elements, group{Elements: []search.Result{item}})
    94  }
    95  
    96  func (gs groups) flatten(strategy string) (out []search.Result, err error) {
    97  	gs.logger.WithField("object", "grouping_before_flatten").
    98  		WithField("strategy", strategy).
    99  		WithField("groups", gs.Elements).
   100  		Debug("group before flattening")
   101  
   102  	switch strategy {
   103  	case "closest":
   104  		out, err = gs.flattenClosest()
   105  	case "merge":
   106  		out, err = gs.flattenMerge()
   107  	default:
   108  		return nil, fmt.Errorf("unrecognized grouping strategy '%s'", strategy)
   109  	}
   110  	if err != nil {
   111  		return
   112  	}
   113  
   114  	gs.logger.WithField("object", "grouping_after_flatten").
   115  		WithField("strategy", strategy).
   116  		WithField("groups", gs.Elements).
   117  		Debug("group after flattening")
   118  
   119  	return out, nil
   120  }
   121  
   122  func (gs groups) flattenClosest() ([]search.Result, error) {
   123  	out := make([]search.Result, len(gs.Elements))
   124  	for i, group := range gs.Elements {
   125  		out[i] = group.Elements[0] // hard-code "closest" strategy for now
   126  	}
   127  
   128  	return out, nil
   129  }
   130  
   131  func (gs groups) flattenMerge() ([]search.Result, error) {
   132  	out := make([]search.Result, len(gs.Elements))
   133  	for i, group := range gs.Elements {
   134  		merged, err := group.flattenMerge()
   135  		if err != nil {
   136  			return nil, fmt.Errorf("group %d: %v", i, err)
   137  		}
   138  
   139  		out[i] = merged
   140  	}
   141  
   142  	return out, nil
   143  }