github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/collection_test.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"testing"
     8  
     9  	"github.com/cockroachdb/errors"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
    13  	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
    14  	"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
    15  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/mock"
    18  	"github.com/stretchr/testify/suite"
    19  )
    20  
    21  type CollectionSuite struct {
    22  	MockSuiteBase
    23  }
    24  
    25  func (s *CollectionSuite) TestListCollections() {
    26  	c := s.client
    27  	ctx, cancel := context.WithCancel(context.Background())
    28  	defer cancel()
    29  
    30  	type testCase struct {
    31  		ids     []int64
    32  		names   []string
    33  		collNum int
    34  		inMem   []int64
    35  	}
    36  	caseLen := 5
    37  	cases := make([]testCase, 0, caseLen)
    38  	for i := 0; i < caseLen; i++ {
    39  		collNum := rand.Intn(5) + 2
    40  		tc := testCase{
    41  			ids:     make([]int64, 0, collNum),
    42  			names:   make([]string, 0, collNum),
    43  			collNum: collNum,
    44  		}
    45  		base := rand.Intn(1000)
    46  		for j := 0; j < collNum; j++ {
    47  			base += rand.Intn(1000)
    48  			tc.ids = append(tc.ids, int64(base))
    49  			base += rand.Intn(500)
    50  			tc.names = append(tc.names, fmt.Sprintf("coll_%d", base))
    51  			inMem := rand.Intn(100)
    52  			if inMem%2 == 0 {
    53  
    54  				tc.inMem = append(tc.inMem, 100)
    55  			} else {
    56  				tc.inMem = append(tc.inMem, 0)
    57  			}
    58  		}
    59  		cases = append(cases, tc)
    60  	}
    61  
    62  	for i, tc := range cases {
    63  		s.Run(fmt.Sprintf("run_%d", i), func() {
    64  			s.resetMock()
    65  			s.mock.EXPECT().ShowCollections(mock.Anything, mock.AnythingOfType("*milvuspb.ShowCollectionsRequest")).
    66  				Return(&milvuspb.ShowCollectionsResponse{
    67  					Status:              &commonpb.Status{},
    68  					CollectionIds:       tc.ids,
    69  					CollectionNames:     tc.names,
    70  					InMemoryPercentages: tc.inMem,
    71  				}, nil)
    72  
    73  			collections, err := c.ListCollections(ctx)
    74  
    75  			s.Require().Equal(tc.collNum, len(collections))
    76  			s.Require().NoError(err)
    77  
    78  			// assert element match
    79  			rids := make([]int64, 0, len(collections))
    80  			rnames := make([]string, 0, len(collections))
    81  			for _, collection := range collections {
    82  				rids = append(rids, collection.ID)
    83  				rnames = append(rnames, collection.Name)
    84  			}
    85  
    86  			s.ElementsMatch(tc.ids, rids)
    87  			s.ElementsMatch(tc.names, rnames)
    88  			// assert id & name match
    89  			for idx, rid := range rids {
    90  				for jdx, id := range tc.ids {
    91  					if rid == id {
    92  						s.Equal(tc.names[idx], rnames[idx])
    93  						s.Equal(tc.inMem[jdx] == 100, collections[idx].Loaded)
    94  					}
    95  				}
    96  			}
    97  		})
    98  	}
    99  }
   100  
   101  func (s *CollectionSuite) TestCreateCollection() {
   102  	c := s.client
   103  	ctx, cancel := context.WithCancel(context.Background())
   104  	defer cancel()
   105  
   106  	s.Run("normal_creation", func() {
   107  		ds := defaultSchema()
   108  		shardsNum := int32(1)
   109  
   110  		defer s.resetMock()
   111  		s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")).
   112  			Run(func(ctx context.Context, req *milvuspb.CreateCollectionRequest) {
   113  				s.Equal(testCollectionName, req.GetCollectionName())
   114  				sschema := &schemapb.CollectionSchema{}
   115  				s.Require().NoError(proto.Unmarshal(req.GetSchema(), sschema))
   116  				s.Require().Equal(len(ds.Fields), len(sschema.Fields))
   117  				for idx, fieldSchema := range ds.Fields {
   118  					s.Equal(fieldSchema.Name, sschema.GetFields()[idx].GetName())
   119  					s.Equal(fieldSchema.PrimaryKey, sschema.GetFields()[idx].GetIsPrimaryKey())
   120  					s.Equal(fieldSchema.AutoID, sschema.GetFields()[idx].GetAutoID())
   121  					s.EqualValues(fieldSchema.DataType, sschema.GetFields()[idx].GetDataType())
   122  				}
   123  				s.Equal(shardsNum, req.GetShardsNum())
   124  				s.Equal(commonpb.ConsistencyLevel_Bounded, req.GetConsistencyLevel())
   125  			}).
   126  			Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   127  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)
   128  
   129  		err := c.CreateCollection(ctx, ds, shardsNum, WithCreateCollectionMsgBase(&commonpb.MsgBase{}))
   130  		s.NoError(err)
   131  	})
   132  
   133  	s.Run("create_with_consistency_level", func() {
   134  		ds := defaultSchema()
   135  		shardsNum := int32(1)
   136  		defer s.resetMock()
   137  		s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")).
   138  			Run(func(ctx context.Context, req *milvuspb.CreateCollectionRequest) {
   139  				s.Equal(testCollectionName, req.GetCollectionName())
   140  				sschema := &schemapb.CollectionSchema{}
   141  				s.Require().NoError(proto.Unmarshal(req.GetSchema(), sschema))
   142  				s.Require().Equal(len(ds.Fields), len(sschema.Fields))
   143  				for idx, fieldSchema := range ds.Fields {
   144  					s.Equal(fieldSchema.Name, sschema.GetFields()[idx].GetName())
   145  					s.Equal(fieldSchema.PrimaryKey, sschema.GetFields()[idx].GetIsPrimaryKey())
   146  					s.Equal(fieldSchema.AutoID, sschema.GetFields()[idx].GetAutoID())
   147  					s.EqualValues(fieldSchema.DataType, sschema.GetFields()[idx].GetDataType())
   148  				}
   149  				s.Equal(shardsNum, req.GetShardsNum())
   150  				s.Equal(commonpb.ConsistencyLevel_Eventually, req.GetConsistencyLevel())
   151  
   152  			}).
   153  			Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   154  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)
   155  
   156  		err := c.CreateCollection(ctx, ds, shardsNum, WithConsistencyLevel(entity.ClEventually))
   157  		s.NoError(err)
   158  	})
   159  
   160  	s.Run("invalid_schemas", func() {
   161  
   162  		type testCase struct {
   163  			name   string
   164  			schema *entity.Schema
   165  		}
   166  		cases := []testCase{
   167  			{
   168  				name:   "empty_fields",
   169  				schema: entity.NewSchema().WithName(testCollectionName),
   170  			},
   171  			{
   172  				name: "empty_collection_name",
   173  				schema: entity.NewSchema().
   174  					WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
   175  					WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)),
   176  			},
   177  			{
   178  				name: "multiple primary key",
   179  				schema: entity.NewSchema().WithName(testCollectionName).
   180  					WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
   181  					WithField(entity.NewField().WithName("int64_2").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
   182  					WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)),
   183  			},
   184  			{
   185  				name: "multiple auto id",
   186  				schema: entity.NewSchema().WithName(testCollectionName).
   187  					WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)).
   188  					WithField(entity.NewField().WithName("int64_2").WithDataType(entity.FieldTypeInt64).WithIsAutoID(true)).
   189  					WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)),
   190  			},
   191  			{
   192  				name: "bad_pk_type",
   193  				schema: entity.NewSchema().
   194  					WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeDouble).WithIsPrimaryKey(true)).
   195  					WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)),
   196  			},
   197  		}
   198  
   199  		for _, tc := range cases {
   200  			s.Run(tc.name, func() {
   201  				err := c.CreateCollection(ctx, tc.schema, 1)
   202  				s.Error(err)
   203  			})
   204  		}
   205  	})
   206  
   207  	s.Run("server_returns_error", func() {
   208  		s.Run("create_collection_error", func() {
   209  			defer s.resetMock()
   210  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)
   211  			s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")).
   212  				Return(nil, errors.New("mocked grpc error"))
   213  
   214  			err := c.CreateCollection(ctx, defaultSchema(), 1)
   215  			s.Error(err)
   216  		})
   217  
   218  		s.Run("create_collection_fail", func() {
   219  			defer s.resetMock()
   220  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)
   221  			s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")).
   222  				Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil)
   223  
   224  			err := c.CreateCollection(ctx, defaultSchema(), 1)
   225  			s.Error(err)
   226  		})
   227  	})
   228  
   229  	s.Run("feature_not_support", func() {
   230  		cases := []struct {
   231  			tag  string
   232  			flag uint64
   233  		}{
   234  			{tag: "json", flag: disableJSON},
   235  			{tag: "partition_key", flag: disableParitionKey},
   236  			{tag: "dyanmic_schema", flag: disableDynamicSchema},
   237  		}
   238  		sch := entity.NewSchema().WithName("all_feature").WithDynamicFieldEnabled(true).
   239  			WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
   240  			WithField(entity.NewField().WithName("embedding").WithDataType(entity.FieldTypeFloatVector).WithDim(128)).
   241  			WithField(entity.NewField().WithName("partition").WithDataType(entity.FieldTypeInt64).WithIsPartitionKey(true)).
   242  			WithField(entity.NewField().WithName("dynamic").WithDataType(entity.FieldTypeJSON).WithIsDynamic(true))
   243  		for _, tc := range cases {
   244  			s.Run(tc.tag, func() {
   245  				grpcClient, ok := c.(*GrpcClient)
   246  				s.Require().True(ok)
   247  				grpcClient.config.addFlags(tc.flag)
   248  				defer grpcClient.config.resetFlags(tc.flag)
   249  
   250  				err := c.CreateCollection(ctx, sch, 1)
   251  				s.Error(err)
   252  				s.ErrorIs(err, ErrFeatureNotSupported)
   253  			})
   254  		}
   255  	})
   256  }
   257  
   258  func (s *CollectionSuite) TestNewCollection() {
   259  	c := s.client
   260  	ctx, cancel := context.WithCancel(context.Background())
   261  
   262  	defer cancel()
   263  	s.resetMock()
   264  
   265  	s.Run("all_default", func() {
   266  		defer s.resetMock()
   267  
   268  		created := false
   269  		s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")).
   270  			Run(func(ctx context.Context, req *milvuspb.CreateCollectionRequest) {
   271  				s.Equal(testCollectionName, req.GetCollectionName())
   272  				sschema := &schemapb.CollectionSchema{}
   273  				s.Require().NoError(proto.Unmarshal(req.GetSchema(), sschema))
   274  				s.Require().Equal(2, len(sschema.Fields))
   275  				for _, field := range sschema.Fields {
   276  					if field.GetName() == "id" {
   277  						s.Equal(schemapb.DataType_Int64, field.GetDataType())
   278  					}
   279  					if field.GetName() == "vector" {
   280  						s.Equal(schemapb.DataType_FloatVector, field.GetDataType())
   281  					}
   282  				}
   283  
   284  				s.Equal(entity.DefaultShardNumber, req.GetShardsNum())
   285  				s.Equal(entity.DefaultConsistencyLevel.CommonConsistencyLevel(), req.GetConsistencyLevel())
   286  				created = true
   287  			}).
   288  			Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   289  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Call.Return(func(_ context.Context, _ *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse {
   290  			return &milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: created}
   291  		}, nil)
   292  		s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   293  		s.mock.EXPECT().Flush(mock.Anything, mock.Anything).Return(&milvuspb.FlushResponse{
   294  			Status:     &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   295  			CollSegIDs: map[string]*schemapb.LongArray{},
   296  		}, nil)
   297  		s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&milvuspb.DescribeIndexResponse{
   298  			Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   299  			IndexDescriptions: []*milvuspb.IndexDescription{
   300  				{FieldName: "vector", State: commonpb.IndexState_Finished},
   301  			},
   302  		}, nil)
   303  		s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   304  		s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadingProgressResponse{
   305  			Status:   &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   306  			Progress: 100,
   307  		}, nil)
   308  		s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
   309  			Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   310  			Schema: &schemapb.CollectionSchema{
   311  				Fields: []*schemapb.FieldSchema{
   312  					{Name: "id", DataType: schemapb.DataType_VarChar},
   313  					{Name: "vector", DataType: schemapb.DataType_FloatVector},
   314  				},
   315  			},
   316  		}, nil)
   317  
   318  		err := c.NewCollection(ctx, testCollectionName, testVectorDim)
   319  		s.NoError(err)
   320  	})
   321  
   322  	s.Run("with_custom_set", func() {
   323  		defer s.resetMock()
   324  		created := false
   325  		s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")).
   326  			Run(func(ctx context.Context, req *milvuspb.CreateCollectionRequest) {
   327  				s.Equal(testCollectionName, req.GetCollectionName())
   328  				sschema := &schemapb.CollectionSchema{}
   329  				s.Require().NoError(proto.Unmarshal(req.GetSchema(), sschema))
   330  				s.Require().Equal(2, len(sschema.Fields))
   331  				for _, field := range sschema.Fields {
   332  					if field.GetName() == "my_pk" {
   333  						s.Equal(schemapb.DataType_VarChar, field.GetDataType())
   334  					}
   335  					if field.GetName() == "embedding" {
   336  						s.Equal(schemapb.DataType_FloatVector, field.GetDataType())
   337  					}
   338  				}
   339  
   340  				s.Equal(entity.DefaultShardNumber, req.GetShardsNum())
   341  				s.Equal(entity.ClEventually.CommonConsistencyLevel(), req.GetConsistencyLevel())
   342  				created = true
   343  			}).
   344  			Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   345  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Call.Return(func(_ context.Context, _ *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse {
   346  			return &milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: created}
   347  		}, nil)
   348  		s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   349  		s.mock.EXPECT().Flush(mock.Anything, mock.Anything).Return(&milvuspb.FlushResponse{
   350  			Status:     &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   351  			CollSegIDs: map[string]*schemapb.LongArray{},
   352  		}, nil)
   353  		s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&milvuspb.DescribeIndexResponse{
   354  			Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   355  			IndexDescriptions: []*milvuspb.IndexDescription{
   356  				{FieldName: "embedding", State: commonpb.IndexState_Finished},
   357  			},
   358  		}, nil)
   359  		s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   360  		s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadingProgressResponse{
   361  			Status:   &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   362  			Progress: 100,
   363  		}, nil)
   364  		s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
   365  			Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   366  			Schema: &schemapb.CollectionSchema{
   367  				Fields: []*schemapb.FieldSchema{
   368  					{Name: "my_pk", DataType: schemapb.DataType_VarChar},
   369  					{Name: "embedding", DataType: schemapb.DataType_FloatVector},
   370  				},
   371  			},
   372  		}, nil)
   373  
   374  		err := c.NewCollection(ctx, testCollectionName, testVectorDim, WithPKFieldName("my_pk"), WithPKFieldType(entity.FieldTypeVarChar), WithVectorFieldName("embedding"), WithConsistencyLevel(entity.ClEventually))
   375  		s.NoError(err)
   376  	})
   377  }
   378  
   379  func (s *CollectionSuite) TestRenameCollection() {
   380  	c := s.client
   381  	ctx, cancel := context.WithCancel(context.Background())
   382  	defer cancel()
   383  	s.Run("normal_run", func() {
   384  		defer s.resetMock()
   385  
   386  		newCollName := fmt.Sprintf("new_%s", randStr(6))
   387  
   388  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   389  		s.mock.EXPECT().RenameCollection(mock.Anything, &milvuspb.RenameCollectionRequest{OldName: testCollectionName, NewName: newCollName}).Return(&commonpb.Status{}, nil)
   390  
   391  		err := c.RenameCollection(ctx, testCollectionName, newCollName)
   392  		s.NoError(err)
   393  	})
   394  
   395  	s.Run("coll_not_exist", func() {
   396  		defer s.resetMock()
   397  
   398  		newCollName := fmt.Sprintf("new_%s", randStr(6))
   399  
   400  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)
   401  
   402  		err := c.RenameCollection(ctx, testCollectionName, newCollName)
   403  		s.Error(err)
   404  	})
   405  
   406  	s.Run("rename_failed", func() {
   407  		defer s.resetMock()
   408  
   409  		newCollName := fmt.Sprintf("new_%s", randStr(6))
   410  
   411  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   412  		s.mock.EXPECT().RenameCollection(mock.Anything, &milvuspb.RenameCollectionRequest{OldName: testCollectionName, NewName: newCollName}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mocked failure"}, nil)
   413  
   414  		err := c.RenameCollection(ctx, testCollectionName, newCollName)
   415  		s.Error(err)
   416  	})
   417  
   418  	s.Run("rename_error", func() {
   419  		defer s.resetMock()
   420  
   421  		newCollName := fmt.Sprintf("new_%s", randStr(6))
   422  
   423  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   424  		s.mock.EXPECT().RenameCollection(mock.Anything, &milvuspb.RenameCollectionRequest{OldName: testCollectionName, NewName: newCollName}).Return(nil, errors.New("mocked error"))
   425  
   426  		err := c.RenameCollection(ctx, testCollectionName, newCollName)
   427  		s.Error(err)
   428  	})
   429  }
   430  
   431  func (s *CollectionSuite) TestAlterCollection() {
   432  	c := s.client
   433  	ctx, cancel := context.WithCancel(context.Background())
   434  	defer cancel()
   435  
   436  	s.Run("normal_run", func() {
   437  		defer s.resetMock()
   438  
   439  		s.setupHasCollection(testCollectionName)
   440  		s.mock.EXPECT().AlterCollection(mock.Anything, mock.AnythingOfType("*milvuspb.AlterCollectionRequest")).
   441  			Return(&commonpb.Status{}, nil)
   442  
   443  		err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000))
   444  		s.NoError(err)
   445  	})
   446  
   447  	s.Run("collection_not_exist", func() {
   448  		defer s.resetMock()
   449  
   450  		s.mock.EXPECT().HasCollection(mock.Anything, mock.AnythingOfType("*milvuspb.HasCollectionRequest")).
   451  			Return(&milvuspb.BoolResponse{
   452  				Status: &commonpb.Status{},
   453  				Value:  false,
   454  			}, nil)
   455  
   456  		err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000))
   457  		s.Error(err)
   458  	})
   459  
   460  	s.Run("no_attributes", func() {
   461  		defer s.resetMock()
   462  
   463  		s.setupHasCollection(testCollectionName)
   464  		err := c.AlterCollection(ctx, testCollectionName)
   465  		s.Error(err)
   466  	})
   467  
   468  	s.Run("request_fails", func() {
   469  		defer s.resetMock()
   470  
   471  		s.setupHasCollection(testCollectionName)
   472  		s.mock.EXPECT().AlterCollection(mock.Anything, mock.AnythingOfType("*milvuspb.AlterCollectionRequest")).
   473  			Return(nil, errors.New("mocked"))
   474  
   475  		err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000))
   476  		s.Error(err)
   477  	})
   478  
   479  	s.Run("server_return_error", func() {
   480  		defer s.resetMock()
   481  
   482  		s.setupHasCollection(testCollectionName)
   483  		s.mock.EXPECT().AlterCollection(mock.Anything, mock.AnythingOfType("*milvuspb.AlterCollectionRequest")).
   484  			Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil)
   485  
   486  		err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000))
   487  		s.Error(err)
   488  	})
   489  
   490  	s.Run("service_not_ready", func() {
   491  		c := &GrpcClient{}
   492  		err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000))
   493  		s.ErrorIs(err, ErrClientNotReady)
   494  	})
   495  }
   496  
   497  func (s *CollectionSuite) TestLoadCollection() {
   498  	ctx, cancel := context.WithCancel(context.Background())
   499  	defer cancel()
   500  
   501  	c := s.client
   502  
   503  	s.Run("normal_run_async", func() {
   504  		defer s.resetMock()
   505  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   506  			Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   507  
   508  		s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   509  
   510  		err := c.LoadCollection(ctx, testCollectionName, true, WithLoadCollectionMsgBase(&commonpb.MsgBase{}))
   511  		s.NoError(err)
   512  	})
   513  
   514  	s.Run("normal_run_sync", func() {
   515  		defer s.resetMock()
   516  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   517  			Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   518  
   519  		s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   520  		s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).
   521  			Return(&milvuspb.GetLoadingProgressResponse{
   522  				Status:   &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
   523  				Progress: 100,
   524  			}, nil)
   525  
   526  		err := c.LoadCollection(ctx, testCollectionName, true)
   527  		s.NoError(err)
   528  	})
   529  
   530  	s.Run("load_default_replica", func() {
   531  		defer s.resetMock()
   532  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   533  			Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   534  
   535  		s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadCollectionRequest) {
   536  			s.Equal(testDefaultReplicaNumber, req.GetReplicaNumber())
   537  			s.Equal(testCollectionName, req.GetCollectionName())
   538  		}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   539  
   540  		err := c.LoadCollection(ctx, testCollectionName, true)
   541  		s.NoError(err)
   542  	})
   543  
   544  	s.Run("load_multiple_replica", func() {
   545  		defer s.resetMock()
   546  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   547  			Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   548  
   549  		s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadCollectionRequest) {
   550  			s.Equal(testMultiReplicaNumber, req.GetReplicaNumber())
   551  			s.Equal(testCollectionName, req.GetCollectionName())
   552  		}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   553  
   554  		err := c.LoadCollection(ctx, testCollectionName, true, WithReplicaNumber(testMultiReplicaNumber))
   555  		s.NoError(err)
   556  	})
   557  
   558  	s.Run("has_collection_failure", func() {
   559  		s.Run("return_false", func() {
   560  			defer s.resetMock()
   561  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   562  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)
   563  
   564  			err := c.LoadCollection(ctx, testCollectionName, true)
   565  			s.Error(err)
   566  		})
   567  
   568  		s.Run("return_error", func() {
   569  			defer s.resetMock()
   570  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   571  				Return(nil, errors.New("mock error"))
   572  
   573  			err := c.LoadCollection(ctx, testCollectionName, true)
   574  			s.Error(err)
   575  		})
   576  	})
   577  
   578  	s.Run("load_collection_failure", func() {
   579  		s.Run("failure_status", func() {
   580  			defer s.resetMock()
   581  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   582  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   583  
   584  			s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).
   585  				Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil)
   586  
   587  			err := c.LoadCollection(ctx, testCollectionName, true)
   588  			s.Error(err)
   589  		})
   590  
   591  		s.Run("return_error", func() {
   592  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   593  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   594  
   595  			s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).
   596  				Return(nil, errors.New("mock error"))
   597  
   598  			err := c.LoadCollection(ctx, testCollectionName, true)
   599  			s.Error(err)
   600  		})
   601  	})
   602  
   603  	s.Run("get_loading_progress_failure", func() {
   604  		defer s.resetMock()
   605  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   606  			Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   607  
   608  		s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   609  		s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).
   610  			Return(nil, errors.New("mock error"))
   611  
   612  		err := c.LoadCollection(ctx, testCollectionName, false)
   613  		s.Error(err)
   614  	})
   615  
   616  	s.Run("service_not_ready", func() {
   617  		c := &GrpcClient{}
   618  		err := c.LoadCollection(ctx, testCollectionName, false)
   619  		s.ErrorIs(err, ErrClientNotReady)
   620  	})
   621  }
   622  
   623  func TestCollectionSuite(t *testing.T) {
   624  	suite.Run(t, new(CollectionSuite))
   625  }
   626  
   627  // default HasCollection injection, returns true only when collection name is `testCollectionName`
   628  var hasCollectionDefault = func(_ context.Context, raw proto.Message) (proto.Message, error) {
   629  	req, ok := raw.(*milvuspb.HasCollectionRequest)
   630  	resp := &milvuspb.BoolResponse{}
   631  	if !ok {
   632  		s, err := BadRequestStatus()
   633  		resp.Status = s
   634  		return s, err
   635  	}
   636  	resp.Value = req.GetCollectionName() == testCollectionName
   637  	s, err := SuccessStatus()
   638  	resp.Status = s
   639  	return resp, err
   640  }
   641  
   642  func TestGrpcClientDropCollection(t *testing.T) {
   643  	ctx := context.Background()
   644  	c := testClient(ctx, t)
   645  
   646  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
   647  	mockServer.SetInjection(MDropCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   648  		req, ok := (raw).(*milvuspb.DropCollectionRequest)
   649  		if !ok {
   650  			return BadRequestStatus()
   651  		}
   652  		if req.GetCollectionName() != testCollectionName { // in mockServer.server, assume testCollection exists only
   653  			return BadRequestStatus()
   654  		}
   655  		return SuccessStatus()
   656  	})
   657  
   658  	t.Run("Test Normal drop", func(t *testing.T) {
   659  		assert.Nil(t, c.DropCollection(ctx, testCollectionName, WithDropCollectionMsgBase(&commonpb.MsgBase{})))
   660  	})
   661  
   662  	t.Run("Test drop non-existing collection", func(t *testing.T) {
   663  		assert.NotNil(t, c.DropCollection(ctx, "AAAAAAAAAANonExists"))
   664  	})
   665  }
   666  
   667  func TestReleaseCollection(t *testing.T) {
   668  	ctx := context.Background()
   669  
   670  	c := testClient(ctx, t)
   671  
   672  	mockServer.SetInjection(MReleaseCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   673  		req, ok := raw.(*milvuspb.ReleaseCollectionRequest)
   674  		if !ok {
   675  			return BadRequestStatus()
   676  		}
   677  		assert.Equal(t, testCollectionName, req.GetCollectionName())
   678  		return SuccessStatus()
   679  	})
   680  
   681  	c.ReleaseCollection(ctx, testCollectionName, WithReleaseCollectionMsgBase(&commonpb.MsgBase{}))
   682  }
   683  
   684  func TestGrpcClientHasCollection(t *testing.T) {
   685  	ctx := context.Background()
   686  
   687  	c := testClient(ctx, t)
   688  
   689  	mockServer.SetInjection(MHasCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   690  		req, ok := raw.(*milvuspb.HasCollectionRequest)
   691  		resp := &milvuspb.BoolResponse{}
   692  		if !ok {
   693  			s, err := BadRequestStatus()
   694  			assert.Fail(t, err.Error())
   695  			resp.Status = s
   696  			return resp, err
   697  		}
   698  		assert.Equal(t, req.CollectionName, testCollectionName)
   699  
   700  		s, err := SuccessStatus()
   701  		resp.Status, resp.Value = s, true
   702  		return resp, err
   703  	})
   704  
   705  	has, err := c.HasCollection(ctx, testCollectionName)
   706  	assert.Nil(t, err)
   707  	assert.True(t, has)
   708  }
   709  
   710  // return injection asserts collection name matchs
   711  // partition name request in partitionNames if flag is true
   712  func hasCollectionInjection(t *testing.T, mustIn bool, collNames ...string) func(context.Context, proto.Message) (proto.Message, error) {
   713  	return func(_ context.Context, raw proto.Message) (proto.Message, error) {
   714  		req, ok := raw.(*milvuspb.HasCollectionRequest)
   715  		resp := &milvuspb.BoolResponse{}
   716  		if !ok {
   717  			s, err := BadRequestStatus()
   718  			resp.Status = s
   719  			return resp, err
   720  		}
   721  		if mustIn {
   722  			resp.Value = assert.Contains(t, collNames, req.GetCollectionName())
   723  		} else {
   724  			for _, pn := range collNames {
   725  				if pn == req.GetCollectionName() {
   726  					resp.Value = true
   727  				}
   728  			}
   729  		}
   730  		s, err := SuccessStatus()
   731  		resp.Status = s
   732  		return resp, err
   733  	}
   734  }
   735  
   736  func describeCollectionInjection(t *testing.T, collID int64, collName string, sch *entity.Schema) func(_ context.Context, raw proto.Message) (proto.Message, error) {
   737  	return func(_ context.Context, raw proto.Message) (proto.Message, error) {
   738  		req, ok := raw.(*milvuspb.DescribeCollectionRequest)
   739  		resp := &milvuspb.DescribeCollectionResponse{}
   740  		if !ok {
   741  			s, err := BadRequestStatus()
   742  			resp.Status = s
   743  			return resp, err
   744  		}
   745  
   746  		assert.Equal(t, collName, req.GetCollectionName())
   747  
   748  		sch := sch
   749  		resp.Schema = sch.ProtoMessage()
   750  		resp.CollectionID = collID
   751  
   752  		s, err := SuccessStatus()
   753  		resp.Status = s
   754  
   755  		return resp, err
   756  	}
   757  }
   758  
   759  func TestGrpcClientDescribeCollection(t *testing.T) {
   760  	ctx := context.Background()
   761  
   762  	c := testClient(ctx, t)
   763  
   764  	collectionID := rand.Int63()
   765  
   766  	mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, collectionID, testCollectionName, defaultSchema()))
   767  
   768  	collection, err := c.DescribeCollection(ctx, testCollectionName)
   769  	assert.Nil(t, err)
   770  	if assert.NotNil(t, collection) {
   771  		assert.Equal(t, collectionID, collection.ID)
   772  	}
   773  }
   774  
   775  func TestGrpcClientGetCollectionStatistics(t *testing.T) {
   776  	ctx := context.Background()
   777  
   778  	c := testClient(ctx, t)
   779  
   780  	stat := make(map[string]string)
   781  	stat["row_count"] = "0"
   782  
   783  	mockServer.SetInjection(MGetCollectionStatistics, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   784  		req, ok := raw.(*milvuspb.GetCollectionStatisticsRequest)
   785  		resp := &milvuspb.GetCollectionStatisticsResponse{}
   786  		if !ok {
   787  			s, err := BadRequestStatus()
   788  			resp.Status = s
   789  			return resp, err
   790  		}
   791  		assert.Equal(t, testCollectionName, req.GetCollectionName())
   792  		s, err := SuccessStatus()
   793  		resp.Status, resp.Stats = s, entity.MapKvPairs(stat)
   794  		return resp, err
   795  	})
   796  
   797  	rStat, err := c.GetCollectionStatistics(ctx, testCollectionName)
   798  	assert.Nil(t, err)
   799  	if assert.NotNil(t, rStat) {
   800  		for k, v := range stat {
   801  			rv, has := rStat[k]
   802  			assert.True(t, has)
   803  			assert.Equal(t, v, rv)
   804  		}
   805  	}
   806  }
   807  
   808  func TestGrpcClientGetReplicas(t *testing.T) {
   809  	ctx := context.Background()
   810  	c := testClient(ctx, t)
   811  
   812  	replicaID := rand.Int63()
   813  	nodeIds := []int64{1, 2, 3, 4}
   814  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
   815  	defer mockServer.DelInjection(MHasCollection)
   816  
   817  	mockServer.SetInjection(MShowCollections, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   818  		s, err := SuccessStatus()
   819  		resp := &milvuspb.ShowCollectionsResponse{
   820  			Status:              s,
   821  			CollectionIds:       []int64{testCollectionID},
   822  			CollectionNames:     []string{testCollectionName},
   823  			InMemoryPercentages: []int64{100},
   824  		}
   825  		return resp, err
   826  	})
   827  	defer mockServer.DelInjection(MShowCollections)
   828  
   829  	mockServer.SetInjection(MGetReplicas, func(ctx context.Context, raw proto.Message) (proto.Message, error) {
   830  		req, ok := raw.(*milvuspb.GetReplicasRequest)
   831  		resp := &milvuspb.GetReplicasResponse{}
   832  		if !ok {
   833  			s, err := BadRequestStatus()
   834  			resp.Status = s
   835  			return resp, err
   836  		}
   837  
   838  		assert.Equal(t, testCollectionID, req.CollectionID)
   839  
   840  		s, err := SuccessStatus()
   841  		resp.Status = s
   842  		resp.Replicas = []*milvuspb.ReplicaInfo{{
   843  			ReplicaID: replicaID,
   844  			ShardReplicas: []*milvuspb.ShardReplica{
   845  				{
   846  					LeaderID:      1,
   847  					DmChannelName: "DML_channel_v1",
   848  				},
   849  				{
   850  					LeaderID:   2,
   851  					LeaderAddr: "DML_channel_v2",
   852  				},
   853  			},
   854  			NodeIds: nodeIds,
   855  		}}
   856  		return resp, err
   857  	})
   858  
   859  	t.Run("get replicas normal", func(t *testing.T) {
   860  		groups, err := c.GetReplicas(ctx, testCollectionName)
   861  		assert.Nil(t, err)
   862  		assert.NotNil(t, groups)
   863  		assert.Equal(t, 1, len(groups))
   864  
   865  		assert.Equal(t, replicaID, groups[0].ReplicaID)
   866  		assert.Equal(t, nodeIds, groups[0].NodeIDs)
   867  		assert.Equal(t, 2, len(groups[0].ShardReplicas))
   868  	})
   869  
   870  	t.Run("get replicas invalid name", func(t *testing.T) {
   871  		_, err := c.GetReplicas(ctx, "invalid name")
   872  		assert.Error(t, err)
   873  	})
   874  
   875  	t.Run("get replicas grpc error", func(t *testing.T) {
   876  		mockServer.SetInjection(MGetReplicas, func(ctx context.Context, raw proto.Message) (proto.Message, error) {
   877  			return &milvuspb.GetReplicasResponse{}, errors.New("mockServer.d grpc error")
   878  		})
   879  		_, err := c.GetReplicas(ctx, testCollectionName)
   880  		assert.Error(t, err)
   881  	})
   882  
   883  	t.Run("get replicas server error", func(t *testing.T) {
   884  		mockServer.SetInjection(MGetReplicas, func(ctx context.Context, raw proto.Message) (proto.Message, error) {
   885  			return &milvuspb.GetReplicasResponse{
   886  				Status: &commonpb.Status{
   887  					ErrorCode: commonpb.ErrorCode_UnexpectedError,
   888  					Reason:    "Service is not healthy",
   889  				},
   890  				Replicas: nil,
   891  			}, nil
   892  		})
   893  		_, err := c.GetReplicas(ctx, testCollectionName)
   894  		assert.Error(t, err)
   895  	})
   896  
   897  	mockServer.DelInjection(MGetReplicas)
   898  }
   899  
   900  func TestGrpcClientGetLoadingProgress(t *testing.T) {
   901  	ctx := context.Background()
   902  	c := testClient(ctx, t)
   903  
   904  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
   905  
   906  	mockServer.SetInjection(MGetLoadingProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   907  		req, ok := raw.(*milvuspb.GetLoadingProgressRequest)
   908  		if !ok {
   909  			return BadRequestStatus()
   910  		}
   911  		resp := &milvuspb.GetLoadingProgressResponse{}
   912  		if !ok {
   913  			s, err := BadRequestStatus()
   914  			resp.Status = s
   915  			return resp, err
   916  		}
   917  		assert.Equal(t, testCollectionName, req.GetCollectionName())
   918  		s, err := SuccessStatus()
   919  		resp.Status, resp.Progress = s, 100
   920  		return resp, err
   921  	})
   922  
   923  	progress, err := c.GetLoadingProgress(ctx, testCollectionName, []string{})
   924  	assert.NoError(t, err)
   925  	assert.Equal(t, int64(100), progress)
   926  }
   927  
   928  func TestGrpcClientGetLoadState(t *testing.T) {
   929  	ctx := context.Background()
   930  	c := testClient(ctx, t)
   931  
   932  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
   933  
   934  	mockServer.SetInjection(MGetLoadState, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   935  		req, ok := raw.(*milvuspb.GetLoadStateRequest)
   936  		if !ok {
   937  			return BadRequestStatus()
   938  		}
   939  		resp := &milvuspb.GetLoadStateResponse{}
   940  		if !ok {
   941  			s, err := BadRequestStatus()
   942  			resp.Status = s
   943  			return resp, err
   944  		}
   945  		assert.Equal(t, testCollectionName, req.GetCollectionName())
   946  		s, err := SuccessStatus()
   947  		resp.Status, resp.State = s, commonpb.LoadState_LoadStateLoaded
   948  		return resp, err
   949  	})
   950  
   951  	state, err := c.GetLoadState(ctx, testCollectionName, []string{})
   952  	assert.NoError(t, err)
   953  	assert.Equal(t, entity.LoadStateLoaded, state)
   954  }