github.com/weaviate/weaviate@v1.24.6/test/acceptance/multi_tenancy/gql_aggregate_tenant_objects_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package test
    13  
    14  import (
    15  	"encoding/json"
    16  	"fmt"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/go-openapi/strfmt"
    21  	"github.com/google/uuid"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  	"github.com/weaviate/weaviate/client/nodes"
    25  	"github.com/weaviate/weaviate/entities/models"
    26  	"github.com/weaviate/weaviate/entities/schema"
    27  	"github.com/weaviate/weaviate/test/helper"
    28  	graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql"
    29  )
    30  
    31  func TestGQLAggregateTenantObjects(t *testing.T) {
    32  	testClass := models.Class{
    33  		Class: "MultiTenantClass",
    34  		MultiTenancyConfig: &models.MultiTenancyConfig{
    35  			Enabled: true,
    36  		},
    37  		Properties: []*models.Property{
    38  			{
    39  				Name:     "name",
    40  				DataType: schema.DataTypeText.PropString(),
    41  			},
    42  		},
    43  	}
    44  	tenantName1 := "Tenant1"
    45  	tenantName2 := "Tenant2"
    46  	numTenantObjs1 := 5
    47  	numTenantObjs2 := 3
    48  
    49  	defer func() {
    50  		helper.DeleteClass(t, testClass.Class)
    51  	}()
    52  	helper.CreateClass(t, &testClass)
    53  
    54  	tenants := []*models.Tenant{
    55  		{Name: tenantName1},
    56  		{Name: tenantName2},
    57  	}
    58  	helper.CreateTenants(t, testClass.Class, tenants)
    59  
    60  	batch1 := makeTenantBatch(batchParams{
    61  		className:  testClass.Class,
    62  		tenantName: tenantName1,
    63  		batchSize:  numTenantObjs1,
    64  	})
    65  	batch2 := makeTenantBatch(batchParams{
    66  		className:  testClass.Class,
    67  		tenantName: tenantName2,
    68  		batchSize:  numTenantObjs2,
    69  	})
    70  
    71  	helper.CreateObjectsBatch(t, batch1)
    72  	helper.CreateObjectsBatch(t, batch2)
    73  
    74  	t.Run("GQL Aggregate tenant objects", func(t *testing.T) {
    75  		testAggregateTenantSuccess(t, testClass.Class, tenantName1, numTenantObjs1, "")
    76  		testAggregateTenantSuccess(t, testClass.Class, tenantName2, numTenantObjs2, "")
    77  	})
    78  
    79  	t.Run("GQL Aggregate tenant objects near object", func(t *testing.T) {
    80  		testAggregateTenantSuccess(t, testClass.Class, tenantName1, numTenantObjs1, string(batch1[0].ID))
    81  		testAggregateTenantSuccess(t, testClass.Class, tenantName2, numTenantObjs2, string(batch2[0].ID))
    82  	})
    83  
    84  	t.Run("Get global tenant objects count", func(t *testing.T) {
    85  		assert.Eventually(t, func() bool {
    86  			params := nodes.NewNodesGetClassParams().WithClassName(testClass.Class).WithOutput(&verbose)
    87  			resp, err := helper.Client(t).Nodes.NodesGetClass(params, nil)
    88  			require.Nil(t, err)
    89  
    90  			payload := resp.GetPayload()
    91  			require.NotNil(t, payload)
    92  			require.NotNil(t, payload.Nodes)
    93  			require.Len(t, payload.Nodes, 1)
    94  
    95  			node := payload.Nodes[0]
    96  			require.NotNil(t, node)
    97  			assert.Equal(t, models.NodeStatusStatusHEALTHY, *node.Status)
    98  			assert.True(t, len(node.Name) > 0)
    99  			assert.True(t, node.GitHash != "" && node.GitHash != "unknown")
   100  			assert.Len(t, node.Shards, 2)
   101  
   102  			shardCount := map[string]int64{
   103  				tenantName1: int64(numTenantObjs1),
   104  				tenantName2: int64(numTenantObjs2),
   105  			}
   106  
   107  			for _, shard := range node.Shards {
   108  				count, ok := shardCount[shard.Name]
   109  				require.True(t, ok, "expected shard %q to be in %+v",
   110  					shard.Name, []string{tenantName1, tenantName2})
   111  
   112  				assert.Equal(t, testClass.Class, shard.Class)
   113  				if count != shard.ObjectCount {
   114  					return false
   115  				}
   116  			}
   117  
   118  			require.NotNil(t, node.Stats)
   119  			assert.Equal(t, int64(2), node.Stats.ShardCount)
   120  			return int64(numTenantObjs1+numTenantObjs2) == node.Stats.ObjectCount
   121  		}, 15*time.Second, 500*time.Millisecond)
   122  	})
   123  }
   124  
   125  func TestGQLAggregateTenantObjects_InvalidTenant(t *testing.T) {
   126  	testClass := models.Class{
   127  		Class: "MultiTenantClass",
   128  		MultiTenancyConfig: &models.MultiTenancyConfig{
   129  			Enabled: true,
   130  		},
   131  		Properties: []*models.Property{
   132  			{
   133  				Name:     "name",
   134  				DataType: schema.DataTypeText.PropString(),
   135  			},
   136  		},
   137  	}
   138  	tenantName := "Tenant1"
   139  	numTenantObjs := 5
   140  
   141  	defer func() {
   142  		helper.DeleteClass(t, testClass.Class)
   143  	}()
   144  
   145  	t.Run("setup test data", func(t *testing.T) {
   146  		t.Run("create class with multi-tenancy enabled", func(t *testing.T) {
   147  			helper.CreateClass(t, &testClass)
   148  		})
   149  
   150  		t.Run("create tenants", func(t *testing.T) {
   151  			tenants := []*models.Tenant{
   152  				{Name: tenantName},
   153  			}
   154  			helper.CreateTenants(t, testClass.Class, tenants)
   155  		})
   156  
   157  		t.Run("add tenant objects", func(t *testing.T) {
   158  			batch := makeTenantBatch(batchParams{
   159  				className:  testClass.Class,
   160  				tenantName: tenantName,
   161  				batchSize:  numTenantObjs,
   162  			})
   163  			helper.CreateObjectsBatch(t, batch)
   164  		})
   165  	})
   166  
   167  	t.Run("non-existent tenant key", func(t *testing.T) {
   168  		query := fmt.Sprintf(`{Aggregate{%s(tenant:"DNE"){meta{count}}}}`, testClass.Class)
   169  		expected := `"DNE"`
   170  		resp, err := graphqlhelper.QueryGraphQL(t, helper.RootAuth, "", query, nil)
   171  		require.Nil(t, err)
   172  		assert.Nil(t, resp.Data["Aggregate"].(map[string]interface{})[testClass.Class])
   173  		assert.Len(t, resp.Errors, 1)
   174  		assert.Contains(t, resp.Errors[0].Message, expected)
   175  	})
   176  }
   177  
   178  type batchParams struct {
   179  	className  string
   180  	tenantName string
   181  	batchSize  int
   182  }
   183  
   184  func makeTenantBatch(params batchParams) []*models.Object {
   185  	batch := make([]*models.Object, params.batchSize)
   186  	for i := range batch {
   187  		batch[i] = &models.Object{
   188  			ID:    strfmt.UUID(uuid.New().String()),
   189  			Class: params.className,
   190  			Properties: map[string]interface{}{
   191  				"name": params.tenantName,
   192  			},
   193  			Tenant: params.tenantName,
   194  		}
   195  	}
   196  	return batch
   197  }
   198  
   199  func testAggregateTenantSuccess(t *testing.T, className, tenantName string, expectedCount int, nearObjectId string) {
   200  	nearObject := ""
   201  	if nearObjectId != "" {
   202  		nearObject = fmt.Sprintf(`nearObject: {id: "%s", certainty: 0.4},`, nearObjectId)
   203  	}
   204  
   205  	query := fmt.Sprintf(`{Aggregate{%s(%s,tenant:%q){meta{count}}}}`, className, nearObject, tenantName)
   206  	resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query)
   207  	result := resp.Get("Aggregate", className).AsSlice()
   208  	require.Len(t, result, 1)
   209  	count := result[0].(map[string]any)["meta"].(map[string]any)["count"].(json.Number)
   210  	assert.Equal(t, json.Number(fmt.Sprint(expectedCount)), count)
   211  }