github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/index_test.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"math/rand"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/cockroachdb/errors"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
    13  	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
    14  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    15  	"github.com/stretchr/testify/assert"
    16  )
    17  
    18  func TestGrpcClientCreateIndex(t *testing.T) {
    19  	ctx := context.Background()
    20  	c := testClient(ctx, t)
    21  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
    22  	mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
    23  
    24  	fieldName := `vector`
    25  	idx, err := entity.NewIndexFlat(entity.IP)
    26  	assert.Nil(t, err)
    27  	if !assert.NotNil(t, idx) {
    28  		t.FailNow()
    29  	}
    30  	mockServer.SetInjection(MCreateIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) {
    31  		req, ok := raw.(*milvuspb.CreateIndexRequest)
    32  		if !ok {
    33  			return BadRequestStatus()
    34  		}
    35  		assert.Equal(t, testCollectionName, req.GetCollectionName())
    36  		assert.Equal(t, fieldName, req.GetFieldName())
    37  		return SuccessStatus()
    38  	})
    39  
    40  	t.Run("test async create index", func(t *testing.T) {
    41  		assert.Nil(t, c.CreateIndex(ctx, testCollectionName, fieldName, idx, true, WithIndexMsgBase(&commonpb.MsgBase{})))
    42  	})
    43  
    44  	t.Run("test sync create index", func(t *testing.T) {
    45  		buildTime := rand.Intn(900) + 100
    46  		start := time.Now()
    47  		flag := false
    48  		mockServer.SetInjection(MDescribeIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) {
    49  			req, ok := raw.(*milvuspb.DescribeIndexRequest)
    50  			resp := &milvuspb.DescribeIndexResponse{}
    51  			if !ok {
    52  				s, err := BadRequestStatus()
    53  				resp.Status = s
    54  				return resp, err
    55  			}
    56  			assert.Equal(t, testCollectionName, req.CollectionName)
    57  			assert.Equal(t, "test-index", req.IndexName)
    58  
    59  			resp.IndexDescriptions = []*milvuspb.IndexDescription{
    60  				{
    61  					IndexName: req.GetIndexName(),
    62  					FieldName: req.GetIndexName(),
    63  					State:     commonpb.IndexState_InProgress,
    64  				},
    65  			}
    66  			if time.Since(start) > time.Duration(buildTime)*time.Millisecond {
    67  				resp.IndexDescriptions[0].State = commonpb.IndexState_Finished
    68  				flag = true
    69  			}
    70  
    71  			s, err := SuccessStatus()
    72  			resp.Status = s
    73  			return resp, err
    74  		})
    75  
    76  		assert.Nil(t, c.CreateIndex(ctx, testCollectionName, fieldName, idx, false, WithIndexName("test-index")))
    77  		assert.True(t, flag)
    78  
    79  		mockServer.DelInjection(MDescribeIndex)
    80  	})
    81  }
    82  
    83  func TestGrpcClientDropIndex(t *testing.T) {
    84  	ctx := context.Background()
    85  	c := testClient(ctx, t)
    86  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
    87  	mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
    88  	assert.Nil(t, c.DropIndex(ctx, testCollectionName, "vector", WithIndexMsgBase(&commonpb.MsgBase{})))
    89  }
    90  
    91  func TestGrpcClientDescribeIndex(t *testing.T) {
    92  	ctx := context.Background()
    93  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
    94  	mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
    95  
    96  	c := testClient(ctx, t)
    97  
    98  	fieldName := "vector"
    99  
   100  	t.Run("normal describe index", func(t *testing.T) {
   101  		mockServer.SetInjection(MDescribeIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   102  			req, ok := raw.(*milvuspb.DescribeIndexRequest)
   103  			resp := &milvuspb.DescribeIndexResponse{}
   104  			if !ok {
   105  				s, err := BadRequestStatus()
   106  				resp.Status = s
   107  				return resp, err
   108  			}
   109  			assert.Equal(t, fieldName, req.GetFieldName())
   110  			assert.Equal(t, testCollectionName, req.GetCollectionName())
   111  			resp.IndexDescriptions = []*milvuspb.IndexDescription{
   112  				{
   113  					IndexName: "_default",
   114  					IndexID:   1,
   115  					FieldName: req.GetFieldName(),
   116  					Params: entity.MapKvPairs(map[string]string{
   117  						"nlist":       "1024",
   118  						"metric_type": "IP",
   119  					}),
   120  				},
   121  			}
   122  			s, err := SuccessStatus()
   123  			resp.Status = s
   124  			return resp, err
   125  		})
   126  
   127  		idxes, err := c.DescribeIndex(ctx, testCollectionName, fieldName)
   128  		assert.Nil(t, err)
   129  		assert.NotNil(t, idxes)
   130  	})
   131  
   132  	t.Run("Service return errors", func(t *testing.T) {
   133  		defer mockServer.DelInjection(MDescribeIndex)
   134  		mockServer.SetInjection(MDescribeIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   135  			resp := &milvuspb.DescribeIndexResponse{}
   136  
   137  			return resp, errors.New("mockServer.d error")
   138  		})
   139  
   140  		_, err := c.DescribeIndex(ctx, testCollectionName, fieldName)
   141  		assert.Error(t, err)
   142  
   143  		mockServer.SetInjection(MDescribeIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   144  			resp := &milvuspb.DescribeIndexResponse{}
   145  			resp.Status = &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}
   146  			return resp, nil
   147  		})
   148  
   149  		_, err = c.DescribeIndex(ctx, testCollectionName, fieldName)
   150  		assert.Error(t, err)
   151  	})
   152  }
   153  
   154  func TestGrpcGetIndexBuildProgress(t *testing.T) {
   155  	ctx := context.Background()
   156  	mockServer.SetInjection(MHasCollection, hasCollectionDefault)
   157  	mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
   158  
   159  	tc := testClient(ctx, t)
   160  	c := tc.(*GrpcClient) // since GetIndexBuildProgress is not exposed
   161  
   162  	t.Run("normal get index build progress", func(t *testing.T) {
   163  		var total, built int64
   164  
   165  		mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   166  			req, ok := raw.(*milvuspb.GetIndexBuildProgressRequest)
   167  			if !ok {
   168  				t.FailNow()
   169  			}
   170  			assert.Equal(t, testCollectionName, req.GetCollectionName())
   171  			resp := &milvuspb.GetIndexBuildProgressResponse{
   172  				TotalRows:   total,
   173  				IndexedRows: built,
   174  			}
   175  			s, err := SuccessStatus()
   176  			resp.Status = s
   177  			return resp, err
   178  		})
   179  
   180  		total = rand.Int63n(1000)
   181  		built = rand.Int63n(total)
   182  		rt, rb, err := c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField)
   183  		assert.NoError(t, err)
   184  		assert.Equal(t, total, rt)
   185  		assert.Equal(t, built, rb)
   186  	})
   187  
   188  	t.Run("Service return errors", func(t *testing.T) {
   189  		defer mockServer.DelInjection(MGetIndexBuildProgress)
   190  		mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   191  			_, ok := raw.(*milvuspb.GetIndexBuildProgressRequest)
   192  			if !ok {
   193  				t.FailNow()
   194  			}
   195  			resp := &milvuspb.GetIndexBuildProgressResponse{}
   196  			return resp, errors.New("mockServer.d error")
   197  		})
   198  
   199  		_, _, err := c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField)
   200  		assert.Error(t, err)
   201  
   202  		mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) {
   203  			_, ok := raw.(*milvuspb.GetIndexBuildProgressRequest)
   204  			if !ok {
   205  				t.FailNow()
   206  			}
   207  			resp := &milvuspb.GetIndexBuildProgressResponse{}
   208  			resp.Status = &commonpb.Status{
   209  				ErrorCode: commonpb.ErrorCode_UnexpectedError,
   210  			}
   211  			return resp, nil
   212  		})
   213  		_, _, err = c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField)
   214  		assert.Error(t, err)
   215  	})
   216  
   217  }