github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/client_test.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"math/rand"
     8  	"net"
     9  	"reflect"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/codes"
    17  	"google.golang.org/grpc/examples/helloworld/helloworld"
    18  	"google.golang.org/grpc/keepalive"
    19  	"google.golang.org/grpc/reflection"
    20  	"google.golang.org/grpc/status"
    21  	"google.golang.org/grpc/test/bufconn"
    22  
    23  	"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
    24  	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
    25  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    26  )
    27  
    28  const (
    29  	bufSize = 1024 * 1024
    30  )
    31  
    32  var (
    33  	lis        *bufconn.Listener
    34  	mockServer *MockServer
    35  )
    36  
    37  const (
    38  	testCollectionName       = `test_go_sdk`
    39  	testCollectionID         = int64(789)
    40  	testPrimaryField         = `int64`
    41  	testVectorField          = `vector`
    42  	testVectorDim            = 128
    43  	testDefaultReplicaNumber = int32(1)
    44  	testMultiReplicaNumber   = int32(2)
    45  	testUsername             = "user"
    46  	testPassword             = "pwd"
    47  )
    48  
    49  func defaultSchema() *entity.Schema {
    50  	return entity.NewSchema().WithName(testCollectionName).WithAutoID(false).
    51  		WithField(entity.NewField().WithName(testPrimaryField).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)).
    52  		WithField(entity.NewField().WithName(testVectorField).WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim))
    53  }
    54  
    55  func varCharSchema() *entity.Schema {
    56  	return entity.NewSchema().WithName(testCollectionName).WithAutoID(false).
    57  		WithField(entity.NewField().WithName("varchar").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(false).WithMaxLength(100)).
    58  		WithField(entity.NewField().WithName(testVectorField).WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim))
    59  }
    60  
    61  var _ entity.Row = &defaultRow{}
    62  
    63  type defaultRow struct {
    64  	entity.RowBase
    65  	int64  int64     `milvus:"primary_key"`
    66  	Vector []float32 `milvus:"dim:128"`
    67  }
    68  
    69  func (r defaultRow) Collection() string {
    70  	return testCollectionName
    71  }
    72  
    73  // TestMain establishes mock grpc server to testing client behavior
    74  func TestMain(m *testing.M) {
    75  	rand.Seed(time.Now().Unix())
    76  	lis = bufconn.Listen(bufSize)
    77  	s := grpc.NewServer()
    78  	mockServer = &MockServer{
    79  		Injections: make(map[ServiceMethod]TestInjection),
    80  	}
    81  	milvuspb.RegisterMilvusServiceServer(s, mockServer)
    82  	go func() {
    83  		if err := s.Serve(lis); err != nil {
    84  			log.Fatalf("Server exited with error: %v", err)
    85  		}
    86  	}()
    87  	m.Run()
    88  	//	lis.Close()
    89  }
    90  
    91  // use bufconn dialer
    92  func bufDialer(context.Context, string) (net.Conn, error) {
    93  	return lis.Dial()
    94  }
    95  
    96  func testClient(ctx context.Context, t *testing.T) Client {
    97  	c, err := NewClient(ctx,
    98  		Config{
    99  			Address: "bufnet",
   100  			DialOptions: []grpc.DialOption{
   101  				grpc.WithBlock(),
   102  				grpc.WithInsecure(),
   103  				grpc.WithContextDialer(bufDialer),
   104  			},
   105  		})
   106  
   107  	if !assert.Nil(t, err) || !assert.NotNil(t, c) {
   108  		t.FailNow()
   109  	}
   110  	return c
   111  }
   112  
   113  func TestHandleRespStatus(t *testing.T) {
   114  	assert.NotNil(t, handleRespStatus(nil))
   115  	assert.Nil(t, handleRespStatus(&commonpb.Status{
   116  		ErrorCode: commonpb.ErrorCode_Success,
   117  	}))
   118  	assert.NotNil(t, handleRespStatus(&commonpb.Status{
   119  		ErrorCode: commonpb.ErrorCode_UnexpectedError,
   120  	}))
   121  }
   122  
   123  type ValidStruct struct {
   124  	entity.RowBase
   125  	ID     int64 `milvus:"primary_key"`
   126  	Attr1  int8
   127  	Attr2  int16
   128  	Attr3  int32
   129  	Attr4  float32
   130  	Attr5  float64
   131  	Attr6  string
   132  	Vector []float32 `milvus:"dim:128"`
   133  }
   134  
   135  func TestGrpcClientNil(t *testing.T) {
   136  	c := &GrpcClient{}
   137  	tp := reflect.TypeOf(c)
   138  	v := reflect.ValueOf(c)
   139  	ctx := context.Background()
   140  	c2 := testClient(ctx, t)
   141  	v2 := reflect.ValueOf(c2)
   142  
   143  	ctxDone, cancel := context.WithCancel(context.Background())
   144  	cancel() // cancel here, so the ctx is done already
   145  
   146  	for i := 0; i < tp.NumMethod(); i++ {
   147  		m := tp.Method(i)
   148  		t.Run(fmt.Sprintf("TestGrpcClientNil_%s", m.Name), func(t *testing.T) {
   149  			mt := m.Type                                   // type of function
   150  			if m.Name == "Close" || m.Name == "Connect" || // skip connect & close
   151  				m.Name == "UsingDatabase" || // skip use database
   152  				m.Name == "Search" || // type alias MetricType treated as string
   153  				m.Name == "QueryIterator" ||
   154  				m.Name == "HybridSearch" || // type alias MetricType treated as string
   155  				m.Name == "CalcDistance" ||
   156  				m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect
   157  				m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ...
   158  				t.Skip("method", m.Name, "skipped")
   159  			}
   160  			ins := make([]reflect.Value, 0, mt.NumIn())
   161  			for j := 1; j < mt.NumIn(); j++ { // idx == 0, is the receiver v
   162  				if j == 1 {
   163  					// non-general solution, hard code context!
   164  					ins = append(ins, reflect.ValueOf(ctx))
   165  					continue
   166  				}
   167  				if mt.IsVariadic() {
   168  					// Variadic function, skip last parameter
   169  					// func m (arg1 interface, opts ... options)
   170  					if j == mt.NumIn()-1 {
   171  						continue
   172  					}
   173  				}
   174  				inT := mt.In(j)
   175  
   176  				switch inT.Kind() {
   177  				case reflect.String: // pass empty
   178  					ins = append(ins, reflect.ValueOf(""))
   179  				case reflect.Int:
   180  					ins = append(ins, reflect.ValueOf(0))
   181  				case reflect.Int64:
   182  					ins = append(ins, reflect.ValueOf(int64(0)))
   183  				case reflect.Bool:
   184  					ins = append(ins, reflect.ValueOf(false))
   185  				case reflect.Interface:
   186  					idxType := reflect.TypeOf((*entity.Index)(nil)).Elem()
   187  					rowType := reflect.TypeOf((*entity.Row)(nil)).Elem()
   188  					colType := reflect.TypeOf((*entity.Column)(nil)).Elem()
   189  					switch {
   190  					case inT.Implements(idxType):
   191  						idx, _ := entity.NewIndexFlat(entity.L2)
   192  						ins = append(ins, reflect.ValueOf(idx))
   193  					case inT.Implements(rowType):
   194  						ins = append(ins, reflect.ValueOf(&ValidStruct{}))
   195  					case inT.Implements(colType):
   196  						ins = append(ins, reflect.ValueOf(entity.NewColumnInt64("id", []int64{})))
   197  					}
   198  				default:
   199  					ins = append(ins, reflect.Zero(inT))
   200  				}
   201  			}
   202  			outs := v.MethodByName(m.Name).Call(ins)
   203  			assert.True(t, len(outs) > 0)
   204  			assert.EqualValues(t, ErrClientNotReady, outs[len(outs)-1].Interface())
   205  
   206  			// ctx done
   207  
   208  			if len(ins) > 0 { // with context param
   209  				ins[0] = reflect.ValueOf(ctxDone)
   210  				outs := v2.MethodByName(m.Name).Call(ins)
   211  				assert.True(t, len(outs) > 0)
   212  				assert.False(t, outs[len(outs)-1].IsNil())
   213  			}
   214  		})
   215  	}
   216  }
   217  
   218  func TestGrpcClientConnect(t *testing.T) {
   219  	ctx := context.Background()
   220  
   221  	t.Run("Use bufconn dailer, testing case", func(t *testing.T) {
   222  		c, err := NewClient(ctx,
   223  			Config{
   224  				Address: "bufnet",
   225  				DialOptions: []grpc.DialOption{
   226  					grpc.WithBlock(), grpc.WithInsecure(), grpc.WithContextDialer(bufDialer),
   227  				},
   228  			})
   229  		assert.Nil(t, err)
   230  		assert.NotNil(t, c)
   231  	})
   232  
   233  	t.Run("Test empty addr, using default timeout", func(t *testing.T) {
   234  		c, err := NewClient(ctx, Config{
   235  			Address: "",
   236  		})
   237  		assert.NotNil(t, err)
   238  		assert.Nil(t, c)
   239  	})
   240  }
   241  
   242  func TestGrpcClientClose(t *testing.T) {
   243  	ctx := context.Background()
   244  
   245  	t.Run("normal close", func(t *testing.T) {
   246  		c := testClient(ctx, t)
   247  		assert.Nil(t, c.Close())
   248  	})
   249  
   250  	t.Run("double close", func(t *testing.T) {
   251  		c := testClient(ctx, t)
   252  		assert.Nil(t, c.Close())
   253  		assert.Nil(t, c.Close())
   254  	})
   255  }
   256  
   257  type Tserver struct {
   258  	helloworld.UnimplementedGreeterServer
   259  	reqCounter   uint
   260  	SuccessCount uint
   261  }
   262  
   263  func (s *Tserver) SayHello(_ context.Context, in *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
   264  	log.Printf("Received: %s", in.Name)
   265  	s.reqCounter++
   266  	if s.reqCounter%s.SuccessCount == 0 {
   267  		log.Printf("success %d", s.reqCounter)
   268  		return &helloworld.HelloReply{Message: strings.ToUpper(in.Name)}, nil
   269  	}
   270  	return nil, status.Errorf(codes.Unavailable, "server: fail it")
   271  }
   272  
   273  func TestGrpcClientRetryPolicy(t *testing.T) {
   274  	// server
   275  	port := ":50051"
   276  	address := "localhost:50051"
   277  	lis, err := net.Listen("tcp", port)
   278  	if err != nil {
   279  		log.Fatalf("failed to listen: %v", err)
   280  	}
   281  	kaep := keepalive.EnforcementPolicy{
   282  		MinTime:             5 * time.Second,
   283  		PermitWithoutStream: true,
   284  	}
   285  	kasp := keepalive.ServerParameters{
   286  		Time:    60 * time.Second,
   287  		Timeout: 60 * time.Second,
   288  	}
   289  
   290  	maxAttempts := 5
   291  	s := grpc.NewServer(
   292  		grpc.KeepaliveEnforcementPolicy(kaep),
   293  		grpc.KeepaliveParams(kasp),
   294  	)
   295  	helloworld.RegisterGreeterServer(s, &Tserver{SuccessCount: uint(maxAttempts)})
   296  	reflection.Register(s)
   297  	go func() {
   298  		if err := s.Serve(lis); err != nil {
   299  			log.Fatalf("failed to serve: %v", err)
   300  		}
   301  	}()
   302  	defer s.Stop()
   303  
   304  	client, err := NewClient(context.TODO(), Config{Address: address, DisableConn: true})
   305  	assert.Nil(t, err)
   306  	defer client.Close()
   307  
   308  	greeterClient := helloworld.NewGreeterClient(client.(*GrpcClient).Conn)
   309  	ctx := context.Background()
   310  	name := fmt.Sprintf("hello world %d", time.Now().Second())
   311  	res, err := greeterClient.SayHello(ctx, &helloworld.HelloRequest{Name: name})
   312  	assert.Nil(t, err)
   313  	assert.Equal(t, res.Message, strings.ToUpper(name))
   314  }
   315  
   316  func TestClient_NewDefaultGrpcClientWithURI(t *testing.T) {
   317  	username := "u"
   318  	password := "p"
   319  	t.Run("create grpc client fail", func(t *testing.T) {
   320  		uri := "https://localhost:port"
   321  		ctx := context.Background()
   322  		client, err := NewDefaultGrpcClientWithURI(ctx, uri, username, password)
   323  		assert.Nil(t, client)
   324  		assert.Error(t, err)
   325  	})
   326  }
   327  
   328  var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
   329  
   330  func randStr(n int) string {
   331  	sb := strings.Builder{}
   332  	sb.Grow(n)
   333  
   334  	for i := 0; i < n; i++ {
   335  		sb.WriteRune(letters[rand.Intn(len(letters))])
   336  	}
   337  
   338  	return sb.String()
   339  }