go-micro.dev/v5@v5.12.0/client/rpc_client.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/google/uuid"
    11  	"github.com/pkg/errors"
    12  
    13  	"go-micro.dev/v5/broker"
    14  	"go-micro.dev/v5/codec"
    15  	raw "go-micro.dev/v5/codec/bytes"
    16  	merrors "go-micro.dev/v5/errors"
    17  	log "go-micro.dev/v5/logger"
    18  	"go-micro.dev/v5/metadata"
    19  	"go-micro.dev/v5/registry"
    20  	"go-micro.dev/v5/selector"
    21  	"go-micro.dev/v5/transport"
    22  	"go-micro.dev/v5/transport/headers"
    23  	"go-micro.dev/v5/util/buf"
    24  	"go-micro.dev/v5/util/net"
    25  	"go-micro.dev/v5/util/pool"
    26  )
    27  
    28  const (
    29  	packageID = "go.micro.client"
    30  )
    31  
    32  type rpcClient struct {
    33  	seq  uint64
    34  	opts Options
    35  	once atomic.Value
    36  	pool pool.Pool
    37  	mu   sync.RWMutex
    38  }
    39  
    40  func newRPCClient(opt ...Option) Client {
    41  	opts := NewOptions(opt...)
    42  
    43  	p := pool.NewPool(
    44  		pool.Size(opts.PoolSize),
    45  		pool.TTL(opts.PoolTTL),
    46  		pool.Transport(opts.Transport),
    47  		pool.CloseTimeout(opts.PoolCloseTimeout),
    48  	)
    49  
    50  	rc := &rpcClient{
    51  		opts: opts,
    52  		pool: p,
    53  		seq:  0,
    54  	}
    55  	rc.once.Store(false)
    56  
    57  	c := Client(rc)
    58  
    59  	// wrap in reverse
    60  	for i := len(opts.Wrappers); i > 0; i-- {
    61  		c = opts.Wrappers[i-1](c)
    62  	}
    63  
    64  	return c
    65  }
    66  
    67  func (r *rpcClient) newCodec(contentType string) (codec.NewCodec, error) {
    68  	if c, ok := r.opts.Codecs[contentType]; ok {
    69  		return c, nil
    70  	}
    71  
    72  	if cf, ok := DefaultCodecs[contentType]; ok {
    73  		return cf, nil
    74  	}
    75  
    76  	return nil, fmt.Errorf("unsupported Content-Type: %s", contentType)
    77  }
    78  
    79  func (r *rpcClient) call(
    80  	ctx context.Context,
    81  	node *registry.Node,
    82  	req Request,
    83  	resp interface{},
    84  	opts CallOptions,
    85  ) error {
    86  	address := node.Address
    87  	logger := r.Options().Logger
    88  
    89  	msg := &transport.Message{
    90  		Header: make(map[string]string),
    91  	}
    92  
    93  	md, ok := metadata.FromContext(ctx)
    94  	if ok {
    95  		for k, v := range md {
    96  			// Don't copy Micro-Topic header, that is used for pub/sub
    97  			// this is fixes the case when the client uses the same context that
    98  			// is received in the subscriber.
    99  			if k == headers.Message {
   100  				continue
   101  			}
   102  
   103  			msg.Header[k] = v
   104  		}
   105  	}
   106  
   107  	// Set connection timeout for single requests to the server. Should be > 0
   108  	// as otherwise requests can't be made.
   109  	cTimeout := opts.ConnectionTimeout
   110  	if cTimeout == 0 {
   111  		logger.Log(log.DebugLevel, "connection timeout was set to 0, overridng to default connection timeout")
   112  
   113  		cTimeout = DefaultConnectionTimeout
   114  	}
   115  
   116  	// set timeout in nanoseconds
   117  	msg.Header["Timeout"] = fmt.Sprintf("%d", cTimeout)
   118  	// set the content type for the request
   119  	msg.Header["Content-Type"] = req.ContentType()
   120  	// set the accept header
   121  	msg.Header["Accept"] = req.ContentType()
   122  
   123  	// setup old protocol
   124  	reqCodec := setupProtocol(msg, node)
   125  
   126  	// no codec specified
   127  	if reqCodec == nil {
   128  		var err error
   129  		reqCodec, err = r.newCodec(req.ContentType())
   130  
   131  		if err != nil {
   132  			return merrors.InternalServerError("go.micro.client", err.Error())
   133  		}
   134  	}
   135  
   136  	dOpts := []transport.DialOption{
   137  		transport.WithStream(),
   138  	}
   139  
   140  	if opts.DialTimeout >= 0 {
   141  		dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout))
   142  	}
   143  
   144  	if opts.ConnClose {
   145  		dOpts = append(dOpts, transport.WithConnClose())
   146  	}
   147  
   148  	c, err := r.pool.Get(address, dOpts...)
   149  	if err != nil {
   150  		if c == nil {
   151  			return merrors.InternalServerError("go.micro.client", "connection error: %v", err)
   152  		}
   153  		logger.Log(log.ErrorLevel, "failed to close pool", err)
   154  	}
   155  
   156  	seq := atomic.AddUint64(&r.seq, 1) - 1
   157  	codec := newRPCCodec(msg, c, reqCodec, "")
   158  
   159  	rsp := &rpcResponse{
   160  		socket: c,
   161  		codec:  codec,
   162  	}
   163  
   164  	releaseFunc := func(err error) {
   165  		if err = r.pool.Release(c, err); err != nil {
   166  			logger.Log(log.ErrorLevel, "failed to release pool", err)
   167  		}
   168  	}
   169  
   170  	stream := &rpcStream{
   171  		id:       fmt.Sprintf("%v", seq),
   172  		context:  ctx,
   173  		request:  req,
   174  		response: rsp,
   175  		codec:    codec,
   176  		closed:   make(chan bool),
   177  		close:    opts.ConnClose,
   178  		release:  releaseFunc,
   179  		sendEOS:  false,
   180  	}
   181  
   182  	// close the stream on exiting this function
   183  	defer func() {
   184  		if err := stream.Close(); err != nil {
   185  			logger.Log(log.ErrorLevel, "failed to close stream", err)
   186  		}
   187  	}()
   188  
   189  	// wait for error response
   190  	ch := make(chan error, 1)
   191  
   192  	go func() {
   193  		defer func() {
   194  			if r := recover(); r != nil {
   195  				ch <- merrors.InternalServerError("go.micro.client", "panic recovered: %v", r)
   196  			}
   197  		}()
   198  
   199  		// send request
   200  		if err := stream.Send(req.Body()); err != nil {
   201  			ch <- err
   202  			return
   203  		}
   204  
   205  		// recv response
   206  		if err := stream.Recv(resp); err != nil {
   207  			ch <- err
   208  			return
   209  		}
   210  
   211  		// success
   212  		ch <- nil
   213  	}()
   214  
   215  	var grr error
   216  
   217  	select {
   218  	case err := <-ch:
   219  		return err
   220  	case <-time.After(cTimeout):
   221  		grr = merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err()))
   222  	}
   223  
   224  	// set the stream error
   225  	if grr != nil {
   226  		stream.Lock()
   227  		stream.err = grr
   228  		stream.Unlock()
   229  
   230  		return grr
   231  	}
   232  
   233  	return nil
   234  }
   235  
   236  func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request, opts CallOptions) (Stream, error) {
   237  	address := node.Address
   238  	logger := r.Options().Logger
   239  
   240  	msg := &transport.Message{
   241  		Header: make(map[string]string),
   242  	}
   243  
   244  	md, ok := metadata.FromContext(ctx)
   245  	if ok {
   246  		for k, v := range md {
   247  			msg.Header[k] = v
   248  		}
   249  	}
   250  
   251  	// set timeout in nanoseconds
   252  	if opts.StreamTimeout > time.Duration(0) {
   253  		msg.Header["Timeout"] = fmt.Sprintf("%d", opts.StreamTimeout)
   254  	}
   255  	// set the content type for the request
   256  	msg.Header["Content-Type"] = req.ContentType()
   257  	// set the accept header
   258  	msg.Header["Accept"] = req.ContentType()
   259  
   260  	// set old codecs
   261  	nCodec := setupProtocol(msg, node)
   262  
   263  	// no codec specified
   264  	if nCodec == nil {
   265  		var err error
   266  
   267  		nCodec, err = r.newCodec(req.ContentType())
   268  		if err != nil {
   269  			return nil, merrors.InternalServerError("go.micro.client", err.Error())
   270  		}
   271  	}
   272  
   273  	dOpts := []transport.DialOption{
   274  		transport.WithStream(),
   275  	}
   276  
   277  	if opts.DialTimeout >= 0 {
   278  		dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout))
   279  	}
   280  
   281  	c, err := r.opts.Transport.Dial(address, dOpts...)
   282  	if err != nil {
   283  		return nil, merrors.InternalServerError("go.micro.client", "connection error: %v", err)
   284  	}
   285  
   286  	// increment the sequence number
   287  	seq := atomic.AddUint64(&r.seq, 1) - 1
   288  	id := fmt.Sprintf("%v", seq)
   289  
   290  	// create codec with stream id
   291  	codec := newRPCCodec(msg, c, nCodec, id)
   292  
   293  	rsp := &rpcResponse{
   294  		socket: c,
   295  		codec:  codec,
   296  	}
   297  
   298  	// set request codec
   299  	if r, ok := req.(*rpcRequest); ok {
   300  		r.codec = codec
   301  	}
   302  
   303  	stream := &rpcStream{
   304  		id:       id,
   305  		context:  ctx,
   306  		request:  req,
   307  		response: rsp,
   308  		codec:    codec,
   309  		// used to close the stream
   310  		closed: make(chan bool),
   311  		// signal the end of stream,
   312  		sendEOS: true,
   313  		release: func(_ error) {},
   314  	}
   315  
   316  	// wait for error response
   317  	ch := make(chan error, 1)
   318  
   319  	go func() {
   320  		// send the first message
   321  		ch <- stream.Send(req.Body())
   322  	}()
   323  
   324  	var grr error
   325  
   326  	select {
   327  	case err := <-ch:
   328  		grr = err
   329  	case <-ctx.Done():
   330  		grr = merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err()))
   331  	}
   332  
   333  	if grr != nil {
   334  		// set the error
   335  		stream.Lock()
   336  		stream.err = grr
   337  		stream.Unlock()
   338  
   339  		// close the stream
   340  		if err := stream.Close(); err != nil {
   341  			logger.Logf(log.ErrorLevel, "failed to close stream: %v", err)
   342  		}
   343  
   344  		return nil, grr
   345  	}
   346  
   347  	return stream, nil
   348  }
   349  
   350  func (r *rpcClient) Init(opts ...Option) error {
   351  	r.mu.Lock()
   352  	defer r.mu.Unlock()
   353  
   354  	size := r.opts.PoolSize
   355  	ttl := r.opts.PoolTTL
   356  	tr := r.opts.Transport
   357  
   358  	for _, o := range opts {
   359  		o(&r.opts)
   360  	}
   361  
   362  	// update pool configuration if the options changed
   363  	if size != r.opts.PoolSize || ttl != r.opts.PoolTTL || tr != r.opts.Transport {
   364  		// close existing pool
   365  		if err := r.pool.Close(); err != nil {
   366  			return errors.Wrap(err, "failed to close pool")
   367  		}
   368  
   369  		// create new pool
   370  		r.pool = pool.NewPool(
   371  			pool.Size(r.opts.PoolSize),
   372  			pool.TTL(r.opts.PoolTTL),
   373  			pool.Transport(r.opts.Transport),
   374  			pool.CloseTimeout(r.opts.PoolCloseTimeout),
   375  		)
   376  	}
   377  
   378  	return nil
   379  }
   380  
   381  // Options retrives the options.
   382  func (r *rpcClient) Options() Options {
   383  	r.mu.RLock()
   384  	defer r.mu.RUnlock()
   385  
   386  	return r.opts
   387  }
   388  
   389  // next returns an iterator for the next nodes to call.
   390  func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) {
   391  	// try get the proxy
   392  	service, address, _ := net.Proxy(request.Service(), opts.Address)
   393  
   394  	// return remote address
   395  	if len(address) > 0 {
   396  		nodes := make([]*registry.Node, len(address))
   397  
   398  		for i, addr := range address {
   399  			nodes[i] = &registry.Node{
   400  				Address: addr,
   401  				// Set the protocol
   402  				Metadata: map[string]string{
   403  					"protocol": "mucp",
   404  				},
   405  			}
   406  		}
   407  
   408  		// crude return method
   409  		return func() (*registry.Node, error) {
   410  			return nodes[time.Now().Unix()%int64(len(nodes))], nil
   411  		}, nil
   412  	}
   413  
   414  	// get next nodes from the selector
   415  	next, err := r.opts.Selector.Select(service, opts.SelectOptions...)
   416  	if err != nil {
   417  		if errors.Is(err, selector.ErrNotFound) {
   418  			return nil, merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error())
   419  		}
   420  
   421  		return nil, merrors.InternalServerError("go.micro.client", "error selecting %s node: %s", service, err.Error())
   422  	}
   423  
   424  	return next, nil
   425  }
   426  
   427  func (r *rpcClient) Call(ctx context.Context, request Request, response interface{}, opts ...CallOption) error {
   428  	// TODO: further validate these mutex locks. full lock would prevent
   429  	// parallel calls. Maybe we can set individual locks for secctions.
   430  	r.mu.RLock()
   431  	defer r.mu.RUnlock()
   432  
   433  	// make a copy of call opts
   434  	callOpts := r.opts.CallOptions
   435  	for _, opt := range opts {
   436  		opt(&callOpts)
   437  	}
   438  
   439  	next, err := r.next(request, callOpts)
   440  	if err != nil {
   441  		return err
   442  	}
   443  
   444  	// check if we already have a deadline
   445  	d, ok := ctx.Deadline()
   446  	if !ok {
   447  		// no deadline so we create a new one
   448  		var cancel context.CancelFunc
   449  		ctx, cancel = context.WithTimeout(ctx, callOpts.RequestTimeout)
   450  
   451  		defer cancel()
   452  	} else {
   453  		// got a deadline so no need to setup context
   454  		// but we need to set the timeout we pass along
   455  		opt := WithRequestTimeout(time.Until(d))
   456  		opt(&callOpts)
   457  	}
   458  
   459  	// should we noop right here?
   460  	select {
   461  	case <-ctx.Done():
   462  		return merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err()))
   463  	default:
   464  	}
   465  
   466  	// make copy of call method
   467  	rcall := r.call
   468  
   469  	// wrap the call in reverse
   470  	for i := len(callOpts.CallWrappers); i > 0; i-- {
   471  		rcall = callOpts.CallWrappers[i-1](rcall)
   472  	}
   473  
   474  	// return errors.New("go.micro.client", "request timeout", 408)
   475  	call := func(i int) error {
   476  		// call backoff first. Someone may want an initial start delay
   477  		t, err := callOpts.Backoff(ctx, request, i)
   478  		if err != nil {
   479  			return merrors.InternalServerError("go.micro.client", "backoff error: %v", err.Error())
   480  		}
   481  
   482  		// only sleep if greater than 0
   483  		if t.Seconds() > 0 {
   484  			time.Sleep(t)
   485  		}
   486  
   487  		// select next node
   488  		node, err := next()
   489  		service := request.Service()
   490  
   491  		if err != nil {
   492  			if errors.Is(err, selector.ErrNotFound) {
   493  				return merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error())
   494  			}
   495  
   496  			return merrors.InternalServerError("go.micro.client",
   497  				"error getting next %s node: %s",
   498  				service,
   499  				err.Error())
   500  		}
   501  
   502  		// make the call
   503  		err = rcall(ctx, node, request, response, callOpts)
   504  		r.opts.Selector.Mark(service, node, err)
   505  
   506  		return err
   507  	}
   508  
   509  	// get the retries
   510  	retries := callOpts.Retries
   511  
   512  	// disable retries when using a proxy
   513  	// Note: I don't see why we should disable retries for proxies, so commenting out.
   514  	// if _, _, ok := net.Proxy(request.Service(), callOpts.Address); ok {
   515  	// 	retries = 0
   516  	// }
   517  
   518  	ch := make(chan error, retries+1)
   519  
   520  	var gerr error
   521  
   522  	for i := 0; i <= retries; i++ {
   523  		go func(i int) {
   524  			ch <- call(i)
   525  		}(i)
   526  
   527  		select {
   528  		case <-ctx.Done():
   529  			return merrors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err()))
   530  		case err := <-ch:
   531  			// if the call succeeded lets bail early
   532  			if err == nil {
   533  				return nil
   534  			}
   535  
   536  			retry, rerr := callOpts.Retry(ctx, request, i, err)
   537  			if rerr != nil {
   538  				return rerr
   539  			}
   540  
   541  			if !retry {
   542  				return err
   543  			}
   544  
   545  			r.opts.Logger.Logf(log.DebugLevel, "Retrying request. Previous attempt failed with: %v", err)
   546  
   547  			gerr = err
   548  		}
   549  	}
   550  
   551  	return gerr
   552  }
   553  
   554  func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOption) (Stream, error) {
   555  	r.mu.RLock()
   556  	defer r.mu.RUnlock()
   557  
   558  	// make a copy of call opts
   559  	callOpts := r.opts.CallOptions
   560  	for _, opt := range opts {
   561  		opt(&callOpts)
   562  	}
   563  
   564  	next, err := r.next(request, callOpts)
   565  	if err != nil {
   566  		return nil, err
   567  	}
   568  
   569  	select {
   570  	case <-ctx.Done():
   571  		return nil, merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err()))
   572  	default:
   573  	}
   574  
   575  	call := func(i int) (Stream, error) {
   576  		// call backoff first. Someone may want an initial start delay
   577  		t, err := callOpts.Backoff(ctx, request, i)
   578  		if err != nil {
   579  			return nil, merrors.InternalServerError("go.micro.client", "backoff error: %v", err.Error())
   580  		}
   581  
   582  		// only sleep if greater than 0
   583  		if t.Seconds() > 0 {
   584  			time.Sleep(t)
   585  		}
   586  
   587  		node, err := next()
   588  		service := request.Service()
   589  
   590  		if err != nil {
   591  			if errors.Is(err, selector.ErrNotFound) {
   592  				return nil, merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error())
   593  			}
   594  
   595  			return nil, merrors.InternalServerError("go.micro.client",
   596  				"error getting next %s node: %s",
   597  				service,
   598  				err.Error())
   599  		}
   600  
   601  		stream, err := r.stream(ctx, node, request, callOpts)
   602  		r.opts.Selector.Mark(service, node, err)
   603  
   604  		return stream, err
   605  	}
   606  
   607  	type response struct {
   608  		stream Stream
   609  		err    error
   610  	}
   611  
   612  	// get the retries
   613  	retries := callOpts.Retries
   614  
   615  	// disable retries when using a proxy
   616  	if _, _, ok := net.Proxy(request.Service(), callOpts.Address); ok {
   617  		retries = 0
   618  	}
   619  
   620  	ch := make(chan response, retries+1)
   621  
   622  	var grr error
   623  
   624  	for i := 0; i <= retries; i++ {
   625  		go func(i int) {
   626  			s, err := call(i)
   627  			ch <- response{s, err}
   628  		}(i)
   629  
   630  		select {
   631  		case <-ctx.Done():
   632  			return nil, merrors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err()))
   633  		case rsp := <-ch:
   634  			// if the call succeeded lets bail early
   635  			if rsp.err == nil {
   636  				return rsp.stream, nil
   637  			}
   638  
   639  			retry, rerr := callOpts.Retry(ctx, request, i, rsp.err)
   640  			if rerr != nil {
   641  				return nil, rerr
   642  			}
   643  
   644  			if !retry {
   645  				return nil, rsp.err
   646  			}
   647  
   648  			grr = rsp.err
   649  		}
   650  	}
   651  
   652  	return nil, grr
   653  }
   654  
   655  func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOption) error {
   656  	options := PublishOptions{
   657  		Context: context.Background(),
   658  	}
   659  	for _, o := range opts {
   660  		o(&options)
   661  	}
   662  
   663  	metadata, ok := metadata.FromContext(ctx)
   664  	if !ok {
   665  		metadata = make(map[string]string)
   666  	}
   667  
   668  	id := uuid.New().String()
   669  	metadata["Content-Type"] = msg.ContentType()
   670  	metadata[headers.Message] = msg.Topic()
   671  	metadata[headers.ID] = id
   672  
   673  	// set the topic
   674  	topic := msg.Topic()
   675  
   676  	// get the exchange
   677  	if len(options.Exchange) > 0 {
   678  		topic = options.Exchange
   679  	}
   680  
   681  	// encode message body
   682  	cf, err := r.newCodec(msg.ContentType())
   683  	if err != nil {
   684  		return merrors.InternalServerError(packageID, err.Error())
   685  	}
   686  
   687  	var body []byte
   688  
   689  	// passed in raw data
   690  	if d, ok := msg.Payload().(*raw.Frame); ok {
   691  		body = d.Data
   692  	} else {
   693  		b := buf.New(nil)
   694  
   695  		if err = cf(b).Write(&codec.Message{
   696  			Target: topic,
   697  			Type:   codec.Event,
   698  			Header: map[string]string{
   699  				headers.ID:      id,
   700  				headers.Message: msg.Topic(),
   701  			},
   702  		}, msg.Payload()); err != nil {
   703  			return merrors.InternalServerError(packageID, err.Error())
   704  		}
   705  
   706  		// set the body
   707  		body = b.Bytes()
   708  	}
   709  
   710  	l, ok := r.once.Load().(bool)
   711  	if !ok {
   712  		return fmt.Errorf("failed to cast to bool")
   713  	}
   714  
   715  	if !l {
   716  		if err = r.opts.Broker.Connect(); err != nil {
   717  			return merrors.InternalServerError(packageID, err.Error())
   718  		}
   719  
   720  		r.once.Store(true)
   721  	}
   722  
   723  	return r.opts.Broker.Publish(topic, &broker.Message{
   724  		Header: metadata,
   725  		Body:   body,
   726  	}, broker.PublishContext(options.Context))
   727  }
   728  
   729  func (r *rpcClient) NewMessage(topic string, message interface{}, opts ...MessageOption) Message {
   730  	return newMessage(topic, message, r.opts.ContentType, opts...)
   731  }
   732  
   733  func (r *rpcClient) NewRequest(service, method string, request interface{}, reqOpts ...RequestOption) Request {
   734  	return newRequest(service, method, request, r.opts.ContentType, reqOpts...)
   735  }
   736  
   737  func (r *rpcClient) String() string {
   738  	return "mucp"
   739  }