github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/aggregator/shard_combiner.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package aggregator
    13  
    14  import (
    15  	"sort"
    16  	"time"
    17  
    18  	"github.com/weaviate/weaviate/entities/aggregation"
    19  )
    20  
    21  type ShardCombiner struct{}
    22  
    23  func NewShardCombiner() *ShardCombiner {
    24  	return &ShardCombiner{}
    25  }
    26  
    27  func (sc *ShardCombiner) Do(results []*aggregation.Result) *aggregation.Result {
    28  	allResultsAreNil := true
    29  	firstNonNilRes := 0
    30  	for i, res := range results {
    31  		if res == nil || len(res.Groups) < 1 {
    32  			continue
    33  		}
    34  		allResultsAreNil = false
    35  		firstNonNilRes = i
    36  	}
    37  
    38  	if allResultsAreNil {
    39  		return &aggregation.Result{}
    40  	}
    41  
    42  	if results[firstNonNilRes].Groups[0].GroupedBy == nil {
    43  		return sc.combineUngrouped(results)
    44  	}
    45  
    46  	return sc.combineGrouped(results)
    47  }
    48  
    49  func (sc *ShardCombiner) combineUngrouped(results []*aggregation.Result) *aggregation.Result {
    50  	combined := aggregation.Result{
    51  		Groups: make([]aggregation.Group, 1),
    52  	}
    53  
    54  	for _, shard := range results {
    55  		if len(shard.Groups) == 0 { // not every shard has results
    56  			continue
    57  		}
    58  		sc.mergeIntoCombinedGroupAtPos(combined.Groups, 0, shard.Groups[0])
    59  	}
    60  
    61  	sc.finalizeGroup(&combined.Groups[0])
    62  	return &combined
    63  }
    64  
    65  func (sc *ShardCombiner) combineGrouped(results []*aggregation.Result) *aggregation.Result {
    66  	combined := aggregation.Result{}
    67  
    68  	for _, shard := range results {
    69  		for _, shardGroup := range shard.Groups {
    70  			pos := getPosOfGroup(combined.Groups, shardGroup.GroupedBy.Value)
    71  			if pos < 0 {
    72  				combined.Groups = append(combined.Groups, shardGroup)
    73  			} else {
    74  				sc.mergeIntoCombinedGroupAtPos(combined.Groups, pos, shardGroup)
    75  			}
    76  		}
    77  	}
    78  
    79  	for i := range combined.Groups {
    80  		sc.finalizeGroup(&combined.Groups[i])
    81  	}
    82  
    83  	sort.Slice(combined.Groups, func(a, b int) bool {
    84  		return combined.Groups[a].Count > combined.Groups[b].Count
    85  	})
    86  	return &combined
    87  }
    88  
    89  func (sc *ShardCombiner) mergeIntoCombinedGroupAtPos(combinedGroups []aggregation.Group,
    90  	pos int, shardGroup aggregation.Group,
    91  ) {
    92  	combinedGroups[pos].Count += shardGroup.Count
    93  
    94  	for propName, prop := range shardGroup.Properties {
    95  		if combinedGroups[pos].Properties == nil {
    96  			combinedGroups[pos].Properties = map[string]aggregation.Property{}
    97  		}
    98  
    99  		combinedProp := combinedGroups[pos].Properties[propName]
   100  
   101  		combinedProp.Type = prop.Type
   102  
   103  		switch prop.Type {
   104  		case aggregation.PropertyTypeNumerical:
   105  			if combinedProp.NumericalAggregations == nil {
   106  				combinedProp.NumericalAggregations = map[string]interface{}{}
   107  			}
   108  			sc.mergeNumericalProp(
   109  				combinedProp.NumericalAggregations, prop.NumericalAggregations)
   110  		case aggregation.PropertyTypeDate:
   111  			if combinedProp.DateAggregations == nil {
   112  				combinedProp.DateAggregations = map[string]interface{}{}
   113  			}
   114  			sc.mergeDateProp(
   115  				combinedProp.DateAggregations, prop.DateAggregations)
   116  		case aggregation.PropertyTypeBoolean:
   117  			sc.mergeBooleanProp(
   118  				&combinedProp.BooleanAggregation, &prop.BooleanAggregation)
   119  		case aggregation.PropertyTypeText:
   120  			sc.mergeTextProp(
   121  				&combinedProp.TextAggregation, &prop.TextAggregation)
   122  		case aggregation.PropertyTypeReference:
   123  			sc.mergeRefProp(
   124  				&combinedProp.ReferenceAggregation, &prop.ReferenceAggregation)
   125  		default:
   126  			panic("unknown prop type: " + prop.Type)
   127  		}
   128  		combinedGroups[pos].Properties[propName] = combinedProp
   129  
   130  	}
   131  }
   132  
   133  func (sc *ShardCombiner) mergeDateProp(first, second map[string]interface{}) {
   134  	if len(second) == 0 {
   135  		return
   136  	}
   137  
   138  	// add all values from the second map to the first one. This is needed to compute median and mode correctly
   139  	for propType := range second {
   140  		switch propType {
   141  		case "_dateAggregator":
   142  			dateAggSource := second[propType].(*dateAggregator)
   143  			if dateAggCombined, ok := first[propType]; ok {
   144  				dateAggCombinedTyped := dateAggCombined.(*dateAggregator)
   145  				for _, pair := range dateAggSource.pairs {
   146  					for i := uint64(0); i < pair.count; i++ {
   147  						dateAggCombinedTyped.AddTimestamp(pair.value.rfc3339)
   148  					}
   149  				}
   150  				dateAggCombinedTyped.buildPairsFromCounts()
   151  				first[propType] = dateAggCombinedTyped
   152  
   153  			} else {
   154  				first[propType] = second[propType]
   155  			}
   156  		}
   157  	}
   158  
   159  	for propType, value := range second {
   160  		switch propType {
   161  		case "count":
   162  			if val, ok := first[propType]; ok {
   163  				first[propType] = val.(int64) + value.(int64)
   164  			} else {
   165  				first[propType] = value
   166  			}
   167  		case "mode":
   168  			dateAggCombined := first["_dateAggregator"].(*dateAggregator)
   169  			first[propType] = dateAggCombined.Mode()
   170  		case "median":
   171  			dateAggCombined := first["_dateAggregator"].(*dateAggregator)
   172  			first[propType] = dateAggCombined.Median()
   173  		case "minimum":
   174  			val, ok := first["minimum"]
   175  			if !ok {
   176  				first["minimum"] = value
   177  			} else {
   178  				source1Time, _ := time.Parse(time.RFC3339, val.(string))
   179  				source2Time, _ := time.Parse(time.RFC3339, value.(string))
   180  				if source2Time.Before(source1Time) {
   181  					first["minimum"] = value
   182  				}
   183  			}
   184  		case "maximum":
   185  			val, ok := first["maximum"]
   186  			if !ok {
   187  				first["maximum"] = value
   188  			} else {
   189  				source1Time, _ := time.Parse(time.RFC3339, val.(string))
   190  				source2Time, _ := time.Parse(time.RFC3339, value.(string))
   191  				if source2Time.After(source1Time) {
   192  					first["maximum"] = value
   193  				}
   194  			}
   195  		case "_dateAggregator":
   196  			continue
   197  		default:
   198  			panic("unknown map entry: " + propType)
   199  		}
   200  	}
   201  }
   202  
   203  func (sc *ShardCombiner) mergeNumericalProp(first, second map[string]interface{}) {
   204  	if len(second) == 0 {
   205  		return
   206  	}
   207  
   208  	// add all values from the second map to the first one. This is needed to compute median, mean and mode correctly
   209  	for propType := range second {
   210  		switch propType {
   211  		case "_numericalAggregator":
   212  			numAggSecondTyped := second[propType].(*numericalAggregator)
   213  			if numAggFirst, ok := first[propType]; ok {
   214  				numAggFirstTyped := numAggFirst.(*numericalAggregator)
   215  				for _, pair := range numAggSecondTyped.pairs {
   216  					for i := uint64(0); i < pair.count; i++ {
   217  						numAggFirstTyped.AddFloat64(pair.value)
   218  					}
   219  				}
   220  				numAggFirstTyped.buildPairsFromCounts()
   221  				first[propType] = numAggFirstTyped
   222  			} else {
   223  				first[propType] = second[propType]
   224  			}
   225  		}
   226  	}
   227  
   228  	for propType, value := range second {
   229  		switch propType {
   230  		case "count", "sum":
   231  			if val, ok := first[propType]; ok {
   232  				first[propType] = val.(float64) + value.(float64)
   233  			} else {
   234  				first[propType] = value
   235  			}
   236  		case "mode":
   237  			numAggFirst := first["_numericalAggregator"].(*numericalAggregator)
   238  			first[propType] = numAggFirst.Mode()
   239  		case "mean":
   240  			numAggFirst := first["_numericalAggregator"].(*numericalAggregator)
   241  			first[propType] = numAggFirst.Mean()
   242  		case "median":
   243  			numAggFirst := first["_numericalAggregator"].(*numericalAggregator)
   244  			first[propType] = numAggFirst.Median()
   245  		case "minimum":
   246  			if _, ok := first["minimum"]; !ok || value.(float64) < first["minimum"].(float64) {
   247  				first["minimum"] = value
   248  			}
   249  		case "maximum":
   250  			if _, ok := first["maximum"]; !ok || value.(float64) > first["maximum"].(float64) {
   251  				first["maximum"] = value
   252  			}
   253  		case "_numericalAggregator":
   254  			continue
   255  		default:
   256  			panic("unknown map entry: " + propType)
   257  		}
   258  	}
   259  }
   260  
   261  func (sc *ShardCombiner) finalizeDateProp(combined map[string]interface{}) {
   262  	delete(combined, "_dateAggregator")
   263  }
   264  
   265  func (sc *ShardCombiner) finalizeNumerical(combined map[string]interface{}) {
   266  	delete(combined, "_numericalAggregator")
   267  }
   268  
   269  func (sc *ShardCombiner) mergeBooleanProp(combined, source *aggregation.Boolean) {
   270  	combined.Count += source.Count
   271  	combined.TotalFalse += source.TotalFalse
   272  	combined.TotalTrue += source.TotalTrue
   273  }
   274  
   275  func (sc *ShardCombiner) finalizeBoolean(combined *aggregation.Boolean) {
   276  	combined.PercentageFalse = float64(combined.TotalFalse) / float64(combined.Count)
   277  	combined.PercentageTrue = float64(combined.TotalTrue) / float64(combined.Count)
   278  }
   279  
   280  func (sc *ShardCombiner) mergeTextProp(first, second *aggregation.Text) {
   281  	first.Count += second.Count
   282  
   283  	for _, textOcc := range second.Items {
   284  		pos := getPosOfTextOcc(first.Items, textOcc.Value)
   285  		if pos < 0 {
   286  			first.Items = append(first.Items, textOcc)
   287  		} else {
   288  			first.Items[pos].Occurs += textOcc.Occurs
   289  		}
   290  	}
   291  }
   292  
   293  func (sc *ShardCombiner) mergeRefProp(first, second *aggregation.Reference) {
   294  	first.PointingTo = append(first.PointingTo, second.PointingTo...)
   295  }
   296  
   297  func (sc *ShardCombiner) finalizeText(combined *aggregation.Text) {
   298  	sort.Slice(combined.Items, func(a, b int) bool {
   299  		return combined.Items[a].Occurs > combined.Items[b].Occurs
   300  	})
   301  }
   302  
   303  func getPosOfTextOcc(haystack []aggregation.TextOccurrence, needle string) int {
   304  	for i, elem := range haystack {
   305  		if elem.Value == needle {
   306  			return i
   307  		}
   308  	}
   309  
   310  	return -1
   311  }
   312  
   313  func (sc *ShardCombiner) finalizeGroup(group *aggregation.Group) {
   314  	for propName, prop := range group.Properties {
   315  		switch prop.Type {
   316  		case aggregation.PropertyTypeNumerical:
   317  			sc.finalizeNumerical(prop.NumericalAggregations)
   318  		case aggregation.PropertyTypeBoolean:
   319  			sc.finalizeBoolean(&prop.BooleanAggregation)
   320  		case aggregation.PropertyTypeText:
   321  			sc.finalizeText(&prop.TextAggregation)
   322  		case aggregation.PropertyTypeDate:
   323  			sc.finalizeDateProp(prop.DateAggregations)
   324  		case aggregation.PropertyTypeReference:
   325  			continue
   326  		default:
   327  			panic("Unknown prop type: " + prop.Type)
   328  		}
   329  		group.Properties[propName] = prop
   330  	}
   331  }
   332  
   333  func getPosOfGroup(haystack []aggregation.Group, needle interface{}) int {
   334  	for i, elem := range haystack {
   335  		if elem.GroupedBy.Value == needle {
   336  			return i
   337  		}
   338  	}
   339  
   340  	return -1
   341  }