github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/partition_test.go (about) 1 package client 2 3 import ( 4 "context" 5 "fmt" 6 "math/rand" 7 "testing" 8 9 "github.com/cockroachdb/errors" 10 11 "github.com/golang/protobuf/proto" 12 "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" 13 "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" 14 "github.com/milvus-io/milvus-sdk-go/v2/entity" 15 "github.com/stretchr/testify/assert" 16 "github.com/stretchr/testify/mock" 17 "github.com/stretchr/testify/suite" 18 ) 19 20 // return injection asserts collection name matchs 21 // partition name request in partitionNames if flag is true 22 func hasPartitionInjection(t *testing.T, collName string, mustIn bool, partitionNames ...string) func(context.Context, proto.Message) (proto.Message, error) { 23 return func(_ context.Context, raw proto.Message) (proto.Message, error) { 24 req, ok := raw.(*milvuspb.HasPartitionRequest) 25 resp := &milvuspb.BoolResponse{} 26 if !ok { 27 s, err := BadRequestStatus() 28 resp.Status = s 29 return resp, err 30 } 31 assert.Equal(t, collName, req.GetCollectionName()) 32 if mustIn { 33 resp.Value = assert.Contains(t, partitionNames, req.GetPartitionName()) 34 } else { 35 for _, pn := range partitionNames { 36 if pn == req.GetPartitionName() { 37 resp.Value = true 38 } 39 } 40 } 41 s, err := SuccessStatus() 42 resp.Status = s 43 return resp, err 44 } 45 } 46 47 func TestGrpcClientCreatePartition(t *testing.T) { 48 49 ctx := context.Background() 50 c := testClient(ctx, t) 51 52 partitionName := fmt.Sprintf("_part_%d", rand.Int()) 53 54 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 55 mockServer.SetInjection(MHasPartition, func(_ context.Context, raw proto.Message) (proto.Message, error) { 56 req, ok := raw.(*milvuspb.HasPartitionRequest) 57 resp := &milvuspb.BoolResponse{} 58 if !ok { 59 s, err := BadRequestStatus() 60 resp.Status = s 61 return resp, err 62 } 63 assert.Equal(t, testCollectionName, req.GetCollectionName()) 64 assert.Equal(t, partitionName, req.GetPartitionName()) 65 resp.Value = false 66 s, err := SuccessStatus() 67 resp.Status = s 68 return resp, err 69 }) 70 71 assert.Nil(t, c.CreatePartition(ctx, testCollectionName, partitionName, WithCreatePartitionMsgBase(&commonpb.MsgBase{}))) 72 } 73 74 func TestGrpcClientDropPartition(t *testing.T) { 75 partitionName := fmt.Sprintf("_part_%d", rand.Int()) 76 ctx := context.Background() 77 c := testClient(ctx, t) 78 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 79 mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, true, partitionName)) // injection has assertion of collName & parition name 80 assert.Nil(t, c.DropPartition(ctx, testCollectionName, partitionName, WithDropPartitionMsgBase(&commonpb.MsgBase{}))) 81 } 82 83 func TestGrpcClientHasPartition(t *testing.T) { 84 partitionName := fmt.Sprintf("_part_%d", rand.Int()) 85 ctx := context.Background() 86 c := testClient(ctx, t) 87 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 88 mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, false, partitionName)) // injection has assertion of collName & parition name 89 90 r, err := c.HasPartition(ctx, testCollectionName, "_default_part") 91 assert.Nil(t, err) 92 assert.False(t, r) 93 94 r, err = c.HasPartition(ctx, testCollectionName, partitionName) 95 assert.Nil(t, err) 96 assert.True(t, r) 97 } 98 99 // default partition interception for ShowPartitions, generates testCollection related paritition data 100 func getPartitionsInterception(t *testing.T, collName string, partitions ...*entity.Partition) func(ctx context.Context, raw proto.Message) (proto.Message, error) { 101 return func(ctx context.Context, raw proto.Message) (proto.Message, error) { 102 req, ok := raw.(*milvuspb.ShowPartitionsRequest) 103 resp := &milvuspb.ShowPartitionsResponse{} 104 if !ok { 105 s, err := BadRequestStatus() 106 resp.Status = s 107 return resp, err 108 } 109 assert.Equal(t, collName, req.GetCollectionName()) 110 resp.PartitionIDs = make([]int64, 0, len(partitions)) 111 resp.PartitionNames = make([]string, 0, len(partitions)) 112 for _, part := range partitions { 113 resp.PartitionIDs = append(resp.PartitionIDs, part.ID) 114 resp.PartitionNames = append(resp.PartitionNames, part.Name) 115 resp.InMemoryPercentages = append(resp.InMemoryPercentages, 100) 116 } 117 s, err := SuccessStatus() 118 resp.Status = s 119 return resp, err 120 } 121 } 122 123 func TestGrpcClientShowPartitions(t *testing.T) { 124 125 ctx := context.Background() 126 c := testClient(ctx, t) 127 128 type testCase struct { 129 collName string 130 partitions []*entity.Partition 131 shouldSuccess bool 132 } 133 cases := []testCase{ 134 { 135 collName: testCollectionName, 136 partitions: []*entity.Partition{ 137 { 138 ID: 1, 139 Name: "_part1", 140 }, 141 { 142 ID: 2, 143 Name: "_part2", 144 }, 145 { 146 ID: 3, 147 Name: "_part3", 148 }, 149 }, 150 shouldSuccess: true, 151 }, 152 } 153 for _, tc := range cases { 154 mockServer.SetInjection(MShowPartitions, getPartitionsInterception(t, tc.collName, tc.partitions...)) 155 r, err := c.ShowPartitions(ctx, tc.collName) 156 if tc.shouldSuccess { 157 assert.Nil(t, err) 158 assert.NotNil(t, r) 159 if assert.Equal(t, len(tc.partitions), len(r)) { 160 for idx, part := range tc.partitions { 161 assert.Equal(t, part.ID, r[idx].ID) 162 assert.Equal(t, part.Name, r[idx].Name) 163 } 164 } 165 } else { 166 assert.NotNil(t, err) 167 } 168 } 169 } 170 171 func TestGrpcClientReleasePartitions(t *testing.T) { 172 ctx := context.Background() 173 174 c := testClient(ctx, t) 175 176 parts := []string{"_part1", "_part2"} 177 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 178 mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, true, "_part1", "_part2", "_part3", "_part4")) 179 mockServer.SetInjection(MReleasePartitions, func(_ context.Context, raw proto.Message) (proto.Message, error) { 180 req, ok := raw.(*milvuspb.ReleasePartitionsRequest) 181 if !ok { 182 return BadRequestStatus() 183 } 184 assert.Equal(t, testCollectionName, req.GetCollectionName()) 185 assert.ElementsMatch(t, parts, req.GetPartitionNames()) 186 187 return SuccessStatus() 188 }) 189 defer mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, false, "testPart")) 190 191 assert.Nil(t, c.ReleasePartitions(ctx, testCollectionName, parts, WithReleasePartitionsMsgBase(&commonpb.MsgBase{}))) 192 } 193 194 func TestGrpcShowPartitions(t *testing.T) { 195 ctx := context.Background() 196 c := testClient(ctx, t) 197 198 partitions := []*entity.Partition{ 199 { 200 ID: 1, 201 Name: "_part1", 202 }, 203 { 204 ID: 2, 205 Name: "_part2", 206 }, 207 { 208 ID: 3, 209 Name: "_part3", 210 }, 211 } 212 213 t.Run("normal show partitions", func(t *testing.T) { 214 mockServer.SetInjection(MShowPartitions, getPartitionsInterception(t, testCollectionName, partitions...)) 215 parts, err := c.ShowPartitions(ctx, testCollectionName) 216 assert.NoError(t, err) 217 assert.NotNil(t, parts) 218 }) 219 220 t.Run("bad response", func(t *testing.T) { 221 mockServer.SetInjection(MShowPartitions, func(ctx context.Context, raw proto.Message) (proto.Message, error) { 222 resp := &milvuspb.ShowPartitionsResponse{} 223 resp.PartitionIDs = make([]int64, 0, len(partitions)) 224 for _, part := range partitions { 225 resp.PartitionIDs = append(resp.PartitionIDs, part.ID) 226 } 227 s, err := SuccessStatus() 228 resp.Status = s 229 return resp, err 230 }) 231 _, err := c.ShowPartitions(ctx, testCollectionName) 232 assert.Error(t, err) 233 }) 234 235 t.Run("Service error", func(t *testing.T) { 236 mockServer.SetInjection(MShowPartitions, func(_ context.Context, raw proto.Message) (proto.Message, error) { 237 return &milvuspb.ShowPartitionsResponse{}, errors.New("always fail") 238 }) 239 defer mockServer.DelInjection(MShowPartitions) 240 241 _, err := c.ShowPartitions(ctx, testCollectionName) 242 assert.Error(t, err) 243 244 mockServer.SetInjection(MShowPartitions, func(_ context.Context, raw proto.Message) (proto.Message, error) { 245 return &milvuspb.ShowPartitionsResponse{ 246 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, 247 }, nil 248 }) 249 _, err = c.ShowPartitions(ctx, testCollectionName) 250 assert.Error(t, err) 251 }) 252 } 253 254 type PartitionSuite struct { 255 MockSuiteBase 256 } 257 258 func (s *PartitionSuite) TestLoadPartitions() { 259 c := s.client 260 ctx, cancel := context.WithCancel(context.Background()) 261 defer cancel() 262 263 partNames := []string{"part_1", "part_2"} 264 mPartNames := map[string]struct{}{"part_1": {}, "part_2": {}} 265 266 s.Run("normal_run_async", func() { 267 defer s.resetMock() 268 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 269 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 270 s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) { 271 s.Equal(testCollectionName, req.GetCollectionName()) 272 _, ok := mPartNames[req.GetPartitionName()] 273 s.True(ok) 274 }).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 275 s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) { 276 s.Equal(testCollectionName, req.GetCollectionName()) 277 s.ElementsMatch(partNames, req.GetPartitionNames()) 278 }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 279 280 err := c.LoadPartitions(ctx, testCollectionName, partNames, true, WithLoadPartitionsMsgBase(&commonpb.MsgBase{})) 281 s.NoError(err) 282 }) 283 284 s.Run("normal_run_sync", func() { 285 defer s.resetMock() 286 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 287 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 288 s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) { 289 s.Equal(testCollectionName, req.GetCollectionName()) 290 _, ok := mPartNames[req.GetPartitionName()] 291 s.True(ok) 292 }).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 293 s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) { 294 s.Equal(testCollectionName, req.GetCollectionName()) 295 s.ElementsMatch(partNames, req.GetPartitionNames()) 296 }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 297 s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything). 298 Return(&milvuspb.GetLoadingProgressResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Progress: 100}, nil) 299 300 err := c.LoadPartitions(ctx, testCollectionName, partNames, false) 301 s.NoError(err) 302 }) 303 304 s.Run("has_collection_failure", func() { 305 s.Run("return_false", func() { 306 defer s.resetMock() 307 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 308 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil) 309 310 err := c.LoadPartitions(ctx, testCollectionName, partNames, false) 311 s.Error(err) 312 }) 313 314 s.Run("return_error", func() { 315 defer s.resetMock() 316 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 317 Return(nil, errors.New("mock error")) 318 319 err := c.LoadPartitions(ctx, testCollectionName, partNames, false) 320 s.Error(err) 321 }) 322 }) 323 324 s.Run("has_partition_failure", func() { 325 s.Run("return_false", func() { 326 defer s.resetMock() 327 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 328 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 329 s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything). 330 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil) 331 332 err := c.LoadPartitions(ctx, testCollectionName, partNames, false) 333 s.Error(err) 334 }) 335 336 s.Run("return_error", func() { 337 defer s.resetMock() 338 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 339 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 340 s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything). 341 Return(nil, errors.New("mock")) 342 343 err := c.LoadPartitions(ctx, testCollectionName, partNames, false) 344 s.Error(err) 345 }) 346 }) 347 348 s.Run("load_partitions_failure", func() { 349 s.Run("fail_status_code", func() { 350 defer s.resetMock() 351 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 352 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 353 s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) { 354 s.Equal(testCollectionName, req.GetCollectionName()) 355 _, ok := mPartNames[req.GetPartitionName()] 356 s.True(ok) 357 }).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 358 s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) { 359 s.Equal(testCollectionName, req.GetCollectionName()) 360 s.ElementsMatch(partNames, req.GetPartitionNames()) 361 }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) 362 363 err := c.LoadPartitions(ctx, testCollectionName, partNames, true) 364 s.Error(err) 365 }) 366 367 s.Run("return_error", func() { 368 defer s.resetMock() 369 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 370 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 371 s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) { 372 s.Equal(testCollectionName, req.GetCollectionName()) 373 _, ok := mPartNames[req.GetPartitionName()] 374 s.True(ok) 375 }).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 376 s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) { 377 s.Equal(testCollectionName, req.GetCollectionName()) 378 s.ElementsMatch(partNames, req.GetPartitionNames()) 379 }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) 380 381 err := c.LoadPartitions(ctx, testCollectionName, partNames, true) 382 s.Error(err) 383 }) 384 }) 385 386 s.Run("get_loading_progress_failure", func() { 387 defer s.resetMock() 388 s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}). 389 Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 390 s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) { 391 s.Equal(testCollectionName, req.GetCollectionName()) 392 _, ok := mPartNames[req.GetPartitionName()] 393 s.True(ok) 394 }).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil) 395 s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) { 396 s.Equal(testCollectionName, req.GetCollectionName()) 397 s.ElementsMatch(partNames, req.GetPartitionNames()) 398 }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) 399 s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything). 400 Return(nil, errors.New("mock error")) 401 402 err := c.LoadPartitions(ctx, testCollectionName, partNames, false) 403 s.Error(err) 404 }) 405 406 s.Run("service_not_ready", func() { 407 c := &GrpcClient{} 408 err := c.LoadPartitions(ctx, testCollectionName, partNames, false) 409 s.ErrorIs(err, ErrClientNotReady) 410 }) 411 } 412 413 func TestPartitionSuite(t *testing.T) { 414 suite.Run(t, new(PartitionSuite)) 415 }