github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/client/rpc_client.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/volts-dev/volts/codec"
     9  	"github.com/volts-dev/volts/internal/body"
    10  	"github.com/volts-dev/volts/internal/errors"
    11  	"github.com/volts-dev/volts/internal/metadata"
    12  	"github.com/volts-dev/volts/internal/net"
    13  	"github.com/volts-dev/volts/internal/pool"
    14  	"github.com/volts-dev/volts/registry"
    15  	"github.com/volts-dev/volts/selector"
    16  	"github.com/volts-dev/volts/transport"
    17  )
    18  
    19  type (
    20  	RpcClient struct {
    21  		config   *Config
    22  		pool     pool.Pool // connect pool
    23  		closing  bool      // user has called Close
    24  		shutdown bool      // server has told us to stop
    25  	}
    26  )
    27  
    28  func NewRpcClient(opts ...Option) *RpcClient {
    29  	cfg := newConfig(
    30  		transport.NewTCPTransport(),
    31  		opts...,
    32  	)
    33  
    34  	// 默认编码
    35  	if cfg.SerializeType == "" {
    36  		cfg.Serialize = codec.JSON
    37  	}
    38  
    39  	p := pool.NewPool(
    40  		pool.Size(cfg.PoolSize),
    41  		pool.TTL(cfg.PoolTtl),
    42  		pool.Transport(cfg.Transport),
    43  	)
    44  
    45  	return &RpcClient{
    46  		config: cfg,
    47  		pool:   p,
    48  	}
    49  }
    50  
    51  func (self *RpcClient) Init(opts ...Option) error {
    52  	self.config.Init(opts...)
    53  	return nil
    54  }
    55  
    56  func (self *RpcClient) Config() *Config {
    57  	return self.config
    58  }
    59  
    60  // 新建请求
    61  func (self *RpcClient) NewRequest(service, method string, request interface{}, optinos ...RequestOption) (*rpcRequest, error) {
    62  	optinos = append(optinos,
    63  		WithCodec(self.config.Serialize),
    64  	)
    65  	return newRpcRequest(service, method, request, optinos...)
    66  }
    67  
    68  func (self *RpcClient) call(ctx context.Context, node *registry.Node, req IRequest, opts CallOptions) (IResponse, error) {
    69  	// 验证解码器
    70  	msgCodece := codec.IdentifyCodec(self.config.Serialize)
    71  	if msgCodece == nil { // no codec specified
    72  		//call.Error = rpc.ErrUnsupportedCodec
    73  		//client.mutex.Unlock()
    74  		//call.done()
    75  		return nil, errors.UnsupportedCodec("volts.client", self.config.SerializeType)
    76  	}
    77  
    78  	// 获取空闲链接
    79  	dOpts := []transport.DialOption{
    80  		transport.WithStream(),
    81  	}
    82  
    83  	if opts.DialTimeout >= 0 {
    84  		dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout, opts.RequestTimeout, 0))
    85  	}
    86  
    87  	conn, err := self.pool.Get(node.Address, dOpts...)
    88  	if err != nil {
    89  		return nil, errors.InternalServerError("volts.client", "connection error: %v", err)
    90  	}
    91  	defer self.pool.Release(conn, nil)
    92  
    93  	// 获取消息载体
    94  	msg := transport.GetMessageFromPool()
    95  	msg.SetMessageType(transport.MT_REQUEST)
    96  	msg.SetSerializeType(self.config.Serialize)
    97  
    98  	// init header
    99  	for k, v := range req.Header() {
   100  		msg.Header[k] = v[0]
   101  	}
   102  	md, ok := metadata.FromContext(ctx)
   103  	if ok {
   104  		for k, v := range md {
   105  			msg.Header[k] = v
   106  		}
   107  	}
   108  
   109  	// set timeout in nanoseconds
   110  	msg.Header["Timeout"] = fmt.Sprintf("%d", opts.RequestTimeout)
   111  	// set the content type for the request
   112  	msg.Header["Content-Type"] = req.ContentType()
   113  	// set the accept header
   114  	msg.Header["Accept"] = req.ContentType()
   115  
   116  	msg.Path = req.Method() // TODO msg 添加server action
   117  	data := req.Body().Data.Bytes()
   118  	if len(data) > 1024 && self.config.CompressType == transport.Gzip {
   119  		data, err = transport.Zip(data)
   120  		if err != nil {
   121  			return nil, err
   122  		}
   123  
   124  		msg.SetCompressType(self.config.CompressType)
   125  	}
   126  
   127  	msg.Payload = data
   128  	//seq := atomic.AddUint64(&self.seq, 1) - 1
   129  	//codec := newRpcCodec(msg, c, cf, "")
   130  
   131  	// 开始发送消息
   132  	// wait for error response
   133  	ch := make(chan error, 1)
   134  	resp := &rpcResponse{}
   135  	go func(resp *rpcResponse) {
   136  		defer func() {
   137  			if r := recover(); r != nil {
   138  				ch <- errors.InternalServerError("volts.client", "panic recovered: %v", r)
   139  			}
   140  		}()
   141  
   142  		// send request
   143  		// 返回编译过的数据
   144  		err := conn.Send(msg)
   145  		if err != nil {
   146  			ch <- err
   147  			return
   148  		}
   149  
   150  		// recv request
   151  		msg = transport.GetMessageFromPool()
   152  		err = conn.Recv(msg)
   153  		if err != nil {
   154  			ch <- err
   155  			return
   156  		}
   157  
   158  		// 状态码处理
   159  		switch msg.MessageStatusType() {
   160  		case transport.StatusOK:
   161  			break
   162  		case transport.StatusError:
   163  			ch <- errors.New("StatusError", int32(transport.StatusError), string(msg.Payload))
   164  			return
   165  		default:
   166  			ch <- errors.New("", int32(msg.MessageStatusType()), string(msg.Payload))
   167  			return
   168  		}
   169  
   170  		bd := body.New(codec.IdentifyCodec(msg.SerializeType()))
   171  		bd.Data.Write(msg.Payload)
   172  		// 解码消息内容
   173  		resp.contentType = msg.SerializeType()
   174  		resp.body = bd // msg.Payload
   175  
   176  		// success
   177  		ch <- nil
   178  	}(resp)
   179  
   180  	err = nil
   181  	select {
   182  	case err := <-ch:
   183  		return resp, err
   184  	case <-ctx.Done():
   185  		err = errors.Timeout("volts.client", fmt.Sprintf("%v", ctx.Err()))
   186  		break
   187  	}
   188  
   189  	// set the stream error
   190  	if err != nil {
   191  		//stream.Lock()
   192  		//stream.err = grr
   193  		//stream.Unlock()
   194  		return nil, err
   195  	}
   196  
   197  	return resp, nil
   198  }
   199  
   200  // 阻塞请求
   201  func (self *RpcClient) Call(request IRequest, opts ...CallOption) (IResponse, error) {
   202  	// make a copy of call opts
   203  	callOpts := self.config.CallOptions
   204  	callOpts.SelectOptions = append(callOpts.SelectOptions, selector.WithFilter(selector.FilterTrasport(self.config.Transport)))
   205  	for _, opt := range opts {
   206  		opt(&callOpts)
   207  	}
   208  
   209  	next, err := self.next(request, callOpts)
   210  	if err != nil {
   211  		return nil, err
   212  	}
   213  
   214  	ctx := callOpts.Context
   215  	if ctx == nil {
   216  		ctx = context.Background()
   217  	}
   218  	// check if we already have a deadline
   219  	d, ok := ctx.Deadline()
   220  	if !ok {
   221  		// no deadline so we create a new one
   222  		var cancel context.CancelFunc
   223  		ctx, cancel = context.WithTimeout(ctx, callOpts.RequestTimeout)
   224  		defer cancel()
   225  	} else {
   226  		// got a deadline so no need to setup context
   227  		// but we need to set the timeout we pass along
   228  		opt := WithRequestTimeout(time.Until(d))
   229  		opt(&callOpts)
   230  	}
   231  
   232  	// should we noop right here?
   233  	select {
   234  	case <-ctx.Done():
   235  		return nil, errors.Timeout("volts.client", fmt.Sprintf("%v", ctx.Err()))
   236  	default:
   237  	}
   238  
   239  	// return errors.New("volts.client", "request timeout", 408)
   240  	call := func(i int, response *IResponse) error {
   241  		// select next node
   242  		// selector 可能因为过滤后得不到合适服务器
   243  		node, err := next()
   244  		if err != nil {
   245  			return err
   246  		}
   247  
   248  		// make the call
   249  		*response, err = self.call(ctx, node, request, callOpts)
   250  		//r.opts.Selector.Mark(service, node, err)
   251  		return err
   252  	}
   253  	var response IResponse
   254  	// get the retries
   255  	retries := callOpts.Retries
   256  	ch := make(chan error, retries+1)
   257  	var gerr error
   258  	for i := 0; i <= retries; i++ {
   259  		go func(i int, response *IResponse) {
   260  			ch <- call(i, response)
   261  		}(i, &response)
   262  
   263  		select {
   264  		case <-ctx.Done():
   265  			return nil, errors.Timeout("volts.client", fmt.Sprintf("call timeout: %v", ctx.Err()))
   266  		case err := <-ch:
   267  			// if the call succeeded lets bail early
   268  			if err == nil {
   269  				return response, nil
   270  			}
   271  
   272  			retry, rerr := callOpts.Retry(ctx, request, i, err)
   273  			if rerr != nil {
   274  				return nil, rerr
   275  			}
   276  
   277  			if !retry {
   278  				return nil, err
   279  			}
   280  
   281  			gerr = err
   282  		}
   283  	}
   284  
   285  	return response, gerr
   286  }
   287  
   288  // next returns an iterator for the next nodes to call
   289  func (r *RpcClient) next(request IRequest, opts CallOptions) (selector.Next, error) {
   290  	// try get the proxy
   291  	service, address, _ := net.Proxy(request.Service(), opts.Address)
   292  
   293  	// return remote address
   294  	if len(address) > 0 {
   295  		nodes := make([]*registry.Node, len(address))
   296  
   297  		for i, addr := range address {
   298  			nodes[i] = &registry.Node{
   299  				Address: addr,
   300  				// Set the protocol
   301  				Metadata: map[string]string{
   302  					"protocol": "mucp",
   303  				},
   304  			}
   305  		}
   306  
   307  		// crude return method
   308  		return func() (*registry.Node, error) {
   309  			return nodes[time.Now().Unix()%int64(len(nodes))], nil
   310  		}, nil
   311  	}
   312  	// only get the things that are of http protocol
   313  	selectOptions := append(opts.SelectOptions, selector.WithFilter(
   314  		selector.FilterLabel("protocol", r.config.Transport.Protocol()),
   315  	))
   316  
   317  	// get next nodes from the selector
   318  	next, err := r.config.Selector.Select(service, selectOptions...)
   319  	if err != nil {
   320  		if err == selector.ErrNotFound {
   321  			return nil, errors.InternalServerError("volts.client", "service %s: %s", service, err.Error())
   322  		}
   323  		return nil, errors.InternalServerError("volts.client", "error selecting %s node: %s", service, err.Error())
   324  	}
   325  
   326  	return next, nil
   327  }
   328  
   329  func (self *RpcClient) String() string {
   330  	return "RpcClient"
   331  }