github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/classification.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package db
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"math"
    18  
    19  	"github.com/go-openapi/strfmt"
    20  	"github.com/pkg/errors"
    21  	"github.com/weaviate/weaviate/entities/additional"
    22  	"github.com/weaviate/weaviate/entities/dto"
    23  	"github.com/weaviate/weaviate/entities/filters"
    24  	libfilters "github.com/weaviate/weaviate/entities/filters"
    25  	"github.com/weaviate/weaviate/entities/models"
    26  	"github.com/weaviate/weaviate/entities/schema"
    27  	"github.com/weaviate/weaviate/entities/search"
    28  	"github.com/weaviate/weaviate/usecases/classification"
    29  	"github.com/weaviate/weaviate/usecases/vectorizer"
    30  )
    31  
    32  // TODO: why is this logic in the persistence package? This is business-logic,
    33  // move out of here!
    34  func (db *DB) GetUnclassified(ctx context.Context, class string,
    35  	properties []string, filter *libfilters.LocalFilter,
    36  ) ([]search.Result, error) {
    37  	mergedFilter := mergeUserFilterWithRefCountFilter(filter, class, properties,
    38  		libfilters.OperatorEqual, 0)
    39  	res, err := db.Search(ctx, dto.GetParams{
    40  		ClassName: class,
    41  		Filters:   mergedFilter,
    42  		Pagination: &libfilters.Pagination{
    43  			Limit: 10000, // TODO: gh-1219 increase
    44  		},
    45  		AdditionalProperties: additional.Properties{
    46  			Classification: true,
    47  			Vector:         true,
    48  			ModuleParams: map[string]interface{}{
    49  				"interpretation": true,
    50  			},
    51  		},
    52  	})
    53  
    54  	return res, err
    55  }
    56  
    57  // TODO: why is this logic in the persistence package? This is business-logic,
    58  // move out of here!
    59  func (db *DB) ZeroShotSearch(ctx context.Context, vector []float32,
    60  	class string, properties []string,
    61  	filter *libfilters.LocalFilter,
    62  ) ([]search.Result, error) {
    63  	res, err := db.VectorSearch(ctx, dto.GetParams{
    64  		ClassName:    class,
    65  		SearchVector: vector,
    66  		Pagination: &filters.Pagination{
    67  			Limit: 1,
    68  		},
    69  		Filters: filter,
    70  		AdditionalProperties: additional.Properties{
    71  			Vector: true,
    72  		},
    73  	})
    74  
    75  	return res, err
    76  }
    77  
    78  // TODO: why is this logic in the persistence package? This is business-logic,
    79  // move out of here!
    80  func (db *DB) AggregateNeighbors(ctx context.Context, vector []float32,
    81  	class string, properties []string, k int,
    82  	filter *libfilters.LocalFilter,
    83  ) ([]classification.NeighborRef, error) {
    84  	mergedFilter := mergeUserFilterWithRefCountFilter(filter, class, properties,
    85  		libfilters.OperatorGreaterThan, 0)
    86  	res, err := db.VectorSearch(ctx, dto.GetParams{
    87  		ClassName:    class,
    88  		SearchVector: vector,
    89  		Pagination: &filters.Pagination{
    90  			Limit: k,
    91  		},
    92  		Filters: mergedFilter,
    93  		AdditionalProperties: additional.Properties{
    94  			Vector: true,
    95  		},
    96  	})
    97  	if err != nil {
    98  		return nil, errors.Wrap(err, "aggregate neighbors: search neighbors")
    99  	}
   100  
   101  	return NewKnnAggregator(res, vector).Aggregate(k, properties)
   102  }
   103  
   104  // TODO: this is business logic, move out of here
   105  type KnnAggregator struct {
   106  	input        search.Results
   107  	sourceVector []float32
   108  }
   109  
   110  func NewKnnAggregator(input search.Results, sourceVector []float32) *KnnAggregator {
   111  	return &KnnAggregator{input: input, sourceVector: sourceVector}
   112  }
   113  
   114  func (a *KnnAggregator) Aggregate(k int, properties []string) ([]classification.NeighborRef, error) {
   115  	neighbors, err := a.extractBeacons(properties)
   116  	if err != nil {
   117  		return nil, errors.Wrap(err, "aggregate: extract beacons from neighbors")
   118  	}
   119  
   120  	return a.aggregateBeacons(neighbors)
   121  }
   122  
   123  func (a *KnnAggregator) extractBeacons(properties []string) (neighborProps, error) {
   124  	neighbors := neighborProps{}
   125  	for i, elem := range a.input {
   126  		schemaMap, ok := elem.Schema.(map[string]interface{})
   127  		if !ok {
   128  			return nil, fmt.Errorf("expecteded element[%d].Schema to be map, got: %T", i, elem.Schema)
   129  		}
   130  
   131  		for _, prop := range properties {
   132  			refProp, ok := schemaMap[prop]
   133  			if !ok {
   134  				return nil, fmt.Errorf("expecteded element[%d].Schema to have property %q, but didn't", i, prop)
   135  			}
   136  
   137  			refTyped, ok := refProp.(models.MultipleRef)
   138  			if !ok {
   139  				return nil, fmt.Errorf("expecteded element[%d].Schema.%s to be models.MultipleRef, got: %T", i, prop, refProp)
   140  			}
   141  
   142  			if len(refTyped) != 1 {
   143  				return nil, fmt.Errorf("a knn training data object needs to have exactly one label: "+
   144  					"expecteded element[%d].Schema.%s to have exactly one reference, got: %d",
   145  					i, prop, len(refTyped))
   146  			}
   147  
   148  			distance, err := vectorizer.NormalizedDistance(a.sourceVector, elem.Vector)
   149  			if err != nil {
   150  				return nil, errors.Wrap(err, "calculate distance between source and candidate")
   151  			}
   152  
   153  			beacon := refTyped[0].Beacon.String()
   154  			neighborProp := neighbors[prop]
   155  			if neighborProp.beacons == nil {
   156  				neighborProp.beacons = neighborBeacons{}
   157  			}
   158  			neighborProp.beacons[beacon] = append(neighborProp.beacons[beacon], distance)
   159  			neighbors[prop] = neighborProp
   160  		}
   161  	}
   162  
   163  	return neighbors, nil
   164  }
   165  
   166  func (a *KnnAggregator) aggregateBeacons(props neighborProps) ([]classification.NeighborRef, error) {
   167  	var out []classification.NeighborRef
   168  	for propName, prop := range props {
   169  		var winningBeacon string
   170  		var winningCount int
   171  		var totalCount int
   172  
   173  		for beacon, distances := range prop.beacons {
   174  			totalCount += len(distances)
   175  			if len(distances) > winningCount {
   176  				winningBeacon = beacon
   177  				winningCount = len(distances)
   178  			}
   179  		}
   180  
   181  		distances := a.distances(prop.beacons, winningBeacon)
   182  		out = append(out, classification.NeighborRef{
   183  			Beacon:       strfmt.URI(winningBeacon),
   184  			WinningCount: winningCount,
   185  			OverallCount: totalCount,
   186  			LosingCount:  totalCount - winningCount,
   187  			Property:     propName,
   188  			Distances:    distances,
   189  		})
   190  	}
   191  
   192  	return out, nil
   193  }
   194  
   195  func (a *KnnAggregator) distances(beacons neighborBeacons,
   196  	winner string,
   197  ) classification.NeighborRefDistances {
   198  	out := classification.NeighborRefDistances{}
   199  
   200  	var winningDistances []float32
   201  	var losingDistances []float32
   202  
   203  	for beacon, distances := range beacons {
   204  		if beacon == winner {
   205  			winningDistances = distances
   206  		} else {
   207  			losingDistances = append(losingDistances, distances...)
   208  		}
   209  	}
   210  
   211  	if len(losingDistances) > 0 {
   212  		mean := mean(losingDistances)
   213  		out.MeanLosingDistance = &mean
   214  
   215  		closest := min(losingDistances)
   216  		out.ClosestLosingDistance = &closest
   217  	}
   218  
   219  	out.ClosestOverallDistance = min(append(winningDistances, losingDistances...))
   220  	out.ClosestWinningDistance = min(winningDistances)
   221  	out.MeanWinningDistance = mean(winningDistances)
   222  
   223  	return out
   224  }
   225  
   226  type neighborProps map[string]neighborProp
   227  
   228  type neighborProp struct {
   229  	beacons neighborBeacons
   230  }
   231  
   232  type neighborBeacons map[string][]float32
   233  
   234  func mergeUserFilterWithRefCountFilter(userFilter *libfilters.LocalFilter, className string,
   235  	properties []string, op libfilters.Operator, refCount int,
   236  ) *libfilters.LocalFilter {
   237  	countFilters := make([]libfilters.Clause, len(properties))
   238  	for i, prop := range properties {
   239  		countFilters[i] = libfilters.Clause{
   240  			Operator: op,
   241  			Value: &libfilters.Value{
   242  				Type:  schema.DataTypeInt,
   243  				Value: refCount,
   244  			},
   245  			On: &libfilters.Path{
   246  				Class:    schema.ClassName(className),
   247  				Property: schema.PropertyName(prop),
   248  			},
   249  		}
   250  	}
   251  
   252  	var countRootClause libfilters.Clause
   253  	if len(countFilters) == 1 {
   254  		countRootClause = countFilters[0]
   255  	} else {
   256  		countRootClause = libfilters.Clause{
   257  			Operands: countFilters,
   258  			Operator: libfilters.OperatorAnd,
   259  		}
   260  	}
   261  
   262  	rootFilter := &libfilters.LocalFilter{}
   263  	if userFilter == nil {
   264  		rootFilter.Root = &countRootClause
   265  	} else {
   266  		rootFilter.Root = &libfilters.Clause{
   267  			Operator: libfilters.OperatorAnd, // so we can AND the refcount requirements and whatever custom filters, the user has
   268  			Operands: []libfilters.Clause{*userFilter.Root, countRootClause},
   269  		}
   270  	}
   271  
   272  	return rootFilter
   273  }
   274  
   275  func mean(in []float32) float32 {
   276  	sum := float32(0)
   277  	for _, v := range in {
   278  		sum += v
   279  	}
   280  
   281  	return sum / float32(len(in))
   282  }
   283  
   284  func min(in []float32) float32 {
   285  	min := float32(math.MaxFloat32)
   286  	for _, dist := range in {
   287  		if dist < min {
   288  			min = dist
   289  		}
   290  	}
   291  
   292  	return min
   293  }