github.com/weaviate/weaviate@v1.24.6/usecases/traverser/traverser_aggregate_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package traverser
    13  
    14  import (
    15  	"context"
    16  	"testing"
    17  
    18  	"github.com/sirupsen/logrus/hooks/test"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/weaviate/weaviate/entities/aggregation"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/schema"
    24  	"github.com/weaviate/weaviate/entities/searchparams"
    25  	"github.com/weaviate/weaviate/usecases/config"
    26  )
    27  
    28  func Test_Traverser_Aggregate(t *testing.T) {
    29  	principal := &models.Principal{}
    30  	logger, _ := test.NewNullLogger()
    31  	locks := &fakeLocks{}
    32  	authorizer := &fakeAuthorizer{}
    33  	vectorRepo := &fakeVectorRepo{}
    34  	explorer := &fakeExplorer{}
    35  	schemaGetter := &fakeSchemaGetter{aggregateTestSchema}
    36  
    37  	traverser := NewTraverser(&config.WeaviateConfig{}, locks, logger, authorizer,
    38  		vectorRepo, explorer, schemaGetter, nil, nil, -1)
    39  
    40  	t.Run("with aggregation only", func(t *testing.T) {
    41  		params := aggregation.Params{
    42  			ClassName: "MyClass",
    43  			Properties: []aggregation.ParamProperty{
    44  				{
    45  					Name:        "label",
    46  					Aggregators: []aggregation.Aggregator{aggregation.NewTopOccurrencesAggregator(nil)},
    47  				},
    48  				{
    49  					Name:        "number",
    50  					Aggregators: []aggregation.Aggregator{aggregation.SumAggregator},
    51  				},
    52  				{
    53  					Name:        "int",
    54  					Aggregators: []aggregation.Aggregator{aggregation.SumAggregator},
    55  				},
    56  				{
    57  					Name:        "date",
    58  					Aggregators: []aggregation.Aggregator{aggregation.NewTopOccurrencesAggregator(nil)},
    59  				},
    60  			},
    61  		}
    62  
    63  		agg := aggregation.Result{
    64  			Groups: []aggregation.Group{
    65  				{
    66  					Properties: map[string]aggregation.Property{
    67  						"label": {
    68  							TextAggregation: aggregation.Text{
    69  								Items: []aggregation.TextOccurrence{
    70  									{
    71  										Value:  "Foo",
    72  										Occurs: 200,
    73  									},
    74  								},
    75  							},
    76  							Type: aggregation.PropertyTypeText,
    77  						},
    78  						"date": {
    79  							TextAggregation: aggregation.Text{
    80  								Items: []aggregation.TextOccurrence{
    81  									{
    82  										Value:  "Bar",
    83  										Occurs: 100,
    84  									},
    85  								},
    86  							},
    87  							Type: aggregation.PropertyTypeText,
    88  						},
    89  						"number": {
    90  							Type: aggregation.PropertyTypeNumerical,
    91  							NumericalAggregations: map[string]interface{}{
    92  								"sum": 200,
    93  							},
    94  						},
    95  						"int": {
    96  							Type: aggregation.PropertyTypeNumerical,
    97  							NumericalAggregations: map[string]interface{}{
    98  								"sum": 100,
    99  							},
   100  						},
   101  					},
   102  				},
   103  			},
   104  		}
   105  
   106  		vectorRepo.On("Aggregate", params).Return(&agg, nil)
   107  		res, err := traverser.Aggregate(context.Background(), principal, &params)
   108  		require.Nil(t, err)
   109  		assert.Equal(t, &agg, res)
   110  	})
   111  
   112  	t.Run("with a mix of aggregation and type inspection", func(t *testing.T) {
   113  		params := aggregation.Params{
   114  			ClassName: "MyClass",
   115  			Properties: []aggregation.ParamProperty{
   116  				{
   117  					Name: "label",
   118  					Aggregators: []aggregation.Aggregator{
   119  						aggregation.TypeAggregator,
   120  						aggregation.NewTopOccurrencesAggregator(nil),
   121  					},
   122  				},
   123  				{
   124  					Name: "number",
   125  					Aggregators: []aggregation.Aggregator{
   126  						aggregation.TypeAggregator,
   127  						aggregation.SumAggregator,
   128  					},
   129  				},
   130  				{
   131  					Name: "int",
   132  					Aggregators: []aggregation.Aggregator{
   133  						aggregation.TypeAggregator,
   134  						aggregation.SumAggregator,
   135  					},
   136  				},
   137  				{
   138  					Name: "date",
   139  					Aggregators: []aggregation.Aggregator{
   140  						aggregation.TypeAggregator,
   141  						aggregation.NewTopOccurrencesAggregator(nil),
   142  					},
   143  				},
   144  				{
   145  					Name:        "a ref",
   146  					Aggregators: []aggregation.Aggregator{aggregation.TypeAggregator},
   147  				},
   148  			},
   149  		}
   150  
   151  		agg := aggregation.Result{
   152  			Groups: []aggregation.Group{
   153  				{
   154  					Properties: map[string]aggregation.Property{
   155  						"label": {
   156  							TextAggregation: aggregation.Text{
   157  								Items: []aggregation.TextOccurrence{
   158  									{
   159  										Value:  "Foo",
   160  										Occurs: 200,
   161  									},
   162  								},
   163  							},
   164  							Type: aggregation.PropertyTypeText,
   165  						},
   166  						"date": {
   167  							TextAggregation: aggregation.Text{
   168  								Items: []aggregation.TextOccurrence{
   169  									{
   170  										Value:  "Bar",
   171  										Occurs: 100,
   172  									},
   173  								},
   174  							},
   175  							Type: aggregation.PropertyTypeText,
   176  						},
   177  						"number": {
   178  							Type: aggregation.PropertyTypeNumerical,
   179  							NumericalAggregations: map[string]interface{}{
   180  								"sum": 200,
   181  							},
   182  						},
   183  						"int": {
   184  							Type: aggregation.PropertyTypeNumerical,
   185  							NumericalAggregations: map[string]interface{}{
   186  								"sum": 100,
   187  							},
   188  						},
   189  					},
   190  				},
   191  			},
   192  		}
   193  
   194  		expectedResult := aggregation.Result{
   195  			Groups: []aggregation.Group{
   196  				{
   197  					Properties: map[string]aggregation.Property{
   198  						"label": {
   199  							TextAggregation: aggregation.Text{
   200  								Items: []aggregation.TextOccurrence{
   201  									{
   202  										Value:  "Foo",
   203  										Occurs: 200,
   204  									},
   205  								},
   206  							},
   207  							Type:       aggregation.PropertyTypeText,
   208  							SchemaType: string(schema.DataTypeText),
   209  						},
   210  						"date": {
   211  							TextAggregation: aggregation.Text{
   212  								Items: []aggregation.TextOccurrence{
   213  									{
   214  										Value:  "Bar",
   215  										Occurs: 100,
   216  									},
   217  								},
   218  							},
   219  							SchemaType: string(schema.DataTypeDate),
   220  							Type:       aggregation.PropertyTypeText,
   221  						},
   222  						"number": {
   223  							Type:       aggregation.PropertyTypeNumerical,
   224  							SchemaType: string(schema.DataTypeNumber),
   225  							NumericalAggregations: map[string]interface{}{
   226  								"sum": 200,
   227  							},
   228  						},
   229  						"int": {
   230  							Type:       aggregation.PropertyTypeNumerical,
   231  							SchemaType: string(schema.DataTypeInt),
   232  							NumericalAggregations: map[string]interface{}{
   233  								"sum": 100,
   234  							},
   235  						},
   236  						"a ref": {
   237  							Type: aggregation.PropertyTypeReference,
   238  							ReferenceAggregation: aggregation.Reference{
   239  								PointingTo: []string{"AnotherClass"},
   240  							},
   241  							SchemaType: string(schema.DataTypeCRef),
   242  						},
   243  					},
   244  				},
   245  			},
   246  		}
   247  
   248  		vectorRepo.On("Aggregate", params).Return(&agg, nil)
   249  		res, err := traverser.Aggregate(context.Background(), principal, &params)
   250  		require.Nil(t, err)
   251  		assert.Equal(t, &expectedResult, res)
   252  	})
   253  
   254  	t.Run("with hybrid search", func(t *testing.T) {
   255  		params := aggregation.Params{
   256  			ClassName: "MyClass",
   257  			Properties: []aggregation.ParamProperty{
   258  				{
   259  					Name:        "label",
   260  					Aggregators: []aggregation.Aggregator{aggregation.NewTopOccurrencesAggregator(nil)},
   261  				},
   262  				{
   263  					Name:        "number",
   264  					Aggregators: []aggregation.Aggregator{aggregation.SumAggregator},
   265  				},
   266  				{
   267  					Name:        "int",
   268  					Aggregators: []aggregation.Aggregator{aggregation.SumAggregator},
   269  				},
   270  				{
   271  					Name:        "date",
   272  					Aggregators: []aggregation.Aggregator{aggregation.NewTopOccurrencesAggregator(nil)},
   273  				},
   274  			},
   275  			IncludeMetaCount: true,
   276  			Hybrid: &searchparams.HybridSearch{
   277  				Type:   "hybrid",
   278  				Alpha:  0.5,
   279  				Query:  "some query",
   280  				Vector: []float32{1, 2, 3},
   281  			},
   282  		}
   283  
   284  		agg := aggregation.Result{
   285  			Groups: []aggregation.Group{
   286  				{
   287  					Properties: map[string]aggregation.Property{
   288  						"label": {
   289  							TextAggregation: aggregation.Text{
   290  								Items: []aggregation.TextOccurrence{
   291  									{
   292  										Value:  "Foo",
   293  										Occurs: 200,
   294  									},
   295  								},
   296  							},
   297  							Type: aggregation.PropertyTypeText,
   298  						},
   299  						"date": {
   300  							TextAggregation: aggregation.Text{
   301  								Items: []aggregation.TextOccurrence{
   302  									{
   303  										Value:  "Bar",
   304  										Occurs: 100,
   305  									},
   306  								},
   307  							},
   308  							Type: aggregation.PropertyTypeText,
   309  						},
   310  						"number": {
   311  							Type: aggregation.PropertyTypeNumerical,
   312  							NumericalAggregations: map[string]interface{}{
   313  								"sum": 200,
   314  							},
   315  						},
   316  						"int": {
   317  							Type: aggregation.PropertyTypeNumerical,
   318  							NumericalAggregations: map[string]interface{}{
   319  								"sum": 100,
   320  							},
   321  						},
   322  					},
   323  				},
   324  			},
   325  		}
   326  
   327  		vectorRepo.On("Aggregate", params).Return(&agg, nil)
   328  		res, err := traverser.Aggregate(context.Background(), principal, &params)
   329  		require.Nil(t, err)
   330  		assert.Equal(t, &agg, res)
   331  		t.Logf("res: %+v", res)
   332  	})
   333  }
   334  
   335  var aggregateTestSchema = schema.Schema{
   336  	Objects: &models.Schema{
   337  		Classes: []*models.Class{
   338  			{
   339  				Class: "AnotherClass",
   340  			},
   341  			{
   342  				Class: "MyClass",
   343  				Properties: []*models.Property{
   344  					{
   345  						Name:         "label",
   346  						DataType:     schema.DataTypeText.PropString(),
   347  						Tokenization: models.PropertyTokenizationWhitespace,
   348  					},
   349  					{
   350  						Name:     "number",
   351  						DataType: []string{string(schema.DataTypeNumber)},
   352  					},
   353  					{
   354  						Name:     "int",
   355  						DataType: []string{string(schema.DataTypeInt)},
   356  					},
   357  					{
   358  						Name:     "date",
   359  						DataType: []string{string(schema.DataTypeDate)},
   360  					},
   361  					{
   362  						Name:     "a ref",
   363  						DataType: []string{"AnotherClass"},
   364  					},
   365  				},
   366  			},
   367  		},
   368  	},
   369  }