github.com/weaviate/weaviate@v1.24.6/test/helper/journey/group_by_journey_tests.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package journey 13 14 import ( 15 "fmt" 16 "testing" 17 18 "github.com/stretchr/testify/assert" 19 "github.com/stretchr/testify/require" 20 "github.com/weaviate/weaviate/test/helper" 21 graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql" 22 "github.com/weaviate/weaviate/test/helper/sample-schema/documents" 23 ) 24 25 func GroupBySingleAndMultiShardTests(t *testing.T, weaviateEndpoint string) { 26 if weaviateEndpoint != "" { 27 helper.SetupClient(weaviateEndpoint) 28 } 29 // helper methods 30 getGroup := func(value interface{}) map[string]interface{} { 31 group := value.(map[string]interface{})["_additional"].(map[string]interface{})["group"].(map[string]interface{}) 32 return group 33 } 34 getGroupHits := func(group map[string]interface{}) (string, []string) { 35 result := []string{} 36 hits := group["hits"].([]interface{}) 37 for _, hit := range hits { 38 additional := hit.(map[string]interface{})["_additional"].(map[string]interface{}) 39 result = append(result, additional["id"].(string)) 40 } 41 groupedBy := group["groupedBy"].(map[string]interface{}) 42 groupedByValue := groupedBy["value"].(string) 43 return groupedByValue, result 44 } 45 // test methods 46 create := func(t *testing.T, multishard bool) { 47 for _, class := range documents.ClassesContextionaryVectorizer(multishard) { 48 helper.CreateClass(t, class) 49 } 50 for _, obj := range documents.Objects() { 51 helper.CreateObject(t, obj) 52 helper.AssertGetObjectEventually(t, obj.Class, obj.ID) 53 } 54 } 55 groupBy := func(t *testing.T, groupsCount, objectsPerGroup int) { 56 query := ` 57 { 58 Get{ 59 Passage( 60 nearObject:{ 61 id: "00000000-0000-0000-0000-000000000001" 62 } 63 groupBy:{ 64 path:["ofDocument"] 65 groups:%v 66 objectsPerGroup:%v 67 } 68 ){ 69 _additional{ 70 id 71 group{ 72 groupedBy{value} 73 count 74 maxDistance 75 minDistance 76 hits { 77 _additional{ 78 id 79 distance 80 } 81 } 82 } 83 } 84 } 85 } 86 } 87 ` 88 result := graphqlhelper.AssertGraphQL(t, helper.RootAuth, fmt.Sprintf(query, groupsCount, objectsPerGroup)) 89 groups := result.Get("Get", "Passage").AsSlice() 90 91 require.Len(t, groups, groupsCount) 92 93 expectedResults := map[string][]string{} 94 95 groupedBy1 := `weaviate://localhost/Document/00000000-0000-0000-0000-000000000011` 96 expectedGroup1 := []string{ 97 documents.PassageIDs[0].String(), 98 documents.PassageIDs[5].String(), 99 documents.PassageIDs[4].String(), 100 documents.PassageIDs[3].String(), 101 documents.PassageIDs[2].String(), 102 documents.PassageIDs[1].String(), 103 } 104 expectedResults[groupedBy1] = expectedGroup1 105 106 groupedBy2 := `weaviate://localhost/Document/00000000-0000-0000-0000-000000000012` 107 expectedGroup2 := []string{ 108 documents.PassageIDs[6].String(), 109 documents.PassageIDs[7].String(), 110 } 111 expectedResults[groupedBy2] = expectedGroup2 112 113 groupsOrder := []string{groupedBy1, groupedBy2} 114 115 for i, current := range groups { 116 group := getGroup(current) 117 groupedBy, ids := getGroupHits(group) 118 assert.Equal(t, groupsOrder[i], groupedBy) 119 for j := range ids { 120 assert.Equal(t, expectedResults[groupedBy][j], ids[j]) 121 } 122 } 123 } 124 delete := func(t *testing.T) { 125 helper.DeleteClass(t, documents.Passage) 126 helper.DeleteClass(t, documents.Document) 127 } 128 // tests 129 tests := []struct { 130 name string 131 multishard bool 132 groups, objectsPerGroup int 133 }{ 134 { 135 name: "single shard - 2 groups 10 objects per group", 136 multishard: false, 137 groups: 2, 138 objectsPerGroup: 10, 139 }, 140 { 141 name: "multi shard - 2 groups 10 objects per group", 142 multishard: true, 143 groups: 2, 144 objectsPerGroup: 10, 145 }, 146 { 147 name: "single shard - 1 groups 1 objects per group", 148 multishard: false, 149 groups: 1, 150 objectsPerGroup: 1, 151 }, 152 { 153 name: "multi shard - 1 groups 1 objects per group", 154 multishard: true, 155 groups: 1, 156 objectsPerGroup: 1, 157 }, 158 } 159 for _, tt := range tests { 160 t.Run(tt.name, func(t *testing.T) { 161 t.Run("create", func(t *testing.T) { 162 create(t, tt.multishard) 163 }) 164 t.Run("group by", func(t *testing.T) { 165 groupBy(t, tt.groups, tt.objectsPerGroup) 166 }) 167 t.Run("delete", func(t *testing.T) { 168 delete(t) 169 }) 170 }) 171 } 172 }