github.com/weaviate/weaviate@v1.24.6/usecases/traverser/grouper/grouper_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package grouper
    13  
    14  import (
    15  	"testing"
    16  
    17  	"github.com/go-openapi/strfmt"
    18  	"github.com/weaviate/weaviate/entities/schema/crossref"
    19  
    20  	"github.com/sirupsen/logrus/hooks/test"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  	"github.com/weaviate/weaviate/entities/models"
    24  	"github.com/weaviate/weaviate/entities/search"
    25  )
    26  
    27  func TestGrouper_ModeClosest(t *testing.T) {
    28  	in := []search.Result{
    29  		{
    30  			ClassName: "Foo",
    31  			Vector:    []float32{0.1, 0.1, 0.98},
    32  			Schema: map[string]interface{}{
    33  				"name": "A1",
    34  			},
    35  		},
    36  		{
    37  			ClassName: "Foo",
    38  			Vector:    []float32{0.1, 0.1, 0.96},
    39  			Schema: map[string]interface{}{
    40  				"name": "A2",
    41  			},
    42  		},
    43  		{
    44  			ClassName: "Foo",
    45  			Vector:    []float32{0.1, 0.1, 0.93},
    46  			Schema: map[string]interface{}{
    47  				"name": "A3",
    48  			},
    49  		},
    50  		{
    51  			ClassName: "Foo",
    52  			Vector:    []float32{0.1, 0.98, 0.1},
    53  			Schema: map[string]interface{}{
    54  				"name": "B1",
    55  			},
    56  		},
    57  		{
    58  			ClassName: "Foo",
    59  			Vector:    []float32{0.1, 0.93, 0.1},
    60  			Schema: map[string]interface{}{
    61  				"name": "B2",
    62  			},
    63  		},
    64  		{
    65  			ClassName: "Foo",
    66  			Vector:    []float32{0.1, 0.92, 0.1},
    67  			Schema: map[string]interface{}{
    68  				"name": "B3",
    69  			},
    70  		},
    71  	}
    72  
    73  	expectedOut := []search.Result{
    74  		{
    75  			ClassName: "Foo",
    76  			Vector:    []float32{0.1, 0.1, 0.98},
    77  			Schema: map[string]interface{}{
    78  				"name": "A1",
    79  			},
    80  		},
    81  		{
    82  			ClassName: "Foo",
    83  			Vector:    []float32{0.1, 0.98, 0.1},
    84  			Schema: map[string]interface{}{
    85  				"name": "B1",
    86  			},
    87  		},
    88  	}
    89  
    90  	log, _ := test.NewNullLogger()
    91  	res, err := New(log).Group(in, "closest", 0.2)
    92  	require.Nil(t, err)
    93  	assert.Equal(t, expectedOut, res)
    94  	for i := range res {
    95  		assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName)
    96  	}
    97  }
    98  
    99  func TestGrouper_ModeMerge(t *testing.T) {
   100  	in := []search.Result{
   101  		{
   102  			ClassName: "Foo",
   103  			Vector:    []float32{0.1, 0.1, 0.98},
   104  			Schema: map[string]interface{}{
   105  				"name":    "A1",
   106  				"count":   10.0,
   107  				"illegal": true,
   108  				"location": &models.GeoCoordinates{
   109  					Latitude:  ptFloat32(20),
   110  					Longitude: ptFloat32(20),
   111  				},
   112  				"relatedTo": []interface{}{
   113  					search.LocalRef{
   114  						Class: "Foo",
   115  						Fields: map[string]interface{}{
   116  							"id":  strfmt.UUID("1"),
   117  							"foo": "bar1",
   118  						},
   119  					},
   120  					search.LocalRef{
   121  						Class: "Foo",
   122  						Fields: map[string]interface{}{
   123  							"id":  strfmt.UUID("2"),
   124  							"foo": "bar2",
   125  						},
   126  					},
   127  				},
   128  			},
   129  		},
   130  		{
   131  			ClassName: "Foo",
   132  			Vector:    []float32{0.1, 0.1, 0.96},
   133  			Schema: map[string]interface{}{
   134  				"name":    "A2",
   135  				"count":   11.0,
   136  				"illegal": true,
   137  			},
   138  		},
   139  		{
   140  			ClassName: "Foo",
   141  			Vector:    []float32{0.1, 0.1, 0.96},
   142  			Schema: map[string]interface{}{
   143  				"name":    "A2",
   144  				"count":   11.0,
   145  				"illegal": true,
   146  				"relatedTo": []interface{}{
   147  					search.LocalRef{
   148  						Class: "Foo",
   149  						Fields: map[string]interface{}{
   150  							"id":  strfmt.UUID("3"),
   151  							"foo": "bar3",
   152  						},
   153  					},
   154  				},
   155  			},
   156  		},
   157  		{
   158  			ClassName: "Foo",
   159  			Vector:    []float32{0.1, 0.1, 0.93},
   160  			Schema: map[string]interface{}{
   161  				"name":    "A3",
   162  				"count":   12.0,
   163  				"illegal": false,
   164  				"location": &models.GeoCoordinates{
   165  					Latitude:  ptFloat32(22),
   166  					Longitude: ptFloat32(18),
   167  				},
   168  				"relatedTo": []interface{}{
   169  					search.LocalRef{
   170  						Class: "Foo",
   171  						Fields: map[string]interface{}{
   172  							"id":  strfmt.UUID("2"),
   173  							"foo": "bar2",
   174  						},
   175  					},
   176  				},
   177  			},
   178  		},
   179  		{
   180  			ClassName: "Foo",
   181  			Vector:    []float32{0.1, 0.98, 0.1},
   182  			Schema: map[string]interface{}{
   183  				"name": "B1",
   184  			},
   185  		},
   186  		{
   187  			ClassName: "Foo",
   188  			Vector:    []float32{0.1, 0.93, 0.1},
   189  			Schema: map[string]interface{}{
   190  				"name": "B2",
   191  			},
   192  		},
   193  		{
   194  			ClassName: "Foo",
   195  			Vector:    []float32{0.1, 0.92, 0.1},
   196  			Schema: map[string]interface{}{
   197  				"name": "B3",
   198  			},
   199  		},
   200  	}
   201  
   202  	expectedOut := []search.Result{
   203  		{
   204  			ClassName: "Foo",
   205  			Vector:    []float32{0.1, 0.1, 0.95750004}, // centroid position of all inputs
   206  			Schema: map[string]interface{}{
   207  				"name":    "A1 (A2, A3)", // note that A2 is only contained once, even though its twice in the input set
   208  				"count":   11.0,          // mean of all inputs
   209  				"illegal": true,          // the most common input value, with a bias towards true on equal count
   210  				"location": &models.GeoCoordinates{
   211  					Latitude:  ptFloat32(21),
   212  					Longitude: ptFloat32(19),
   213  				},
   214  				"relatedTo": []interface{}{
   215  					search.LocalRef{
   216  						Class: "Foo",
   217  						Fields: map[string]interface{}{
   218  							"id":  strfmt.UUID("1"),
   219  							"foo": "bar1",
   220  						},
   221  					},
   222  					search.LocalRef{
   223  						Class: "Foo",
   224  						Fields: map[string]interface{}{
   225  							"id":  strfmt.UUID("2"),
   226  							"foo": "bar2",
   227  						},
   228  					},
   229  					search.LocalRef{
   230  						Class: "Foo",
   231  						Fields: map[string]interface{}{
   232  							"id":  strfmt.UUID("3"),
   233  							"foo": "bar3",
   234  						},
   235  					},
   236  				},
   237  			},
   238  		},
   239  		{
   240  			ClassName: "Foo",
   241  			Vector:    []float32{0.1, 0.9433334, 0.1},
   242  			Schema: map[string]interface{}{
   243  				"name": "B1 (B2, B3)",
   244  			},
   245  		},
   246  	}
   247  
   248  	log, _ := test.NewNullLogger()
   249  	res, err := New(log).Group(in, "merge", 0.2)
   250  	require.Nil(t, err)
   251  	assert.Equal(t, expectedOut, res)
   252  	for i := range res {
   253  		assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName)
   254  	}
   255  }
   256  
   257  // Since reference properties can be represented both as models.MultipleRef
   258  // and []interface{}, we need to test for both cases. TestGrouper_ModeMerge
   259  // above tests the case of []interface{}, so this test handles the other case.
   260  // see https://github.com/weaviate/weaviate/pull/2320 for more info
   261  func Test_Grouper_ModeMerge_MultipleRef(t *testing.T) {
   262  	in := []search.Result{
   263  		{
   264  			ClassName: "Foo",
   265  			Vector:    []float32{0.1, 0.1, 0.98},
   266  			Schema: map[string]interface{}{
   267  				"name":    "A1",
   268  				"count":   10.0,
   269  				"illegal": true,
   270  				"location": &models.GeoCoordinates{
   271  					Latitude:  ptFloat32(20),
   272  					Longitude: ptFloat32(20),
   273  				},
   274  				"relatedTo": models.MultipleRef{
   275  					&models.SingleRef{
   276  						Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "3dc4417d-1508-4914-9929-8add49684b9f").String()),
   277  						Class:  "Foo",
   278  					},
   279  					&models.SingleRef{
   280  						Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()),
   281  						Class:  "Foo",
   282  					},
   283  				},
   284  			},
   285  		},
   286  		{
   287  			ClassName: "Foo",
   288  			Vector:    []float32{0.1, 0.1, 0.96},
   289  			Schema: map[string]interface{}{
   290  				"name":    "A2",
   291  				"count":   11.0,
   292  				"illegal": true,
   293  			},
   294  		},
   295  		{
   296  			ClassName: "Foo",
   297  			Vector:    []float32{0.1, 0.1, 0.96},
   298  			Schema: map[string]interface{}{
   299  				"name":    "A2",
   300  				"count":   11.0,
   301  				"illegal": true,
   302  				"relatedTo": models.MultipleRef{
   303  					&models.SingleRef{
   304  						Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f280a7f7-7fab-46ed-b895-1490512660ae").String()),
   305  						Class:  "Foo",
   306  					},
   307  				},
   308  			},
   309  		},
   310  		{
   311  			ClassName: "Foo",
   312  			Vector:    []float32{0.1, 0.1, 0.93},
   313  			Schema: map[string]interface{}{
   314  				"name":    "A3",
   315  				"count":   12.0,
   316  				"illegal": false,
   317  				"location": &models.GeoCoordinates{
   318  					Latitude:  ptFloat32(22),
   319  					Longitude: ptFloat32(18),
   320  				},
   321  				"relatedTo": models.MultipleRef{
   322  					&models.SingleRef{
   323  						Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()),
   324  						Class:  "Foo",
   325  					},
   326  				},
   327  			},
   328  		},
   329  		{
   330  			ClassName: "Foo",
   331  			Vector:    []float32{0.1, 0.98, 0.1},
   332  			Schema: map[string]interface{}{
   333  				"name": "B1",
   334  			},
   335  		},
   336  		{
   337  			ClassName: "Foo",
   338  			Vector:    []float32{0.1, 0.93, 0.1},
   339  			Schema: map[string]interface{}{
   340  				"name": "B2",
   341  			},
   342  		},
   343  		{
   344  			ClassName: "Foo",
   345  			Vector:    []float32{0.1, 0.92, 0.1},
   346  			Schema: map[string]interface{}{
   347  				"name": "B3",
   348  			},
   349  		},
   350  	}
   351  
   352  	expectedOut := []search.Result{
   353  		{
   354  			ClassName: "Foo",
   355  			Vector:    []float32{0.1, 0.1, 0.95750004}, // centroid position of all inputs
   356  			Schema: map[string]interface{}{
   357  				"name":    "A1 (A2, A3)", // note that A2 is only contained once, even though its twice in the input set
   358  				"count":   11.0,          // mean of all inputs
   359  				"illegal": true,          // the most common input value, with a bias towards true on equal count
   360  				"location": &models.GeoCoordinates{
   361  					Latitude:  ptFloat32(21),
   362  					Longitude: ptFloat32(19),
   363  				},
   364  				"relatedTo": []interface{}{
   365  					&models.SingleRef{
   366  						Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "3dc4417d-1508-4914-9929-8add49684b9f").String()),
   367  						Class:  "Foo",
   368  					},
   369  					&models.SingleRef{
   370  						Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()),
   371  						Class:  "Foo",
   372  					},
   373  					&models.SingleRef{
   374  						Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f280a7f7-7fab-46ed-b895-1490512660ae").String()),
   375  						Class:  "Foo",
   376  					},
   377  				},
   378  			},
   379  		},
   380  		{
   381  			ClassName: "Foo",
   382  			Vector:    []float32{0.1, 0.9433334, 0.1},
   383  			Schema: map[string]interface{}{
   384  				"name": "B1 (B2, B3)",
   385  			},
   386  		},
   387  	}
   388  
   389  	log, _ := test.NewNullLogger()
   390  	res, err := New(log).Group(in, "merge", 0.2)
   391  	require.Nil(t, err)
   392  	assert.Equal(t, expectedOut, res)
   393  	for i := range res {
   394  		assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName)
   395  	}
   396  }
   397  
   398  func TestGrouper_ModeMergeFailWithIDTypeOtherThenUUID(t *testing.T) {
   399  	in := []search.Result{
   400  		{
   401  			ClassName: "Foo",
   402  			Vector:    []float32{0.1, 0.1, 0.98},
   403  			Schema: map[string]interface{}{
   404  				"name":    "A1",
   405  				"count":   10.0,
   406  				"illegal": true,
   407  				"location": &models.GeoCoordinates{
   408  					Latitude:  ptFloat32(20),
   409  					Longitude: ptFloat32(20),
   410  				},
   411  				"relatedTo": []interface{}{
   412  					search.LocalRef{
   413  						Class: "Foo",
   414  						Fields: map[string]interface{}{
   415  							"id":  "1",
   416  							"foo": "bar1",
   417  						},
   418  					},
   419  					search.LocalRef{
   420  						Class: "Foo",
   421  						Fields: map[string]interface{}{
   422  							"id":  "2",
   423  							"foo": "bar2",
   424  						},
   425  					},
   426  				},
   427  			},
   428  		},
   429  	}
   430  
   431  	log, _ := test.NewNullLogger()
   432  	res, err := New(log).Group(in, "merge", 0.2)
   433  	require.NotNil(t, err)
   434  	assert.Nil(t, res)
   435  	assert.EqualError(t, err,
   436  		"group 0: merge values: prop 'relatedTo': element 0: "+
   437  			"found a search.LocalRef, 'id' field type expected to be strfmt.UUID but got string")
   438  }