github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/index_test.go (about) 1 package client 2 3 import ( 4 "context" 5 "math/rand" 6 "testing" 7 "time" 8 9 "github.com/cockroachdb/errors" 10 11 "github.com/golang/protobuf/proto" 12 "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" 13 "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" 14 "github.com/milvus-io/milvus-sdk-go/v2/entity" 15 "github.com/stretchr/testify/assert" 16 ) 17 18 func TestGrpcClientCreateIndex(t *testing.T) { 19 ctx := context.Background() 20 c := testClient(ctx, t) 21 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 22 mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) 23 24 fieldName := `vector` 25 idx, err := entity.NewIndexFlat(entity.IP) 26 assert.Nil(t, err) 27 if !assert.NotNil(t, idx) { 28 t.FailNow() 29 } 30 mockServer.SetInjection(MCreateIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) { 31 req, ok := raw.(*milvuspb.CreateIndexRequest) 32 if !ok { 33 return BadRequestStatus() 34 } 35 assert.Equal(t, testCollectionName, req.GetCollectionName()) 36 assert.Equal(t, fieldName, req.GetFieldName()) 37 return SuccessStatus() 38 }) 39 40 t.Run("test async create index", func(t *testing.T) { 41 assert.Nil(t, c.CreateIndex(ctx, testCollectionName, fieldName, idx, true, WithIndexMsgBase(&commonpb.MsgBase{}))) 42 }) 43 44 t.Run("test sync create index", func(t *testing.T) { 45 buildTime := rand.Intn(900) + 100 46 start := time.Now() 47 flag := false 48 mockServer.SetInjection(MDescribeIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) { 49 req, ok := raw.(*milvuspb.DescribeIndexRequest) 50 resp := &milvuspb.DescribeIndexResponse{} 51 if !ok { 52 s, err := BadRequestStatus() 53 resp.Status = s 54 return resp, err 55 } 56 assert.Equal(t, testCollectionName, req.CollectionName) 57 assert.Equal(t, "test-index", req.IndexName) 58 59 resp.IndexDescriptions = []*milvuspb.IndexDescription{ 60 { 61 IndexName: req.GetIndexName(), 62 FieldName: req.GetIndexName(), 63 State: commonpb.IndexState_InProgress, 64 }, 65 } 66 if time.Since(start) > time.Duration(buildTime)*time.Millisecond { 67 resp.IndexDescriptions[0].State = commonpb.IndexState_Finished 68 flag = true 69 } 70 71 s, err := SuccessStatus() 72 resp.Status = s 73 return resp, err 74 }) 75 76 assert.Nil(t, c.CreateIndex(ctx, testCollectionName, fieldName, idx, false, WithIndexName("test-index"))) 77 assert.True(t, flag) 78 79 mockServer.DelInjection(MDescribeIndex) 80 }) 81 } 82 83 func TestGrpcClientDropIndex(t *testing.T) { 84 ctx := context.Background() 85 c := testClient(ctx, t) 86 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 87 mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) 88 assert.Nil(t, c.DropIndex(ctx, testCollectionName, "vector", WithIndexMsgBase(&commonpb.MsgBase{}))) 89 } 90 91 func TestGrpcClientDescribeIndex(t *testing.T) { 92 ctx := context.Background() 93 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 94 mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) 95 96 c := testClient(ctx, t) 97 98 fieldName := "vector" 99 100 t.Run("normal describe index", func(t *testing.T) { 101 mockServer.SetInjection(MDescribeIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) { 102 req, ok := raw.(*milvuspb.DescribeIndexRequest) 103 resp := &milvuspb.DescribeIndexResponse{} 104 if !ok { 105 s, err := BadRequestStatus() 106 resp.Status = s 107 return resp, err 108 } 109 assert.Equal(t, fieldName, req.GetFieldName()) 110 assert.Equal(t, testCollectionName, req.GetCollectionName()) 111 resp.IndexDescriptions = []*milvuspb.IndexDescription{ 112 { 113 IndexName: "_default", 114 IndexID: 1, 115 FieldName: req.GetFieldName(), 116 Params: entity.MapKvPairs(map[string]string{ 117 "nlist": "1024", 118 "metric_type": "IP", 119 }), 120 }, 121 } 122 s, err := SuccessStatus() 123 resp.Status = s 124 return resp, err 125 }) 126 127 idxes, err := c.DescribeIndex(ctx, testCollectionName, fieldName) 128 assert.Nil(t, err) 129 assert.NotNil(t, idxes) 130 }) 131 132 t.Run("Service return errors", func(t *testing.T) { 133 defer mockServer.DelInjection(MDescribeIndex) 134 mockServer.SetInjection(MDescribeIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) { 135 resp := &milvuspb.DescribeIndexResponse{} 136 137 return resp, errors.New("mockServer.d error") 138 }) 139 140 _, err := c.DescribeIndex(ctx, testCollectionName, fieldName) 141 assert.Error(t, err) 142 143 mockServer.SetInjection(MDescribeIndex, func(_ context.Context, raw proto.Message) (proto.Message, error) { 144 resp := &milvuspb.DescribeIndexResponse{} 145 resp.Status = &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError} 146 return resp, nil 147 }) 148 149 _, err = c.DescribeIndex(ctx, testCollectionName, fieldName) 150 assert.Error(t, err) 151 }) 152 } 153 154 func TestGrpcGetIndexBuildProgress(t *testing.T) { 155 ctx := context.Background() 156 mockServer.SetInjection(MHasCollection, hasCollectionDefault) 157 mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) 158 159 tc := testClient(ctx, t) 160 c := tc.(*GrpcClient) // since GetIndexBuildProgress is not exposed 161 162 t.Run("normal get index build progress", func(t *testing.T) { 163 var total, built int64 164 165 mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) { 166 req, ok := raw.(*milvuspb.GetIndexBuildProgressRequest) 167 if !ok { 168 t.FailNow() 169 } 170 assert.Equal(t, testCollectionName, req.GetCollectionName()) 171 resp := &milvuspb.GetIndexBuildProgressResponse{ 172 TotalRows: total, 173 IndexedRows: built, 174 } 175 s, err := SuccessStatus() 176 resp.Status = s 177 return resp, err 178 }) 179 180 total = rand.Int63n(1000) 181 built = rand.Int63n(total) 182 rt, rb, err := c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField) 183 assert.NoError(t, err) 184 assert.Equal(t, total, rt) 185 assert.Equal(t, built, rb) 186 }) 187 188 t.Run("Service return errors", func(t *testing.T) { 189 defer mockServer.DelInjection(MGetIndexBuildProgress) 190 mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) { 191 _, ok := raw.(*milvuspb.GetIndexBuildProgressRequest) 192 if !ok { 193 t.FailNow() 194 } 195 resp := &milvuspb.GetIndexBuildProgressResponse{} 196 return resp, errors.New("mockServer.d error") 197 }) 198 199 _, _, err := c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField) 200 assert.Error(t, err) 201 202 mockServer.SetInjection(MGetIndexBuildProgress, func(_ context.Context, raw proto.Message) (proto.Message, error) { 203 _, ok := raw.(*milvuspb.GetIndexBuildProgressRequest) 204 if !ok { 205 t.FailNow() 206 } 207 resp := &milvuspb.GetIndexBuildProgressResponse{} 208 resp.Status = &commonpb.Status{ 209 ErrorCode: commonpb.ErrorCode_UnexpectedError, 210 } 211 return resp, nil 212 }) 213 _, _, err = c.GetIndexBuildProgress(ctx, testCollectionName, testVectorField) 214 assert.Error(t, err) 215 }) 216 217 }