github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/classification.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package db 13 14 import ( 15 "context" 16 "fmt" 17 "math" 18 19 "github.com/go-openapi/strfmt" 20 "github.com/pkg/errors" 21 "github.com/weaviate/weaviate/entities/additional" 22 "github.com/weaviate/weaviate/entities/dto" 23 "github.com/weaviate/weaviate/entities/filters" 24 libfilters "github.com/weaviate/weaviate/entities/filters" 25 "github.com/weaviate/weaviate/entities/models" 26 "github.com/weaviate/weaviate/entities/schema" 27 "github.com/weaviate/weaviate/entities/search" 28 "github.com/weaviate/weaviate/usecases/classification" 29 "github.com/weaviate/weaviate/usecases/vectorizer" 30 ) 31 32 // TODO: why is this logic in the persistence package? This is business-logic, 33 // move out of here! 34 func (db *DB) GetUnclassified(ctx context.Context, class string, 35 properties []string, filter *libfilters.LocalFilter, 36 ) ([]search.Result, error) { 37 mergedFilter := mergeUserFilterWithRefCountFilter(filter, class, properties, 38 libfilters.OperatorEqual, 0) 39 res, err := db.Search(ctx, dto.GetParams{ 40 ClassName: class, 41 Filters: mergedFilter, 42 Pagination: &libfilters.Pagination{ 43 Limit: 10000, // TODO: gh-1219 increase 44 }, 45 AdditionalProperties: additional.Properties{ 46 Classification: true, 47 Vector: true, 48 ModuleParams: map[string]interface{}{ 49 "interpretation": true, 50 }, 51 }, 52 }) 53 54 return res, err 55 } 56 57 // TODO: why is this logic in the persistence package? This is business-logic, 58 // move out of here! 59 func (db *DB) ZeroShotSearch(ctx context.Context, vector []float32, 60 class string, properties []string, 61 filter *libfilters.LocalFilter, 62 ) ([]search.Result, error) { 63 res, err := db.VectorSearch(ctx, dto.GetParams{ 64 ClassName: class, 65 SearchVector: vector, 66 Pagination: &filters.Pagination{ 67 Limit: 1, 68 }, 69 Filters: filter, 70 AdditionalProperties: additional.Properties{ 71 Vector: true, 72 }, 73 }) 74 75 return res, err 76 } 77 78 // TODO: why is this logic in the persistence package? This is business-logic, 79 // move out of here! 80 func (db *DB) AggregateNeighbors(ctx context.Context, vector []float32, 81 class string, properties []string, k int, 82 filter *libfilters.LocalFilter, 83 ) ([]classification.NeighborRef, error) { 84 mergedFilter := mergeUserFilterWithRefCountFilter(filter, class, properties, 85 libfilters.OperatorGreaterThan, 0) 86 res, err := db.VectorSearch(ctx, dto.GetParams{ 87 ClassName: class, 88 SearchVector: vector, 89 Pagination: &filters.Pagination{ 90 Limit: k, 91 }, 92 Filters: mergedFilter, 93 AdditionalProperties: additional.Properties{ 94 Vector: true, 95 }, 96 }) 97 if err != nil { 98 return nil, errors.Wrap(err, "aggregate neighbors: search neighbors") 99 } 100 101 return NewKnnAggregator(res, vector).Aggregate(k, properties) 102 } 103 104 // TODO: this is business logic, move out of here 105 type KnnAggregator struct { 106 input search.Results 107 sourceVector []float32 108 } 109 110 func NewKnnAggregator(input search.Results, sourceVector []float32) *KnnAggregator { 111 return &KnnAggregator{input: input, sourceVector: sourceVector} 112 } 113 114 func (a *KnnAggregator) Aggregate(k int, properties []string) ([]classification.NeighborRef, error) { 115 neighbors, err := a.extractBeacons(properties) 116 if err != nil { 117 return nil, errors.Wrap(err, "aggregate: extract beacons from neighbors") 118 } 119 120 return a.aggregateBeacons(neighbors) 121 } 122 123 func (a *KnnAggregator) extractBeacons(properties []string) (neighborProps, error) { 124 neighbors := neighborProps{} 125 for i, elem := range a.input { 126 schemaMap, ok := elem.Schema.(map[string]interface{}) 127 if !ok { 128 return nil, fmt.Errorf("expecteded element[%d].Schema to be map, got: %T", i, elem.Schema) 129 } 130 131 for _, prop := range properties { 132 refProp, ok := schemaMap[prop] 133 if !ok { 134 return nil, fmt.Errorf("expecteded element[%d].Schema to have property %q, but didn't", i, prop) 135 } 136 137 refTyped, ok := refProp.(models.MultipleRef) 138 if !ok { 139 return nil, fmt.Errorf("expecteded element[%d].Schema.%s to be models.MultipleRef, got: %T", i, prop, refProp) 140 } 141 142 if len(refTyped) != 1 { 143 return nil, fmt.Errorf("a knn training data object needs to have exactly one label: "+ 144 "expecteded element[%d].Schema.%s to have exactly one reference, got: %d", 145 i, prop, len(refTyped)) 146 } 147 148 distance, err := vectorizer.NormalizedDistance(a.sourceVector, elem.Vector) 149 if err != nil { 150 return nil, errors.Wrap(err, "calculate distance between source and candidate") 151 } 152 153 beacon := refTyped[0].Beacon.String() 154 neighborProp := neighbors[prop] 155 if neighborProp.beacons == nil { 156 neighborProp.beacons = neighborBeacons{} 157 } 158 neighborProp.beacons[beacon] = append(neighborProp.beacons[beacon], distance) 159 neighbors[prop] = neighborProp 160 } 161 } 162 163 return neighbors, nil 164 } 165 166 func (a *KnnAggregator) aggregateBeacons(props neighborProps) ([]classification.NeighborRef, error) { 167 var out []classification.NeighborRef 168 for propName, prop := range props { 169 var winningBeacon string 170 var winningCount int 171 var totalCount int 172 173 for beacon, distances := range prop.beacons { 174 totalCount += len(distances) 175 if len(distances) > winningCount { 176 winningBeacon = beacon 177 winningCount = len(distances) 178 } 179 } 180 181 distances := a.distances(prop.beacons, winningBeacon) 182 out = append(out, classification.NeighborRef{ 183 Beacon: strfmt.URI(winningBeacon), 184 WinningCount: winningCount, 185 OverallCount: totalCount, 186 LosingCount: totalCount - winningCount, 187 Property: propName, 188 Distances: distances, 189 }) 190 } 191 192 return out, nil 193 } 194 195 func (a *KnnAggregator) distances(beacons neighborBeacons, 196 winner string, 197 ) classification.NeighborRefDistances { 198 out := classification.NeighborRefDistances{} 199 200 var winningDistances []float32 201 var losingDistances []float32 202 203 for beacon, distances := range beacons { 204 if beacon == winner { 205 winningDistances = distances 206 } else { 207 losingDistances = append(losingDistances, distances...) 208 } 209 } 210 211 if len(losingDistances) > 0 { 212 mean := mean(losingDistances) 213 out.MeanLosingDistance = &mean 214 215 closest := min(losingDistances) 216 out.ClosestLosingDistance = &closest 217 } 218 219 out.ClosestOverallDistance = min(append(winningDistances, losingDistances...)) 220 out.ClosestWinningDistance = min(winningDistances) 221 out.MeanWinningDistance = mean(winningDistances) 222 223 return out 224 } 225 226 type neighborProps map[string]neighborProp 227 228 type neighborProp struct { 229 beacons neighborBeacons 230 } 231 232 type neighborBeacons map[string][]float32 233 234 func mergeUserFilterWithRefCountFilter(userFilter *libfilters.LocalFilter, className string, 235 properties []string, op libfilters.Operator, refCount int, 236 ) *libfilters.LocalFilter { 237 countFilters := make([]libfilters.Clause, len(properties)) 238 for i, prop := range properties { 239 countFilters[i] = libfilters.Clause{ 240 Operator: op, 241 Value: &libfilters.Value{ 242 Type: schema.DataTypeInt, 243 Value: refCount, 244 }, 245 On: &libfilters.Path{ 246 Class: schema.ClassName(className), 247 Property: schema.PropertyName(prop), 248 }, 249 } 250 } 251 252 var countRootClause libfilters.Clause 253 if len(countFilters) == 1 { 254 countRootClause = countFilters[0] 255 } else { 256 countRootClause = libfilters.Clause{ 257 Operands: countFilters, 258 Operator: libfilters.OperatorAnd, 259 } 260 } 261 262 rootFilter := &libfilters.LocalFilter{} 263 if userFilter == nil { 264 rootFilter.Root = &countRootClause 265 } else { 266 rootFilter.Root = &libfilters.Clause{ 267 Operator: libfilters.OperatorAnd, // so we can AND the refcount requirements and whatever custom filters, the user has 268 Operands: []libfilters.Clause{*userFilter.Root, countRootClause}, 269 } 270 } 271 272 return rootFilter 273 } 274 275 func mean(in []float32) float32 { 276 sum := float32(0) 277 for _, v := range in { 278 sum += v 279 } 280 281 return sum / float32(len(in)) 282 } 283 284 func min(in []float32) float32 { 285 min := float32(math.MaxFloat32) 286 for _, dist := range in { 287 if dist < min { 288 min = dist 289 } 290 } 291 292 return min 293 }