github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/aggregator/shard_combiner.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package aggregator 13 14 import ( 15 "sort" 16 "time" 17 18 "github.com/weaviate/weaviate/entities/aggregation" 19 ) 20 21 type ShardCombiner struct{} 22 23 func NewShardCombiner() *ShardCombiner { 24 return &ShardCombiner{} 25 } 26 27 func (sc *ShardCombiner) Do(results []*aggregation.Result) *aggregation.Result { 28 allResultsAreNil := true 29 firstNonNilRes := 0 30 for i, res := range results { 31 if res == nil || len(res.Groups) < 1 { 32 continue 33 } 34 allResultsAreNil = false 35 firstNonNilRes = i 36 } 37 38 if allResultsAreNil { 39 return &aggregation.Result{} 40 } 41 42 if results[firstNonNilRes].Groups[0].GroupedBy == nil { 43 return sc.combineUngrouped(results) 44 } 45 46 return sc.combineGrouped(results) 47 } 48 49 func (sc *ShardCombiner) combineUngrouped(results []*aggregation.Result) *aggregation.Result { 50 combined := aggregation.Result{ 51 Groups: make([]aggregation.Group, 1), 52 } 53 54 for _, shard := range results { 55 if len(shard.Groups) == 0 { // not every shard has results 56 continue 57 } 58 sc.mergeIntoCombinedGroupAtPos(combined.Groups, 0, shard.Groups[0]) 59 } 60 61 sc.finalizeGroup(&combined.Groups[0]) 62 return &combined 63 } 64 65 func (sc *ShardCombiner) combineGrouped(results []*aggregation.Result) *aggregation.Result { 66 combined := aggregation.Result{} 67 68 for _, shard := range results { 69 for _, shardGroup := range shard.Groups { 70 pos := getPosOfGroup(combined.Groups, shardGroup.GroupedBy.Value) 71 if pos < 0 { 72 combined.Groups = append(combined.Groups, shardGroup) 73 } else { 74 sc.mergeIntoCombinedGroupAtPos(combined.Groups, pos, shardGroup) 75 } 76 } 77 } 78 79 for i := range combined.Groups { 80 sc.finalizeGroup(&combined.Groups[i]) 81 } 82 83 sort.Slice(combined.Groups, func(a, b int) bool { 84 return combined.Groups[a].Count > combined.Groups[b].Count 85 }) 86 return &combined 87 } 88 89 func (sc *ShardCombiner) mergeIntoCombinedGroupAtPos(combinedGroups []aggregation.Group, 90 pos int, shardGroup aggregation.Group, 91 ) { 92 combinedGroups[pos].Count += shardGroup.Count 93 94 for propName, prop := range shardGroup.Properties { 95 if combinedGroups[pos].Properties == nil { 96 combinedGroups[pos].Properties = map[string]aggregation.Property{} 97 } 98 99 combinedProp := combinedGroups[pos].Properties[propName] 100 101 combinedProp.Type = prop.Type 102 103 switch prop.Type { 104 case aggregation.PropertyTypeNumerical: 105 if combinedProp.NumericalAggregations == nil { 106 combinedProp.NumericalAggregations = map[string]interface{}{} 107 } 108 sc.mergeNumericalProp( 109 combinedProp.NumericalAggregations, prop.NumericalAggregations) 110 case aggregation.PropertyTypeDate: 111 if combinedProp.DateAggregations == nil { 112 combinedProp.DateAggregations = map[string]interface{}{} 113 } 114 sc.mergeDateProp( 115 combinedProp.DateAggregations, prop.DateAggregations) 116 case aggregation.PropertyTypeBoolean: 117 sc.mergeBooleanProp( 118 &combinedProp.BooleanAggregation, &prop.BooleanAggregation) 119 case aggregation.PropertyTypeText: 120 sc.mergeTextProp( 121 &combinedProp.TextAggregation, &prop.TextAggregation) 122 case aggregation.PropertyTypeReference: 123 sc.mergeRefProp( 124 &combinedProp.ReferenceAggregation, &prop.ReferenceAggregation) 125 default: 126 panic("unknown prop type: " + prop.Type) 127 } 128 combinedGroups[pos].Properties[propName] = combinedProp 129 130 } 131 } 132 133 func (sc *ShardCombiner) mergeDateProp(first, second map[string]interface{}) { 134 if len(second) == 0 { 135 return 136 } 137 138 // add all values from the second map to the first one. This is needed to compute median and mode correctly 139 for propType := range second { 140 switch propType { 141 case "_dateAggregator": 142 dateAggSource := second[propType].(*dateAggregator) 143 if dateAggCombined, ok := first[propType]; ok { 144 dateAggCombinedTyped := dateAggCombined.(*dateAggregator) 145 for _, pair := range dateAggSource.pairs { 146 for i := uint64(0); i < pair.count; i++ { 147 dateAggCombinedTyped.AddTimestamp(pair.value.rfc3339) 148 } 149 } 150 dateAggCombinedTyped.buildPairsFromCounts() 151 first[propType] = dateAggCombinedTyped 152 153 } else { 154 first[propType] = second[propType] 155 } 156 } 157 } 158 159 for propType, value := range second { 160 switch propType { 161 case "count": 162 if val, ok := first[propType]; ok { 163 first[propType] = val.(int64) + value.(int64) 164 } else { 165 first[propType] = value 166 } 167 case "mode": 168 dateAggCombined := first["_dateAggregator"].(*dateAggregator) 169 first[propType] = dateAggCombined.Mode() 170 case "median": 171 dateAggCombined := first["_dateAggregator"].(*dateAggregator) 172 first[propType] = dateAggCombined.Median() 173 case "minimum": 174 val, ok := first["minimum"] 175 if !ok { 176 first["minimum"] = value 177 } else { 178 source1Time, _ := time.Parse(time.RFC3339, val.(string)) 179 source2Time, _ := time.Parse(time.RFC3339, value.(string)) 180 if source2Time.Before(source1Time) { 181 first["minimum"] = value 182 } 183 } 184 case "maximum": 185 val, ok := first["maximum"] 186 if !ok { 187 first["maximum"] = value 188 } else { 189 source1Time, _ := time.Parse(time.RFC3339, val.(string)) 190 source2Time, _ := time.Parse(time.RFC3339, value.(string)) 191 if source2Time.After(source1Time) { 192 first["maximum"] = value 193 } 194 } 195 case "_dateAggregator": 196 continue 197 default: 198 panic("unknown map entry: " + propType) 199 } 200 } 201 } 202 203 func (sc *ShardCombiner) mergeNumericalProp(first, second map[string]interface{}) { 204 if len(second) == 0 { 205 return 206 } 207 208 // add all values from the second map to the first one. This is needed to compute median, mean and mode correctly 209 for propType := range second { 210 switch propType { 211 case "_numericalAggregator": 212 numAggSecondTyped := second[propType].(*numericalAggregator) 213 if numAggFirst, ok := first[propType]; ok { 214 numAggFirstTyped := numAggFirst.(*numericalAggregator) 215 for _, pair := range numAggSecondTyped.pairs { 216 for i := uint64(0); i < pair.count; i++ { 217 numAggFirstTyped.AddFloat64(pair.value) 218 } 219 } 220 numAggFirstTyped.buildPairsFromCounts() 221 first[propType] = numAggFirstTyped 222 } else { 223 first[propType] = second[propType] 224 } 225 } 226 } 227 228 for propType, value := range second { 229 switch propType { 230 case "count", "sum": 231 if val, ok := first[propType]; ok { 232 first[propType] = val.(float64) + value.(float64) 233 } else { 234 first[propType] = value 235 } 236 case "mode": 237 numAggFirst := first["_numericalAggregator"].(*numericalAggregator) 238 first[propType] = numAggFirst.Mode() 239 case "mean": 240 numAggFirst := first["_numericalAggregator"].(*numericalAggregator) 241 first[propType] = numAggFirst.Mean() 242 case "median": 243 numAggFirst := first["_numericalAggregator"].(*numericalAggregator) 244 first[propType] = numAggFirst.Median() 245 case "minimum": 246 if _, ok := first["minimum"]; !ok || value.(float64) < first["minimum"].(float64) { 247 first["minimum"] = value 248 } 249 case "maximum": 250 if _, ok := first["maximum"]; !ok || value.(float64) > first["maximum"].(float64) { 251 first["maximum"] = value 252 } 253 case "_numericalAggregator": 254 continue 255 default: 256 panic("unknown map entry: " + propType) 257 } 258 } 259 } 260 261 func (sc *ShardCombiner) finalizeDateProp(combined map[string]interface{}) { 262 delete(combined, "_dateAggregator") 263 } 264 265 func (sc *ShardCombiner) finalizeNumerical(combined map[string]interface{}) { 266 delete(combined, "_numericalAggregator") 267 } 268 269 func (sc *ShardCombiner) mergeBooleanProp(combined, source *aggregation.Boolean) { 270 combined.Count += source.Count 271 combined.TotalFalse += source.TotalFalse 272 combined.TotalTrue += source.TotalTrue 273 } 274 275 func (sc *ShardCombiner) finalizeBoolean(combined *aggregation.Boolean) { 276 combined.PercentageFalse = float64(combined.TotalFalse) / float64(combined.Count) 277 combined.PercentageTrue = float64(combined.TotalTrue) / float64(combined.Count) 278 } 279 280 func (sc *ShardCombiner) mergeTextProp(first, second *aggregation.Text) { 281 first.Count += second.Count 282 283 for _, textOcc := range second.Items { 284 pos := getPosOfTextOcc(first.Items, textOcc.Value) 285 if pos < 0 { 286 first.Items = append(first.Items, textOcc) 287 } else { 288 first.Items[pos].Occurs += textOcc.Occurs 289 } 290 } 291 } 292 293 func (sc *ShardCombiner) mergeRefProp(first, second *aggregation.Reference) { 294 first.PointingTo = append(first.PointingTo, second.PointingTo...) 295 } 296 297 func (sc *ShardCombiner) finalizeText(combined *aggregation.Text) { 298 sort.Slice(combined.Items, func(a, b int) bool { 299 return combined.Items[a].Occurs > combined.Items[b].Occurs 300 }) 301 } 302 303 func getPosOfTextOcc(haystack []aggregation.TextOccurrence, needle string) int { 304 for i, elem := range haystack { 305 if elem.Value == needle { 306 return i 307 } 308 } 309 310 return -1 311 } 312 313 func (sc *ShardCombiner) finalizeGroup(group *aggregation.Group) { 314 for propName, prop := range group.Properties { 315 switch prop.Type { 316 case aggregation.PropertyTypeNumerical: 317 sc.finalizeNumerical(prop.NumericalAggregations) 318 case aggregation.PropertyTypeBoolean: 319 sc.finalizeBoolean(&prop.BooleanAggregation) 320 case aggregation.PropertyTypeText: 321 sc.finalizeText(&prop.TextAggregation) 322 case aggregation.PropertyTypeDate: 323 sc.finalizeDateProp(prop.DateAggregations) 324 case aggregation.PropertyTypeReference: 325 continue 326 default: 327 panic("Unknown prop type: " + prop.Type) 328 } 329 group.Properties[propName] = prop 330 } 331 } 332 333 func getPosOfGroup(haystack []aggregation.Group, needle interface{}) int { 334 for i, elem := range haystack { 335 if elem.GroupedBy.Value == needle { 336 return i 337 } 338 } 339 340 return -1 341 }