github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/client_test.go (about) 1 package client 2 3 import ( 4 "context" 5 "fmt" 6 "log" 7 "math/rand" 8 "net" 9 "reflect" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/stretchr/testify/assert" 15 "google.golang.org/grpc" 16 "google.golang.org/grpc/codes" 17 "google.golang.org/grpc/examples/helloworld/helloworld" 18 "google.golang.org/grpc/keepalive" 19 "google.golang.org/grpc/reflection" 20 "google.golang.org/grpc/status" 21 "google.golang.org/grpc/test/bufconn" 22 23 "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" 24 "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" 25 "github.com/milvus-io/milvus-sdk-go/v2/entity" 26 ) 27 28 const ( 29 bufSize = 1024 * 1024 30 ) 31 32 var ( 33 lis *bufconn.Listener 34 mockServer *MockServer 35 ) 36 37 const ( 38 testCollectionName = `test_go_sdk` 39 testCollectionID = int64(789) 40 testPrimaryField = `int64` 41 testVectorField = `vector` 42 testVectorDim = 128 43 testDefaultReplicaNumber = int32(1) 44 testMultiReplicaNumber = int32(2) 45 testUsername = "user" 46 testPassword = "pwd" 47 ) 48 49 func defaultSchema() *entity.Schema { 50 return entity.NewSchema().WithName(testCollectionName).WithAutoID(false). 51 WithField(entity.NewField().WithName(testPrimaryField).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)). 52 WithField(entity.NewField().WithName(testVectorField).WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim)) 53 } 54 55 func varCharSchema() *entity.Schema { 56 return entity.NewSchema().WithName(testCollectionName).WithAutoID(false). 57 WithField(entity.NewField().WithName("varchar").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(false).WithMaxLength(100)). 58 WithField(entity.NewField().WithName(testVectorField).WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim)) 59 } 60 61 var _ entity.Row = &defaultRow{} 62 63 type defaultRow struct { 64 entity.RowBase 65 int64 int64 `milvus:"primary_key"` 66 Vector []float32 `milvus:"dim:128"` 67 } 68 69 func (r defaultRow) Collection() string { 70 return testCollectionName 71 } 72 73 // TestMain establishes mock grpc server to testing client behavior 74 func TestMain(m *testing.M) { 75 rand.Seed(time.Now().Unix()) 76 lis = bufconn.Listen(bufSize) 77 s := grpc.NewServer() 78 mockServer = &MockServer{ 79 Injections: make(map[ServiceMethod]TestInjection), 80 } 81 milvuspb.RegisterMilvusServiceServer(s, mockServer) 82 go func() { 83 if err := s.Serve(lis); err != nil { 84 log.Fatalf("Server exited with error: %v", err) 85 } 86 }() 87 m.Run() 88 // lis.Close() 89 } 90 91 // use bufconn dialer 92 func bufDialer(context.Context, string) (net.Conn, error) { 93 return lis.Dial() 94 } 95 96 func testClient(ctx context.Context, t *testing.T) Client { 97 c, err := NewClient(ctx, 98 Config{ 99 Address: "bufnet", 100 DialOptions: []grpc.DialOption{ 101 grpc.WithBlock(), 102 grpc.WithInsecure(), 103 grpc.WithContextDialer(bufDialer), 104 }, 105 }) 106 107 if !assert.Nil(t, err) || !assert.NotNil(t, c) { 108 t.FailNow() 109 } 110 return c 111 } 112 113 func TestHandleRespStatus(t *testing.T) { 114 assert.NotNil(t, handleRespStatus(nil)) 115 assert.Nil(t, handleRespStatus(&commonpb.Status{ 116 ErrorCode: commonpb.ErrorCode_Success, 117 })) 118 assert.NotNil(t, handleRespStatus(&commonpb.Status{ 119 ErrorCode: commonpb.ErrorCode_UnexpectedError, 120 })) 121 } 122 123 type ValidStruct struct { 124 entity.RowBase 125 ID int64 `milvus:"primary_key"` 126 Attr1 int8 127 Attr2 int16 128 Attr3 int32 129 Attr4 float32 130 Attr5 float64 131 Attr6 string 132 Vector []float32 `milvus:"dim:128"` 133 } 134 135 func TestGrpcClientNil(t *testing.T) { 136 c := &GrpcClient{} 137 tp := reflect.TypeOf(c) 138 v := reflect.ValueOf(c) 139 ctx := context.Background() 140 c2 := testClient(ctx, t) 141 v2 := reflect.ValueOf(c2) 142 143 ctxDone, cancel := context.WithCancel(context.Background()) 144 cancel() // cancel here, so the ctx is done already 145 146 for i := 0; i < tp.NumMethod(); i++ { 147 m := tp.Method(i) 148 t.Run(fmt.Sprintf("TestGrpcClientNil_%s", m.Name), func(t *testing.T) { 149 mt := m.Type // type of function 150 if m.Name == "Close" || m.Name == "Connect" || // skip connect & close 151 m.Name == "UsingDatabase" || // skip use database 152 m.Name == "Search" || // type alias MetricType treated as string 153 m.Name == "QueryIterator" || 154 m.Name == "HybridSearch" || // type alias MetricType treated as string 155 m.Name == "CalcDistance" || 156 m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect 157 m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ... 158 t.Skip("method", m.Name, "skipped") 159 } 160 ins := make([]reflect.Value, 0, mt.NumIn()) 161 for j := 1; j < mt.NumIn(); j++ { // idx == 0, is the receiver v 162 if j == 1 { 163 // non-general solution, hard code context! 164 ins = append(ins, reflect.ValueOf(ctx)) 165 continue 166 } 167 if mt.IsVariadic() { 168 // Variadic function, skip last parameter 169 // func m (arg1 interface, opts ... options) 170 if j == mt.NumIn()-1 { 171 continue 172 } 173 } 174 inT := mt.In(j) 175 176 switch inT.Kind() { 177 case reflect.String: // pass empty 178 ins = append(ins, reflect.ValueOf("")) 179 case reflect.Int: 180 ins = append(ins, reflect.ValueOf(0)) 181 case reflect.Int64: 182 ins = append(ins, reflect.ValueOf(int64(0))) 183 case reflect.Bool: 184 ins = append(ins, reflect.ValueOf(false)) 185 case reflect.Interface: 186 idxType := reflect.TypeOf((*entity.Index)(nil)).Elem() 187 rowType := reflect.TypeOf((*entity.Row)(nil)).Elem() 188 colType := reflect.TypeOf((*entity.Column)(nil)).Elem() 189 switch { 190 case inT.Implements(idxType): 191 idx, _ := entity.NewIndexFlat(entity.L2) 192 ins = append(ins, reflect.ValueOf(idx)) 193 case inT.Implements(rowType): 194 ins = append(ins, reflect.ValueOf(&ValidStruct{})) 195 case inT.Implements(colType): 196 ins = append(ins, reflect.ValueOf(entity.NewColumnInt64("id", []int64{}))) 197 } 198 default: 199 ins = append(ins, reflect.Zero(inT)) 200 } 201 } 202 outs := v.MethodByName(m.Name).Call(ins) 203 assert.True(t, len(outs) > 0) 204 assert.EqualValues(t, ErrClientNotReady, outs[len(outs)-1].Interface()) 205 206 // ctx done 207 208 if len(ins) > 0 { // with context param 209 ins[0] = reflect.ValueOf(ctxDone) 210 outs := v2.MethodByName(m.Name).Call(ins) 211 assert.True(t, len(outs) > 0) 212 assert.False(t, outs[len(outs)-1].IsNil()) 213 } 214 }) 215 } 216 } 217 218 func TestGrpcClientConnect(t *testing.T) { 219 ctx := context.Background() 220 221 t.Run("Use bufconn dailer, testing case", func(t *testing.T) { 222 c, err := NewClient(ctx, 223 Config{ 224 Address: "bufnet", 225 DialOptions: []grpc.DialOption{ 226 grpc.WithBlock(), grpc.WithInsecure(), grpc.WithContextDialer(bufDialer), 227 }, 228 }) 229 assert.Nil(t, err) 230 assert.NotNil(t, c) 231 }) 232 233 t.Run("Test empty addr, using default timeout", func(t *testing.T) { 234 c, err := NewClient(ctx, Config{ 235 Address: "", 236 }) 237 assert.NotNil(t, err) 238 assert.Nil(t, c) 239 }) 240 } 241 242 func TestGrpcClientClose(t *testing.T) { 243 ctx := context.Background() 244 245 t.Run("normal close", func(t *testing.T) { 246 c := testClient(ctx, t) 247 assert.Nil(t, c.Close()) 248 }) 249 250 t.Run("double close", func(t *testing.T) { 251 c := testClient(ctx, t) 252 assert.Nil(t, c.Close()) 253 assert.Nil(t, c.Close()) 254 }) 255 } 256 257 type Tserver struct { 258 helloworld.UnimplementedGreeterServer 259 reqCounter uint 260 SuccessCount uint 261 } 262 263 func (s *Tserver) SayHello(_ context.Context, in *helloworld.HelloRequest) (*helloworld.HelloReply, error) { 264 log.Printf("Received: %s", in.Name) 265 s.reqCounter++ 266 if s.reqCounter%s.SuccessCount == 0 { 267 log.Printf("success %d", s.reqCounter) 268 return &helloworld.HelloReply{Message: strings.ToUpper(in.Name)}, nil 269 } 270 return nil, status.Errorf(codes.Unavailable, "server: fail it") 271 } 272 273 func TestGrpcClientRetryPolicy(t *testing.T) { 274 // server 275 port := ":50051" 276 address := "localhost:50051" 277 lis, err := net.Listen("tcp", port) 278 if err != nil { 279 log.Fatalf("failed to listen: %v", err) 280 } 281 kaep := keepalive.EnforcementPolicy{ 282 MinTime: 5 * time.Second, 283 PermitWithoutStream: true, 284 } 285 kasp := keepalive.ServerParameters{ 286 Time: 60 * time.Second, 287 Timeout: 60 * time.Second, 288 } 289 290 maxAttempts := 5 291 s := grpc.NewServer( 292 grpc.KeepaliveEnforcementPolicy(kaep), 293 grpc.KeepaliveParams(kasp), 294 ) 295 helloworld.RegisterGreeterServer(s, &Tserver{SuccessCount: uint(maxAttempts)}) 296 reflection.Register(s) 297 go func() { 298 if err := s.Serve(lis); err != nil { 299 log.Fatalf("failed to serve: %v", err) 300 } 301 }() 302 defer s.Stop() 303 304 client, err := NewClient(context.TODO(), Config{Address: address, DisableConn: true}) 305 assert.Nil(t, err) 306 defer client.Close() 307 308 greeterClient := helloworld.NewGreeterClient(client.(*GrpcClient).Conn) 309 ctx := context.Background() 310 name := fmt.Sprintf("hello world %d", time.Now().Second()) 311 res, err := greeterClient.SayHello(ctx, &helloworld.HelloRequest{Name: name}) 312 assert.Nil(t, err) 313 assert.Equal(t, res.Message, strings.ToUpper(name)) 314 } 315 316 func TestClient_NewDefaultGrpcClientWithURI(t *testing.T) { 317 username := "u" 318 password := "p" 319 t.Run("create grpc client fail", func(t *testing.T) { 320 uri := "https://localhost:port" 321 ctx := context.Background() 322 client, err := NewDefaultGrpcClientWithURI(ctx, uri, username, password) 323 assert.Nil(t, client) 324 assert.Error(t, err) 325 }) 326 } 327 328 var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 329 330 func randStr(n int) string { 331 sb := strings.Builder{} 332 sb.Grow(n) 333 334 for i := 0; i < n; i++ { 335 sb.WriteRune(letters[rand.Intn(len(letters))]) 336 } 337 338 return sb.String() 339 }