github.com/weaviate/weaviate@v1.24.6/test/helper/journey/group_by_journey_tests.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package journey
    13  
    14  import (
    15  	"fmt"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"github.com/weaviate/weaviate/test/helper"
    21  	graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql"
    22  	"github.com/weaviate/weaviate/test/helper/sample-schema/documents"
    23  )
    24  
    25  func GroupBySingleAndMultiShardTests(t *testing.T, weaviateEndpoint string) {
    26  	if weaviateEndpoint != "" {
    27  		helper.SetupClient(weaviateEndpoint)
    28  	}
    29  	// helper methods
    30  	getGroup := func(value interface{}) map[string]interface{} {
    31  		group := value.(map[string]interface{})["_additional"].(map[string]interface{})["group"].(map[string]interface{})
    32  		return group
    33  	}
    34  	getGroupHits := func(group map[string]interface{}) (string, []string) {
    35  		result := []string{}
    36  		hits := group["hits"].([]interface{})
    37  		for _, hit := range hits {
    38  			additional := hit.(map[string]interface{})["_additional"].(map[string]interface{})
    39  			result = append(result, additional["id"].(string))
    40  		}
    41  		groupedBy := group["groupedBy"].(map[string]interface{})
    42  		groupedByValue := groupedBy["value"].(string)
    43  		return groupedByValue, result
    44  	}
    45  	// test methods
    46  	create := func(t *testing.T, multishard bool) {
    47  		for _, class := range documents.ClassesContextionaryVectorizer(multishard) {
    48  			helper.CreateClass(t, class)
    49  		}
    50  		for _, obj := range documents.Objects() {
    51  			helper.CreateObject(t, obj)
    52  			helper.AssertGetObjectEventually(t, obj.Class, obj.ID)
    53  		}
    54  	}
    55  	groupBy := func(t *testing.T, groupsCount, objectsPerGroup int) {
    56  		query := `
    57  		{
    58  			Get{
    59  				Passage(
    60  					nearObject:{
    61  						id: "00000000-0000-0000-0000-000000000001"
    62  					}
    63  					groupBy:{
    64  						path:["ofDocument"]
    65  						groups:%v
    66  						objectsPerGroup:%v
    67  					}
    68  				){
    69  					_additional{
    70  						id
    71  						group{
    72  							groupedBy{value}
    73  							count
    74  							maxDistance
    75  							minDistance
    76  							hits {
    77  								_additional{
    78  									id
    79  									distance
    80  								}
    81  							}
    82  						}
    83  					}
    84  				}
    85  			}
    86  		}
    87  		`
    88  		result := graphqlhelper.AssertGraphQL(t, helper.RootAuth, fmt.Sprintf(query, groupsCount, objectsPerGroup))
    89  		groups := result.Get("Get", "Passage").AsSlice()
    90  
    91  		require.Len(t, groups, groupsCount)
    92  
    93  		expectedResults := map[string][]string{}
    94  
    95  		groupedBy1 := `weaviate://localhost/Document/00000000-0000-0000-0000-000000000011`
    96  		expectedGroup1 := []string{
    97  			documents.PassageIDs[0].String(),
    98  			documents.PassageIDs[5].String(),
    99  			documents.PassageIDs[4].String(),
   100  			documents.PassageIDs[3].String(),
   101  			documents.PassageIDs[2].String(),
   102  			documents.PassageIDs[1].String(),
   103  		}
   104  		expectedResults[groupedBy1] = expectedGroup1
   105  
   106  		groupedBy2 := `weaviate://localhost/Document/00000000-0000-0000-0000-000000000012`
   107  		expectedGroup2 := []string{
   108  			documents.PassageIDs[6].String(),
   109  			documents.PassageIDs[7].String(),
   110  		}
   111  		expectedResults[groupedBy2] = expectedGroup2
   112  
   113  		groupsOrder := []string{groupedBy1, groupedBy2}
   114  
   115  		for i, current := range groups {
   116  			group := getGroup(current)
   117  			groupedBy, ids := getGroupHits(group)
   118  			assert.Equal(t, groupsOrder[i], groupedBy)
   119  			for j := range ids {
   120  				assert.Equal(t, expectedResults[groupedBy][j], ids[j])
   121  			}
   122  		}
   123  	}
   124  	delete := func(t *testing.T) {
   125  		helper.DeleteClass(t, documents.Passage)
   126  		helper.DeleteClass(t, documents.Document)
   127  	}
   128  	// tests
   129  	tests := []struct {
   130  		name                    string
   131  		multishard              bool
   132  		groups, objectsPerGroup int
   133  	}{
   134  		{
   135  			name:            "single shard - 2 groups 10 objects per group",
   136  			multishard:      false,
   137  			groups:          2,
   138  			objectsPerGroup: 10,
   139  		},
   140  		{
   141  			name:            "multi shard -  2 groups 10 objects per group",
   142  			multishard:      true,
   143  			groups:          2,
   144  			objectsPerGroup: 10,
   145  		},
   146  		{
   147  			name:            "single shard -  1 groups 1 objects per group",
   148  			multishard:      false,
   149  			groups:          1,
   150  			objectsPerGroup: 1,
   151  		},
   152  		{
   153  			name:            "multi shard -  1 groups 1 objects per group",
   154  			multishard:      true,
   155  			groups:          1,
   156  			objectsPerGroup: 1,
   157  		},
   158  	}
   159  	for _, tt := range tests {
   160  		t.Run(tt.name, func(t *testing.T) {
   161  			t.Run("create", func(t *testing.T) {
   162  				create(t, tt.multishard)
   163  			})
   164  			t.Run("group by", func(t *testing.T) {
   165  				groupBy(t, tt.groups, tt.objectsPerGroup)
   166  			})
   167  			t.Run("delete", func(t *testing.T) {
   168  				delete(t)
   169  			})
   170  		})
   171  	}
   172  }