github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/collection_test.go (about) 1 package client 2 3 import ( 4 "context" 5 "fmt" 6 "math/rand" 7 "testing" 8 9 "github.com/cockroachdb/errors" 10 11 "github.com/golang/protobuf/proto" 12 "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" 13 "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" 14 "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" 15 "github.com/milvus-io/milvus-sdk-go/v2/entity" 16 "github.com/stretchr/testify/assert" 17 "github.com/stretchr/testify/mock" 18 "github.com/stretchr/testify/suite" 19 ) 20 21 type CollectionSuite struct { 22 MockSuiteBase 23 } 24 25 func (s *CollectionSuite) TestListCollections() { 26 c := s.client 27 ctx, cancel := context.WithCancel(context.Background()) 28 defer cancel() 29 30 type testCase struct { 31 ids []int64 32 names []string 33 collNum int 34 inMem []int64 35 } 36 caseLen := 5 37 cases := make([]testCase, 0, caseLen) 38 for i := 0; i < caseLen; i++ { 39 collNum := rand.Intn(5) + 2 40 tc := testCase{ 41 ids: make([]int64, 0, collNum), 42 names: make([]string, 0, collNum), 43 collNum: collNum, 44 } 45 base := rand.Intn(1000) 46 for j := 0; j < collNum; j++ { 47 base += rand.Intn(1000) 48 tc.ids = append(tc.ids, int64(base)) 49 base += rand.Intn(500) 50 tc.names = append(tc.names, fmt.Sprintf("coll_%d", base)) 51 inMem := rand.Intn(100) 52 if inMem%2 == 0 { 53 54 tc.inMem = append(tc.inMem, 100) 55 } else { 56 tc.inMem = append(tc.inMem, 0) 57 } 58 } 59 cases = append(cases, tc) 60 } 61 62 for i, tc := range cases { 63 s.Run(fmt.Sprintf("run_%d", i), func() { 64 s.resetMock() 65 s.mock.EXPECT().ShowCollections(mock.Anything, mock.AnythingOfType("*milvuspb.ShowCollectionsRequest")). 66 Return(&milvuspb.ShowCollectionsResponse{ 67 Status: &commonpb.Status{}, 68 CollectionIds: tc.ids, 69 CollectionNames: tc.names, 70 InMemoryPercentages: tc.inMem, 71 }, nil) 72 73 collections, err := c.ListCollections(ctx) 74 75 s.Require().Equal(tc.collNum, len(collections)) 76 s.Require().NoError(err) 77 78 // assert element match 79 rids := make([]int64, 0, len(collections)) 80 rnames := make([]string, 0, len(collections)) 81 for _, collection := range collections { 82 rids = append(rids, collection.ID) 83 rnames = append(rnames, collection.Name) 84 } 85 86 s.ElementsMatch(tc.ids, rids) 87 s.ElementsMatch(tc.names, rnames) 88 // assert id & name match 89 for idx, rid := range rids { 90 for jdx, id := range tc.ids { 91 if rid == id { 92 s.Equal(tc.names[idx], rnames[idx]) 93 s.Equal(tc.inMem[jdx] == 100, collections[idx].Loaded) 94 } 95 } 96 } 97 }) 98 } 99 } 100 101 func (s *CollectionSuite) TestCreateCollection() { 102 c := s.client 103 ctx, cancel := context.WithCancel(context.Background()) 104 defer cancel() 105 106 s.Run("normal_creation", func() { 107 ds := defaultSchema() 108 shardsNum := int32(1) 109 110 defer s.resetMock() 111 s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")). 112 Run(func(ctx context.Context, req *milvuspb.CreateCollectionRequest) { 113 s.Equal(testCollectionName, req.GetCollectionName()) 114 sschema := &schemapb.CollectionSchema{} 115 s.Require().NoError(proto.Unmarshal(req.GetSchema(), sschema)) 116 s.Require().Equal(len(ds.Fields), len(sschema.Fields)) 117 for idx, fieldSchema := range ds.Fields { 118 s.Equal(fieldSchema.Name, sschema.GetFields()[idx].GetName()) 119 s.Equal(fieldSchema.PrimaryKey, sschema.GetFields()[idx].GetIsPrimaryKey()) 120 s.Equal(fieldSchema.AutoID, sschema.GetFields()[idx].GetAutoID()) 121 s.EqualValues(fieldSchema.DataType, sschema.GetFields()[idx].GetDataType()) 122 } 123 s.Equal(shardsNum, req.GetShardsNum()) 124 s.Equal(commonpb.ConsistencyLevel_Bounded, req.GetConsistencyLevel()) 125 }). 126 Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 127 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil) 128 129 err := c.CreateCollection(ctx, ds, shardsNum, WithCreateCollectionMsgBase(&commonpb.MsgBase{})) 130 s.NoError(err) 131 }) 132 133 s.Run("create_with_consistency_level", func() { 134 ds := defaultSchema() 135 shardsNum := int32(1) 136 defer s.resetMock() 137 s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")). 138 Run(func(ctx context.Context, req *milvuspb.CreateCollectionRequest) { 139 s.Equal(testCollectionName, req.GetCollectionName()) 140 sschema := &schemapb.CollectionSchema{} 141 s.Require().NoError(proto.Unmarshal(req.GetSchema(), sschema)) 142 s.Require().Equal(len(ds.Fields), len(sschema.Fields)) 143 for idx, fieldSchema := range ds.Fields { 144 s.Equal(fieldSchema.Name, sschema.GetFields()[idx].GetName()) 145 s.Equal(fieldSchema.PrimaryKey, sschema.GetFields()[idx].GetIsPrimaryKey()) 146 s.Equal(fieldSchema.AutoID, sschema.GetFields()[idx].GetAutoID()) 147 s.EqualValues(fieldSchema.DataType, sschema.GetFields()[idx].GetDataType()) 148 } 149 s.Equal(shardsNum, req.GetShardsNum()) 150 s.Equal(commonpb.ConsistencyLevel_Eventually, req.GetConsistencyLevel()) 151 152 }). 153 Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 154 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil) 155 156 err := c.CreateCollection(ctx, ds, shardsNum, WithConsistencyLevel(entity.ClEventually)) 157 s.NoError(err) 158 }) 159 160 s.Run("invalid_schemas", func() { 161 162 type testCase struct { 163 name string 164 schema *entity.Schema 165 } 166 cases := []testCase{ 167 { 168 name: "empty_fields", 169 schema: entity.NewSchema().WithName(testCollectionName), 170 }, 171 { 172 name: "empty_collection_name", 173 schema: entity.NewSchema(). 174 WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). 175 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)), 176 }, 177 { 178 name: "multiple primary key", 179 schema: entity.NewSchema().WithName(testCollectionName). 180 WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). 181 WithField(entity.NewField().WithName("int64_2").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). 182 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)), 183 }, 184 { 185 name: "multiple auto id", 186 schema: entity.NewSchema().WithName(testCollectionName). 187 WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)). 188 WithField(entity.NewField().WithName("int64_2").WithDataType(entity.FieldTypeInt64).WithIsAutoID(true)). 189 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)), 190 }, 191 { 192 name: "bad_pk_type", 193 schema: entity.NewSchema(). 194 WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeDouble).WithIsPrimaryKey(true)). 195 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)), 196 }, 197 } 198 199 for _, tc := range cases { 200 s.Run(tc.name, func() { 201 err := c.CreateCollection(ctx, tc.schema, 1) 202 s.Error(err) 203 }) 204 } 205 }) 206 207 s.Run("server_returns_error", func() { 208 s.Run("create_collection_error", func() { 209 defer s.resetMock() 210 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil) 211 s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")). 212 Return(nil, errors.New("mocked grpc error")) 213 214 err := c.CreateCollection(ctx, defaultSchema(), 1) 215 s.Error(err) 216 }) 217 218 s.Run("create_collection_fail", func() { 219 defer s.resetMock() 220 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil) 221 s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")). 222 Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) 223 224 err := c.CreateCollection(ctx, defaultSchema(), 1) 225 s.Error(err) 226 }) 227 }) 228 229 s.Run("feature_not_support", func() { 230 cases := []struct { 231 tag string 232 flag uint64 233 }{ 234 {tag: "json", flag: disableJSON}, 235 {tag: "partition_key", flag: disableParitionKey}, 236 {tag: "dyanmic_schema", flag: disableDynamicSchema}, 237 } 238 sch := entity.NewSchema().WithName("all_feature").WithDynamicFieldEnabled(true). 239 WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). 240 WithField(entity.NewField().WithName("embedding").WithDataType(entity.FieldTypeFloatVector).WithDim(128)). 241 WithField(entity.NewField().WithName("partition").WithDataType(entity.FieldTypeInt64).WithIsPartitionKey(true)). 242 WithField(entity.NewField().WithName("dynamic").WithDataType(entity.FieldTypeJSON).WithIsDynamic(true)) 243 for _, tc := range cases { 244 s.Run(tc.tag, func() { 245 grpcClient, ok := c.(*GrpcClient) 246 s.Require().True(ok) 247 grpcClient.config.addFlags(tc.flag) 248 defer grpcClient.config.resetFlags(tc.flag) 249 250 err := c.CreateCollection(ctx, sch, 1) 251 s.Error(err) 252 s.ErrorIs(err, ErrFeatureNotSupported) 253 }) 254 } 255 }) 256 } 257 258 func (s *CollectionSuite) TestNewCollection() { 259 c := s.client 260 ctx, cancel := context.WithCancel(context.Background()) 261 262 defer cancel() 263 s.resetMock() 264 265 s.Run("all_default", func() { 266 defer s.resetMock() 267 268 created := false 269 s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")). 270 Run(func(ctx context.Context, req *milvuspb.CreateCollectionRequest) { 271 s.Equal(testCollectionName, req.GetCollectionName()) 272 sschema := &schemapb.CollectionSchema{} 273 s.Require().NoError(proto.Unmarshal(req.GetSchema(), sschema)) 274 s.Require().Equal(2, len(sschema.Fields)) 275 for _, field := range sschema.Fields { 276 if field.GetName() == "id" { 277 s.Equal(schemapb.DataType_Int64, field.GetDataType()) 278 } 279 if field.GetName() == "vector" { 280 s.Equal(schemapb.DataType_FloatVector, field.GetDataType()) 281 } 282 } 283 284 s.Equal(entity.DefaultShardNumber, req.GetShardsNum()) 285 s.Equal(entity.DefaultConsistencyLevel.CommonConsistencyLevel(), req.GetConsistencyLevel()) 286 created = true 287 }). 288 Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 289 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Call.Return(func(_ context.Context, _ *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse { 290 return &milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: created} 291 }, nil) 292 s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 293 s.mock.EXPECT().Flush(mock.Anything, mock.Anything).Return(&milvuspb.FlushResponse{ 294 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 295 CollSegIDs: map[string]*schemapb.LongArray{}, 296 }, nil) 297 s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&milvuspb.DescribeIndexResponse{ 298 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 299 IndexDescriptions: []*milvuspb.IndexDescription{ 300 {FieldName: "vector", State: commonpb.IndexState_Finished}, 301 }, 302 }, nil) 303 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 304 s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadingProgressResponse{ 305 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 306 Progress: 100, 307 }, nil) 308 s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ 309 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 310 Schema: &schemapb.CollectionSchema{ 311 Fields: []*schemapb.FieldSchema{ 312 {Name: "id", DataType: schemapb.DataType_VarChar}, 313 {Name: "vector", DataType: schemapb.DataType_FloatVector}, 314 }, 315 }, 316 }, nil) 317 318 err := c.NewCollection(ctx, testCollectionName, testVectorDim) 319 s.NoError(err) 320 }) 321 322 s.Run("with_custom_set", func() { 323 defer s.resetMock() 324 created := false 325 s.mock.EXPECT().CreateCollection(mock.Anything, mock.AnythingOfType("*milvuspb.CreateCollectionRequest")). 326 Run(func(ctx context.Context, req *milvuspb.CreateCollectionRequest) { 327 s.Equal(testCollectionName, req.GetCollectionName()) 328 sschema := &schemapb.CollectionSchema{} 329 s.Require().NoError(proto.Unmarshal(req.GetSchema(), sschema)) 330 s.Require().Equal(2, len(sschema.Fields)) 331 for _, field := range sschema.Fields { 332 if field.GetName() == "my_pk" { 333 s.Equal(schemapb.DataType_VarChar, field.GetDataType()) 334 } 335 if field.GetName() == "embedding" { 336 s.Equal(schemapb.DataType_FloatVector, field.GetDataType()) 337 } 338 } 339 340 s.Equal(entity.DefaultShardNumber, req.GetShardsNum()) 341 s.Equal(entity.ClEventually.CommonConsistencyLevel(), req.GetConsistencyLevel()) 342 created = true 343 }). 344 Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 345 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Call.Return(func(_ context.Context, _ *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse { 346 return &milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: created} 347 }, nil) 348 s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 349 s.mock.EXPECT().Flush(mock.Anything, mock.Anything).Return(&milvuspb.FlushResponse{ 350 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 351 CollSegIDs: map[string]*schemapb.LongArray{}, 352 }, nil) 353 s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&milvuspb.DescribeIndexResponse{ 354 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 355 IndexDescriptions: []*milvuspb.IndexDescription{ 356 {FieldName: "embedding", State: commonpb.IndexState_Finished}, 357 }, 358 }, nil) 359 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 360 s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadingProgressResponse{ 361 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 362 Progress: 100, 363 }, nil) 364 s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ 365 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 366 Schema: &schemapb.CollectionSchema{ 367 Fields: []*schemapb.FieldSchema{ 368 {Name: "my_pk", DataType: schemapb.DataType_VarChar}, 369 {Name: "embedding", DataType: schemapb.DataType_FloatVector}, 370 }, 371 }, 372 }, nil) 373 374 err := c.NewCollection(ctx, testCollectionName, testVectorDim, WithPKFieldName("my_pk"), WithPKFieldType(entity.FieldTypeVarChar), WithVectorFieldName("embedding"), WithConsistencyLevel(entity.ClEventually)) 375 s.NoError(err) 376 }) 377 } 378 379 func (s *CollectionSuite) TestRenameCollection() { 380 c := s.client 381 ctx, cancel := context.WithCancel(context.Background()) 382 defer cancel() 383 s.Run("normal_run", func() { 384 defer s.resetMock() 385 386 newCollName := fmt.Sprintf("new_%s", randStr(6)) 387 388 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 389 s.mock.EXPECT().RenameCollection(mock.Anything, &milvuspb.RenameCollectionRequest{OldName: testCollectionName, NewName: newCollName}).Return(&commonpb.Status{}, nil) 390 391 err := c.RenameCollection(ctx, testCollectionName, newCollName) 392 s.NoError(err) 393 }) 394 395 s.Run("coll_not_exist", func() { 396 defer s.resetMock() 397 398 newCollName := fmt.Sprintf("new_%s", randStr(6)) 399 400 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil) 401 402 err := c.RenameCollection(ctx, testCollectionName, newCollName) 403 s.Error(err) 404 }) 405 406 s.Run("rename_failed", func() { 407 defer s.resetMock() 408 409 newCollName := fmt.Sprintf("new_%s", randStr(6)) 410 411 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 412 s.mock.EXPECT().RenameCollection(mock.Anything, &milvuspb.RenameCollectionRequest{OldName: testCollectionName, NewName: newCollName}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mocked failure"}, nil) 413 414 err := c.RenameCollection(ctx, testCollectionName, newCollName) 415 s.Error(err) 416 }) 417 418 s.Run("rename_error", func() { 419 defer s.resetMock() 420 421 newCollName := fmt.Sprintf("new_%s", randStr(6)) 422 423 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 424 s.mock.EXPECT().RenameCollection(mock.Anything, &milvuspb.RenameCollectionRequest{OldName: testCollectionName, NewName: newCollName}).Return(nil, errors.New("mocked error")) 425 426 err := c.RenameCollection(ctx, testCollectionName, newCollName) 427 s.Error(err) 428 }) 429 } 430 431 func (s *CollectionSuite) TestAlterCollection() { 432 c := s.client 433 ctx, cancel := context.WithCancel(context.Background()) 434 defer cancel() 435 436 s.Run("normal_run", func() { 437 defer s.resetMock() 438 439 s.setupHasCollection(testCollectionName) 440 s.mock.EXPECT().AlterCollection(mock.Anything, mock.AnythingOfType("*milvuspb.AlterCollectionRequest")). 441 Return(&commonpb.Status{}, nil) 442 443 err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000)) 444 s.NoError(err) 445 }) 446 447 s.Run("collection_not_exist", func() { 448 defer s.resetMock() 449 450 s.mock.EXPECT().HasCollection(mock.Anything, mock.AnythingOfType("*milvuspb.HasCollectionRequest")). 451 Return(&milvuspb.BoolResponse{ 452 Status: &commonpb.Status{}, 453 Value: false, 454 }, nil) 455 456 err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000)) 457 s.Error(err) 458 }) 459 460 s.Run("no_attributes", func() { 461 defer s.resetMock() 462 463 s.setupHasCollection(testCollectionName) 464 err := c.AlterCollection(ctx, testCollectionName) 465 s.Error(err) 466 }) 467 468 s.Run("request_fails", func() { 469 defer s.resetMock() 470 471 s.setupHasCollection(testCollectionName) 472 s.mock.EXPECT().AlterCollection(mock.Anything, mock.AnythingOfType("*milvuspb.AlterCollectionRequest")). 473 Return(nil, errors.New("mocked")) 474 475 err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000)) 476 s.Error(err) 477 }) 478 479 s.Run("server_return_error", func() { 480 defer s.resetMock() 481 482 s.setupHasCollection(testCollectionName) 483 s.mock.EXPECT().AlterCollection(mock.Anything, mock.AnythingOfType("*milvuspb.AlterCollectionRequest")). 484 Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) 485 486 err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000)) 487 s.Error(err) 488 }) 489 490 s.Run("service_not_ready", func() { 491 c := &GrpcClient{} 492 err := c.AlterCollection(ctx, testCollectionName, entity.CollectionTTL(100000)) 493 s.ErrorIs(err, ErrClientNotReady) 494 }) 495 } 496 497 func (s *CollectionSuite) TestLoadCollection() { 498 ctx, cancel := context.WithCancel(context.Background()) 499 defer cancel() 500 501 c := s.client 502 503 s.Run("normal_run_async", func() { 504 defer s.resetMock() 505 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 506 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 507 508 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 509 510 err := c.LoadCollection(ctx, testCollectionName, true, WithLoadCollectionMsgBase(&commonpb.MsgBase{})) 511 s.NoError(err) 512 }) 513 514 s.Run("normal_run_sync", func() { 515 defer s.resetMock() 516 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 517 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 518 519 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 520 s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything). 521 Return(&milvuspb.GetLoadingProgressResponse{ 522 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 523 Progress: 100, 524 }, nil) 525 526 err := c.LoadCollection(ctx, testCollectionName, true) 527 s.NoError(err) 528 }) 529 530 s.Run("load_default_replica", func() { 531 defer s.resetMock() 532 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 533 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 534 535 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadCollectionRequest) { 536 s.Equal(testDefaultReplicaNumber, req.GetReplicaNumber()) 537 s.Equal(testCollectionName, req.GetCollectionName()) 538 }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 539 540 err := c.LoadCollection(ctx, testCollectionName, true) 541 s.NoError(err) 542 }) 543 544 s.Run("load_multiple_replica", func() { 545 defer s.resetMock() 546 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 547 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 548 549 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadCollectionRequest) { 550 s.Equal(testMultiReplicaNumber, req.GetReplicaNumber()) 551 s.Equal(testCollectionName, req.GetCollectionName()) 552 }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 553 554 err := c.LoadCollection(ctx, testCollectionName, true, WithReplicaNumber(testMultiReplicaNumber)) 555 s.NoError(err) 556 }) 557 558 s.Run("has_collection_failure", func() { 559 s.Run("return_false", func() { 560 defer s.resetMock() 561 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 562 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil) 563 564 err := c.LoadCollection(ctx, testCollectionName, true) 565 s.Error(err) 566 }) 567 568 s.Run("return_error", func() { 569 defer s.resetMock() 570 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 571 Return(nil, errors.New("mock error")) 572 573 err := c.LoadCollection(ctx, testCollectionName, true) 574 s.Error(err) 575 }) 576 }) 577 578 s.Run("load_collection_failure", func() { 579 s.Run("failure_status", func() { 580 defer s.resetMock() 581 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 582 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 583 584 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything). 585 Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) 586 587 err := c.LoadCollection(ctx, testCollectionName, true) 588 s.Error(err) 589 }) 590 591 s.Run("return_error", func() { 592 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 593 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 594 595 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything). 596 Return(nil, errors.New("mock error")) 597 598 err := c.LoadCollection(ctx, testCollectionName, true) 599 s.Error(err) 600 }) 601 }) 602 603 s.Run("get_loading_progress_failure", func() { 604 defer s.resetMock() 605 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 606 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 607 608 s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 609 s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything). 610 Return(nil, errors.New("mock error")) 611 612 err := c.LoadCollection(ctx, testCollectionName, false) 613 s.Error(err) 614 }) 615 616 s.Run("service_not_ready", func() { 617 c := &GrpcClient{} 618 err := c.LoadCollection(ctx, testCollectionName, false) 619 s.ErrorIs(err, ErrClientNotReady) 620 }) 621 } 622 623 func TestCollectionSuite(t *testing.T) { 624 suite.Run(t, new(CollectionSuite)) 625 } 626 627 // default HasCollection injection, returns true only when collection name is `testCollectionName` 628 var hasCollectionDefault = func(_ context.Context, raw proto.Message) (proto.Message, error) { 629 req, ok := raw.(*milvuspb.HasCollectionRequest) 630 resp := &milvuspb.BoolResponse{} 631 if !ok { 632 s, err := BadRequestStatus() 633 resp.Status = s 634 return s, err 635 } 636 resp.Value = req.GetCollectionName() == testCollectionName 637 s, err := SuccessStatus() 638 resp.Status = s 639 return resp, err 640 } 641 642 func TestGrpcClientDropCollection(t *testing.T) { 643 ctx := context.Background() 644 c := testClient(ctx, t) 645 646 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 647 mockServer.SetInjection(MDropCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) { 648 req, ok := (raw).(*milvuspb.DropCollectionRequest) 649 if !ok { 650 return BadRequestStatus() 651 } 652 if req.GetCollectionName() != testCollectionName { // in mockServer.server, assume testCollection exists only 653 return BadRequestStatus() 654 } 655 return SuccessStatus() 656 }) 657 658 t.Run("Test Normal drop", func(t *testing.T) { 659 assert.Nil(t, c.DropCollection(ctx, testCollectionName, WithDropCollectionMsgBase(&commonpb.MsgBase{}))) 660 }) 661 662 t.Run("Test drop non-existing collection", func(t *testing.T) { 663 assert.NotNil(t, c.DropCollection(ctx, "AAAAAAAAAANonExists")) 664 }) 665 } 666 667 func TestReleaseCollection(t *testing.T) { 668 ctx := context.Background() 669 670 c := testClient(ctx, t) 671 672 mockServer.SetInjection(MReleaseCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) { 673 req, ok := raw.(*milvuspb.ReleaseCollectionRequest) 674 if !ok { 675 return BadRequestStatus() 676 } 677 assert.Equal(t, testCollectionName, req.GetCollectionName()) 678 return SuccessStatus() 679 }) 680 681 c.ReleaseCollection(ctx, testCollectionName, WithReleaseCollectionMsgBase(&commonpb.MsgBase{})) 682 } 683 684 func TestGrpcClientHasCollection(t *testing.T) { 685 ctx := context.Background() 686 687 c := testClient(ctx, t) 688 689 mockServer.SetInjection(MHasCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) { 690 req, ok := raw.(*milvuspb.HasCollectionRequest) 691 resp := &milvuspb.BoolResponse{} 692 if !ok { 693 s, err := BadRequestStatus() 694 assert.Fail(t, err.Error()) 695 resp.Status = s 696 return resp, err 697 } 698 assert.Equal(t, req.CollectionName, testCollectionName) 699 700 s, err := SuccessStatus() 701 resp.Status, resp.Value = s, true 702 return resp, err 703 }) 704 705 has, err := c.HasCollection(ctx, testCollectionName) 706 assert.Nil(t, err) 707 assert.True(t, has) 708 } 709 710 // return injection asserts collection name matchs 711 // partition name request in partitionNames if flag is true 712 func hasCollectionInjection(t *testing.T, mustIn bool, collNames ...string) func(context.Context, proto.Message) (proto.Message, error) { 713 return func(_ context.Context, raw proto.Message) (proto.Message, error) { 714 req, ok := raw.(*milvuspb.HasCollectionRequest) 715 resp := &milvuspb.BoolResponse{} 716 if !ok { 717 s, err := BadRequestStatus() 718 resp.Status = s 719 return resp, err 720 } 721 if mustIn { 722 resp.Value = assert.Contains(t, collNames, req.GetCollectionName()) 723 } else { 724 for _, pn := range collNames { 725 if pn == req.GetCollectionName() { 726 resp.Value = true 727 } 728 } 729 } 730 s, err := SuccessStatus() 731 resp.Status = s 732 return resp, err 733 } 734 } 735 736 func describeCollectionInjection(t *testing.T, collID int64, collName string, sch *entity.Schema) func(_ context.Context, raw proto.Message) (proto.Message, error) { 737 return func(_ context.Context, raw proto.Message) (proto.Message, error) { 738 req, ok := raw.(*milvuspb.DescribeCollectionRequest) 739 resp := &milvuspb.DescribeCollectionResponse{} 740 if !ok { 741 s, err := BadRequestStatus() 742 resp.Status = s 743 return resp, err 744 } 745 746 assert.Equal(t, collName, req.GetCollectionName()) 747 748 sch := sch 749 resp.Schema = sch.ProtoMessage() 750 resp.CollectionID = collID 751 752 s, err := SuccessStatus() 753 resp.Status = s 754 755 return resp, err 756 } 757 } 758 759 func TestGrpcClientDescribeCollection(t *testing.T) { 760 ctx := context.Background() 761 762 c := testClient(ctx, t) 763 764 collectionID := rand.Int63() 765 766 mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, collectionID, testCollectionName, defaultSchema())) 767 768 collection, err := c.DescribeCollection(ctx, testCollectionName) 769 assert.Nil(t, err) 770 if assert.NotNil(t, collection) { 771 assert.Equal(t, collectionID, collection.ID) 772 } 773 } 774 775 func TestGrpcClientGetCollectionStatistics(t *testing.T) { 776 ctx := context.Background() 777 778 c := testClient(ctx, t) 779 780 stat := make(map[string]string) 781 stat["row_count"] = "0" 782 783 mockServer.SetInjection(MGetCollectionStatistics, func(_ context.Context, raw proto.Message) (proto.Message, error) { 784 req, ok := raw.(*milvuspb.GetCollectionStatisticsRequest) 785 resp := &milvuspb.GetCollectionStatisticsResponse{} 786 if !ok { 787 s, err := BadRequestStatus() 788 resp.Status = s 789 return resp, err 790 } 791 assert.Equal(t, testCollectionName, req.GetCollectionName()) 792 s, err := SuccessStatus() 793 resp.Status, resp.Stats = s, entity.MapKvPairs(stat) 794 return resp, err 795 }) 796 797 rStat, err := c.GetCollectionStatistics(ctx, testCollectionName) 798 assert.Nil(t, err) 799 if assert.NotNil(t, rStat) { 800 for k, v := range stat { 801 rv, has := rStat[k] 802 assert.True(t, has) 803 assert.Equal(t, v, rv) 804 } 805 } 806 } 807 808 func TestGrpcClientGetReplicas(t *testing.T) { 809 ctx := context.Background() 810 c := testClient(ctx, t) 811 812 replicaID := rand.Int63() 813 nodeIds := []int64{1, 2, 3, 4} 814 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 815 defer mockServer.DelInjection(MHasCollection) 816 817 mockServer.SetInjection(MShowCollections, func(_ context.Context, raw proto.Message) (proto.Message, error) { 818 s, err := SuccessStatus() 819 resp := &milvuspb.ShowCollectionsResponse{ 820 Status: s, 821 CollectionIds: []int64{testCollectionID}, 822 CollectionNames: []string{testCollectionName}, 823 InMemoryPercentages: []int64{100}, 824 } 825 return resp, err 826 }) 827 defer mockServer.DelInjection(MShowCollections) 828 829 mockServer.SetInjection(MGetReplicas, func(ctx context.Context, raw proto.Message) (proto.Message, error) { 830 req, ok := raw.(*milvuspb.GetReplicasRequest) 831 resp := &milvuspb.GetReplicasResponse{} 832 if !ok { 833 s, err := BadRequestStatus() 834 resp.Status = s 835 return resp, err 836 } 837 838 assert.Equal(t, testCollectionID, req.CollectionID) 839 840 s, err := SuccessStatus() 841 resp.Status = s 842 resp.Replicas = []*milvuspb.ReplicaInfo{{ 843 ReplicaID: replicaID, 844 ShardReplicas: []*milvuspb.ShardReplica{ 845 { 846 LeaderID: 1, 847 DmChannelName: "DML_channel_v1", 848 }, 849 { 850 LeaderID: 2, 851 LeaderAddr: "DML_channel_v2", 852 }, 853 }, 854 NodeIds: nodeIds, 855 }} 856 return resp, err 857 }) 858 859 t.Run("get replicas normal", func(t *testing.T) { 860 groups, err := c.GetReplicas(ctx, testCollectionName) 861 assert.Nil(t, err) 862 assert.NotNil(t, groups) 863 assert.Equal(t, 1, len(groups)) 864 865 assert.Equal(t, replicaID, groups[0].ReplicaID) 866 assert.Equal(t, nodeIds, groups[0].NodeIDs) 867 assert.Equal(t, 2, len(groups[0].ShardReplicas)) 868 }) 869 870 t.Run("get replicas invalid name", func(t *testing.T) { 871 _, err := c.GetReplicas(ctx, "invalid name") 872 assert.Error(t, err) 873 }) 874 875 t.Run("get replicas grpc error", func(t *testing.T) { 876 mockServer.SetInjection(MGetReplicas, func(ctx context.Context, raw proto.Message) (proto.Message, error) { 877 return &milvuspb.GetReplicasResponse{}, errors.New("mockServer.d grpc error") 878 }) 879 _, err := c.GetReplicas(ctx, testCollectionName) 880 assert.Error(t, err) 881 }) 882 883 t.Run("get replicas server error", func(t *testing.T) { 884 mockServer.SetInjection(MGetReplicas, func(ctx context.Context, raw proto.Message) (proto.Message, error) { 885 return &milvuspb.GetReplicasResponse{ 886 Status: &commonpb.Status{ 887 ErrorCode: commonpb.ErrorCode_UnexpectedError, 888 Reason: "Service is not healthy", 889 }, 890 Replicas: nil, 891 }, nil 892 }) 893 _, err := c.GetReplicas(ctx, testCollectionName) 894 assert.Error(t, err) 895 }) 896 897 mockServer.DelInjection(MGetReplicas) 898 } 899 900 func TestGrpcClientGetLoadingProgress(t *testing.T) { 901 ctx := context.Background() 902 c := testClient(ctx, t) 903 904 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 905 906 mockServer.SetInjection(MGetLoadingProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) { 907 req, ok := raw.(*milvuspb.GetLoadingProgressRequest) 908 if !ok { 909 return BadRequestStatus() 910 } 911 resp := &milvuspb.GetLoadingProgressResponse{} 912 if !ok { 913 s, err := BadRequestStatus() 914 resp.Status = s 915 return resp, err 916 } 917 assert.Equal(t, testCollectionName, req.GetCollectionName()) 918 s, err := SuccessStatus() 919 resp.Status, resp.Progress = s, 100 920 return resp, err 921 }) 922 923 progress, err := c.GetLoadingProgress(ctx, testCollectionName, []string{}) 924 assert.NoError(t, err) 925 assert.Equal(t, int64(100), progress) 926 } 927 928 func TestGrpcClientGetLoadState(t *testing.T) { 929 ctx := context.Background() 930 c := testClient(ctx, t) 931 932 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 933 934 mockServer.SetInjection(MGetLoadState, func(_ context.Context, raw proto.Message) (proto.Message, error) { 935 req, ok := raw.(*milvuspb.GetLoadStateRequest) 936 if !ok { 937 return BadRequestStatus() 938 } 939 resp := &milvuspb.GetLoadStateResponse{} 940 if !ok { 941 s, err := BadRequestStatus() 942 resp.Status = s 943 return resp, err 944 } 945 assert.Equal(t, testCollectionName, req.GetCollectionName()) 946 s, err := SuccessStatus() 947 resp.Status, resp.State = s, commonpb.LoadState_LoadStateLoaded 948 return resp, err 949 }) 950 951 state, err := c.GetLoadState(ctx, testCollectionName, []string{}) 952 assert.NoError(t, err) 953 assert.Equal(t, entity.LoadStateLoaded, state) 954 }