github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/partition_test.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"testing"
     8  
     9  	"github.com/cockroachdb/errors"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
    13  	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
    14  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/mock"
    17  	"github.com/stretchr/testify/suite"
    18  )
    19  
    20  // return injection asserts collection name matchs
    21  // partition name request in partitionNames if flag is true
    22  func hasPartitionInjection(t *testing.T, collName string, mustIn bool, partitionNames ...string) func(context.Context, proto.Message) (proto.Message, error) {
    23  	return func(_ context.Context, raw proto.Message) (proto.Message, error) {
    24  		req, ok := raw.(*milvuspb.HasPartitionRequest)
    25  		resp := &milvuspb.BoolResponse{}
    26  		if !ok {
    27  			s, err := BadRequestStatus()
    28  			resp.Status = s
    29  			return resp, err
    30  		}
    31  		assert.Equal(t, collName, req.GetCollectionName())
    32  		if mustIn {
    33  			resp.Value = assert.Contains(t, partitionNames, req.GetPartitionName())
    34  		} else {
    35  			for _, pn := range partitionNames {
    36  				if pn == req.GetPartitionName() {
    37  					resp.Value = true
    38  				}
    39  			}
    40  		}
    41  		s, err := SuccessStatus()
    42  		resp.Status = s
    43  		return resp, err
    44  	}
    45  }
    46  
    47  func TestGrpcClientCreatePartition(t *testing.T) {
    48  
    49  	ctx := context.Background()
    50  	c := testClient(ctx, t)
    51  
    52  	partitionName := fmt.Sprintf("_part_%d", rand.Int())
    53  
    54  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
    55  	mockServer.SetInjection(MHasPartition, func(_ context.Context, raw proto.Message) (proto.Message, error) {
    56  		req, ok := raw.(*milvuspb.HasPartitionRequest)
    57  		resp := &milvuspb.BoolResponse{}
    58  		if !ok {
    59  			s, err := BadRequestStatus()
    60  			resp.Status = s
    61  			return resp, err
    62  		}
    63  		assert.Equal(t, testCollectionName, req.GetCollectionName())
    64  		assert.Equal(t, partitionName, req.GetPartitionName())
    65  		resp.Value = false
    66  		s, err := SuccessStatus()
    67  		resp.Status = s
    68  		return resp, err
    69  	})
    70  
    71  	assert.Nil(t, c.CreatePartition(ctx, testCollectionName, partitionName, WithCreatePartitionMsgBase(&commonpb.MsgBase{})))
    72  }
    73  
    74  func TestGrpcClientDropPartition(t *testing.T) {
    75  	partitionName := fmt.Sprintf("_part_%d", rand.Int())
    76  	ctx := context.Background()
    77  	c := testClient(ctx, t)
    78  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
    79  	mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, true, partitionName)) // injection has assertion of collName & parition name
    80  	assert.Nil(t, c.DropPartition(ctx, testCollectionName, partitionName, WithDropPartitionMsgBase(&commonpb.MsgBase{})))
    81  }
    82  
    83  func TestGrpcClientHasPartition(t *testing.T) {
    84  	partitionName := fmt.Sprintf("_part_%d", rand.Int())
    85  	ctx := context.Background()
    86  	c := testClient(ctx, t)
    87  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
    88  	mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, false, partitionName)) // injection has assertion of collName & parition name
    89  
    90  	r, err := c.HasPartition(ctx, testCollectionName, "_default_part")
    91  	assert.Nil(t, err)
    92  	assert.False(t, r)
    93  
    94  	r, err = c.HasPartition(ctx, testCollectionName, partitionName)
    95  	assert.Nil(t, err)
    96  	assert.True(t, r)
    97  }
    98  
    99  // default partition interception for ShowPartitions, generates testCollection related paritition data
   100  func getPartitionsInterception(t *testing.T, collName string, partitions ...*entity.Partition) func(ctx context.Context, raw proto.Message) (proto.Message, error) {
   101  	return func(ctx context.Context, raw proto.Message) (proto.Message, error) {
   102  		req, ok := raw.(*milvuspb.ShowPartitionsRequest)
   103  		resp := &milvuspb.ShowPartitionsResponse{}
   104  		if !ok {
   105  			s, err := BadRequestStatus()
   106  			resp.Status = s
   107  			return resp, err
   108  		}
   109  		assert.Equal(t, collName, req.GetCollectionName())
   110  		resp.PartitionIDs = make([]int64, 0, len(partitions))
   111  		resp.PartitionNames = make([]string, 0, len(partitions))
   112  		for _, part := range partitions {
   113  			resp.PartitionIDs = append(resp.PartitionIDs, part.ID)
   114  			resp.PartitionNames = append(resp.PartitionNames, part.Name)
   115  			resp.InMemoryPercentages = append(resp.InMemoryPercentages, 100)
   116  		}
   117  		s, err := SuccessStatus()
   118  		resp.Status = s
   119  		return resp, err
   120  	}
   121  }
   122  
   123  func TestGrpcClientShowPartitions(t *testing.T) {
   124  
   125  	ctx := context.Background()
   126  	c := testClient(ctx, t)
   127  
   128  	type testCase struct {
   129  		collName      string
   130  		partitions    []*entity.Partition
   131  		shouldSuccess bool
   132  	}
   133  	cases := []testCase{
   134  		{
   135  			collName: testCollectionName,
   136  			partitions: []*entity.Partition{
   137  				{
   138  					ID:   1,
   139  					Name: "_part1",
   140  				},
   141  				{
   142  					ID:   2,
   143  					Name: "_part2",
   144  				},
   145  				{
   146  					ID:   3,
   147  					Name: "_part3",
   148  				},
   149  			},
   150  			shouldSuccess: true,
   151  		},
   152  	}
   153  	for _, tc := range cases {
   154  		mockServer.SetInjection(MShowPartitions, getPartitionsInterception(t, tc.collName, tc.partitions...))
   155  		r, err := c.ShowPartitions(ctx, tc.collName)
   156  		if tc.shouldSuccess {
   157  			assert.Nil(t, err)
   158  			assert.NotNil(t, r)
   159  			if assert.Equal(t, len(tc.partitions), len(r)) {
   160  				for idx, part := range tc.partitions {
   161  					assert.Equal(t, part.ID, r[idx].ID)
   162  					assert.Equal(t, part.Name, r[idx].Name)
   163  				}
   164  			}
   165  		} else {
   166  			assert.NotNil(t, err)
   167  		}
   168  	}
   169  }
   170  
   171  func TestGrpcClientReleasePartitions(t *testing.T) {
   172  	ctx := context.Background()
   173  
   174  	c := testClient(ctx, t)
   175  
   176  	parts := []string{"_part1", "_part2"}
   177  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
   178  	mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, true, "_part1", "_part2", "_part3", "_part4"))
   179  	mockServer.SetInjection(MReleasePartitions, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   180  		req, ok := raw.(*milvuspb.ReleasePartitionsRequest)
   181  		if !ok {
   182  			return BadRequestStatus()
   183  		}
   184  		assert.Equal(t, testCollectionName, req.GetCollectionName())
   185  		assert.ElementsMatch(t, parts, req.GetPartitionNames())
   186  
   187  		return SuccessStatus()
   188  	})
   189  	defer mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, false, "testPart"))
   190  
   191  	assert.Nil(t, c.ReleasePartitions(ctx, testCollectionName, parts, WithReleasePartitionsMsgBase(&commonpb.MsgBase{})))
   192  }
   193  
   194  func TestGrpcShowPartitions(t *testing.T) {
   195  	ctx := context.Background()
   196  	c := testClient(ctx, t)
   197  
   198  	partitions := []*entity.Partition{
   199  		{
   200  			ID:   1,
   201  			Name: "_part1",
   202  		},
   203  		{
   204  			ID:   2,
   205  			Name: "_part2",
   206  		},
   207  		{
   208  			ID:   3,
   209  			Name: "_part3",
   210  		},
   211  	}
   212  
   213  	t.Run("normal show partitions", func(t *testing.T) {
   214  		mockServer.SetInjection(MShowPartitions, getPartitionsInterception(t, testCollectionName, partitions...))
   215  		parts, err := c.ShowPartitions(ctx, testCollectionName)
   216  		assert.NoError(t, err)
   217  		assert.NotNil(t, parts)
   218  	})
   219  
   220  	t.Run("bad response", func(t *testing.T) {
   221  		mockServer.SetInjection(MShowPartitions, func(ctx context.Context, raw proto.Message) (proto.Message, error) {
   222  			resp := &milvuspb.ShowPartitionsResponse{}
   223  			resp.PartitionIDs = make([]int64, 0, len(partitions))
   224  			for _, part := range partitions {
   225  				resp.PartitionIDs = append(resp.PartitionIDs, part.ID)
   226  			}
   227  			s, err := SuccessStatus()
   228  			resp.Status = s
   229  			return resp, err
   230  		})
   231  		_, err := c.ShowPartitions(ctx, testCollectionName)
   232  		assert.Error(t, err)
   233  	})
   234  
   235  	t.Run("Service error", func(t *testing.T) {
   236  		mockServer.SetInjection(MShowPartitions, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   237  			return &milvuspb.ShowPartitionsResponse{}, errors.New("always fail")
   238  		})
   239  		defer mockServer.DelInjection(MShowPartitions)
   240  
   241  		_, err := c.ShowPartitions(ctx, testCollectionName)
   242  		assert.Error(t, err)
   243  
   244  		mockServer.SetInjection(MShowPartitions, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   245  			return &milvuspb.ShowPartitionsResponse{
   246  				Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
   247  			}, nil
   248  		})
   249  		_, err = c.ShowPartitions(ctx, testCollectionName)
   250  		assert.Error(t, err)
   251  	})
   252  }
   253  
   254  type PartitionSuite struct {
   255  	MockSuiteBase
   256  }
   257  
   258  func (s *PartitionSuite) TestLoadPartitions() {
   259  	c := s.client
   260  	ctx, cancel := context.WithCancel(context.Background())
   261  	defer cancel()
   262  
   263  	partNames := []string{"part_1", "part_2"}
   264  	mPartNames := map[string]struct{}{"part_1": {}, "part_2": {}}
   265  
   266  	s.Run("normal_run_async", func() {
   267  		defer s.resetMock()
   268  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   269  			Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   270  		s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) {
   271  			s.Equal(testCollectionName, req.GetCollectionName())
   272  			_, ok := mPartNames[req.GetPartitionName()]
   273  			s.True(ok)
   274  		}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   275  		s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) {
   276  			s.Equal(testCollectionName, req.GetCollectionName())
   277  			s.ElementsMatch(partNames, req.GetPartitionNames())
   278  		}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   279  
   280  		err := c.LoadPartitions(ctx, testCollectionName, partNames, true, WithLoadPartitionsMsgBase(&commonpb.MsgBase{}))
   281  		s.NoError(err)
   282  	})
   283  
   284  	s.Run("normal_run_sync", func() {
   285  		defer s.resetMock()
   286  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   287  			Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   288  		s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) {
   289  			s.Equal(testCollectionName, req.GetCollectionName())
   290  			_, ok := mPartNames[req.GetPartitionName()]
   291  			s.True(ok)
   292  		}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   293  		s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) {
   294  			s.Equal(testCollectionName, req.GetCollectionName())
   295  			s.ElementsMatch(partNames, req.GetPartitionNames())
   296  		}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   297  		s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).
   298  			Return(&milvuspb.GetLoadingProgressResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Progress: 100}, nil)
   299  
   300  		err := c.LoadPartitions(ctx, testCollectionName, partNames, false)
   301  		s.NoError(err)
   302  	})
   303  
   304  	s.Run("has_collection_failure", func() {
   305  		s.Run("return_false", func() {
   306  			defer s.resetMock()
   307  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   308  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)
   309  
   310  			err := c.LoadPartitions(ctx, testCollectionName, partNames, false)
   311  			s.Error(err)
   312  		})
   313  
   314  		s.Run("return_error", func() {
   315  			defer s.resetMock()
   316  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   317  				Return(nil, errors.New("mock error"))
   318  
   319  			err := c.LoadPartitions(ctx, testCollectionName, partNames, false)
   320  			s.Error(err)
   321  		})
   322  	})
   323  
   324  	s.Run("has_partition_failure", func() {
   325  		s.Run("return_false", func() {
   326  			defer s.resetMock()
   327  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   328  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   329  			s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).
   330  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)
   331  
   332  			err := c.LoadPartitions(ctx, testCollectionName, partNames, false)
   333  			s.Error(err)
   334  		})
   335  
   336  		s.Run("return_error", func() {
   337  			defer s.resetMock()
   338  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   339  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   340  			s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).
   341  				Return(nil, errors.New("mock"))
   342  
   343  			err := c.LoadPartitions(ctx, testCollectionName, partNames, false)
   344  			s.Error(err)
   345  		})
   346  	})
   347  
   348  	s.Run("load_partitions_failure", func() {
   349  		s.Run("fail_status_code", func() {
   350  			defer s.resetMock()
   351  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   352  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   353  			s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) {
   354  				s.Equal(testCollectionName, req.GetCollectionName())
   355  				_, ok := mPartNames[req.GetPartitionName()]
   356  				s.True(ok)
   357  			}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   358  			s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) {
   359  				s.Equal(testCollectionName, req.GetCollectionName())
   360  				s.ElementsMatch(partNames, req.GetPartitionNames())
   361  			}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil)
   362  
   363  			err := c.LoadPartitions(ctx, testCollectionName, partNames, true)
   364  			s.Error(err)
   365  		})
   366  
   367  		s.Run("return_error", func() {
   368  			defer s.resetMock()
   369  			s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   370  				Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   371  			s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) {
   372  				s.Equal(testCollectionName, req.GetCollectionName())
   373  				_, ok := mPartNames[req.GetPartitionName()]
   374  				s.True(ok)
   375  			}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   376  			s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) {
   377  				s.Equal(testCollectionName, req.GetCollectionName())
   378  				s.ElementsMatch(partNames, req.GetPartitionNames())
   379  			}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil)
   380  
   381  			err := c.LoadPartitions(ctx, testCollectionName, partNames, true)
   382  			s.Error(err)
   383  		})
   384  	})
   385  
   386  	s.Run("get_loading_progress_failure", func() {
   387  		defer s.resetMock()
   388  		s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).
   389  			Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   390  		s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.HasPartitionRequest) {
   391  			s.Equal(testCollectionName, req.GetCollectionName())
   392  			_, ok := mPartNames[req.GetPartitionName()]
   393  			s.True(ok)
   394  		}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: true}, nil)
   395  		s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Run(func(_ context.Context, req *milvuspb.LoadPartitionsRequest) {
   396  			s.Equal(testCollectionName, req.GetCollectionName())
   397  			s.ElementsMatch(partNames, req.GetPartitionNames())
   398  		}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
   399  		s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).
   400  			Return(nil, errors.New("mock error"))
   401  
   402  		err := c.LoadPartitions(ctx, testCollectionName, partNames, false)
   403  		s.Error(err)
   404  	})
   405  
   406  	s.Run("service_not_ready", func() {
   407  		c := &GrpcClient{}
   408  		err := c.LoadPartitions(ctx, testCollectionName, partNames, false)
   409  		s.ErrorIs(err, ErrClientNotReady)
   410  	})
   411  }
   412  
   413  func TestPartitionSuite(t *testing.T) {
   414  	suite.Run(t, new(PartitionSuite))
   415  }