trpc.group/trpc-go/trpc-go@v1.0.3/client/client.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package client is tRPC-Go clientside implementation,
    15  // including network transportation, resolving, routing etc.
    16  package client
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"net"
    22  	"time"
    23  
    24  	"trpc.group/trpc-go/trpc-go/codec"
    25  	"trpc.group/trpc-go/trpc-go/errs"
    26  	"trpc.group/trpc-go/trpc-go/filter"
    27  	"trpc.group/trpc-go/trpc-go/internal/attachment"
    28  	icodec "trpc.group/trpc-go/trpc-go/internal/codec"
    29  	"trpc.group/trpc-go/trpc-go/internal/report"
    30  	"trpc.group/trpc-go/trpc-go/naming/registry"
    31  	"trpc.group/trpc-go/trpc-go/naming/selector"
    32  	"trpc.group/trpc-go/trpc-go/rpcz"
    33  	"trpc.group/trpc-go/trpc-go/transport"
    34  )
    35  
    36  // Client is the interface that initiates RPCs and sends request messages to a server.
    37  type Client interface {
    38  	// Invoke performs a unary RPC.
    39  	Invoke(ctx context.Context, reqBody interface{}, rspBody interface{}, opt ...Option) error
    40  }
    41  
    42  // DefaultClient is the default global client.
    43  // It's thread-safe.
    44  var DefaultClient = New()
    45  
    46  // New creates a client that uses default client transport.
    47  var New = func() Client {
    48  	return &client{}
    49  }
    50  
    51  // client is the default implementation of Client with
    52  // pluggable codec, transport, filter etc.
    53  type client struct{}
    54  
    55  // Invoke invokes a backend call by passing in custom request/response message
    56  // and running selector filter, codec, transport etc.
    57  func (c *client) Invoke(ctx context.Context, reqBody interface{}, rspBody interface{}, opt ...Option) (err error) {
    58  	// The generic message structure data of the current request is retrieved from the context,
    59  	// and each backend call uses a new msg generated by the client stub code.
    60  	ctx, msg := codec.EnsureMessage(ctx)
    61  
    62  	span, end, ctx := rpcz.NewSpanContext(ctx, "client")
    63  
    64  	// Get client options.
    65  	opts, err := c.getOptions(msg, opt...)
    66  	defer func() {
    67  		span.SetAttribute(rpcz.TRPCAttributeRPCName, msg.ClientRPCName())
    68  		if err == nil {
    69  			span.SetAttribute(rpcz.TRPCAttributeError, msg.ClientRspErr())
    70  		} else {
    71  			span.SetAttribute(rpcz.TRPCAttributeError, err)
    72  		}
    73  		end.End()
    74  	}()
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	// Update Msg by options.
    80  	c.updateMsg(msg, opts)
    81  
    82  	fullLinkDeadline, ok := ctx.Deadline()
    83  	if opts.Timeout > 0 {
    84  		var cancel context.CancelFunc
    85  		ctx, cancel = context.WithTimeout(ctx, opts.Timeout)
    86  		defer cancel()
    87  	}
    88  	if deadline, ok := ctx.Deadline(); ok {
    89  		msg.WithRequestTimeout(deadline.Sub(time.Now()))
    90  	}
    91  	if ok && (opts.Timeout <= 0 || time.Until(fullLinkDeadline) < opts.Timeout) {
    92  		opts.fixTimeout = mayConvert2FullLinkTimeout
    93  	}
    94  
    95  	// Start filter chain processing.
    96  	filters := c.fixFilters(opts)
    97  	span.SetAttribute(rpcz.TRPCAttributeFilterNames, opts.FilterNames)
    98  	return filters.Filter(contextWithOptions(ctx, opts), reqBody, rspBody, callFunc)
    99  }
   100  
   101  // getOptions returns Options needed by each RPC.
   102  func (c *client) getOptions(msg codec.Msg, opt ...Option) (*Options, error) {
   103  	opts := getOptionsByCalleeAndUserOptions(msg.CalleeServiceName(), opt...).clone()
   104  
   105  	// Set service info options.
   106  	opts.SelectOptions = append(opts.SelectOptions, c.getServiceInfoOptions(msg)...)
   107  
   108  	// The given input options have the highest priority
   109  	// and they will override the original ones.
   110  	for _, o := range opt {
   111  		o(opts)
   112  	}
   113  
   114  	if err := opts.parseTarget(); err != nil {
   115  		return nil, errs.NewFrameError(errs.RetClientRouteErr, err.Error())
   116  	}
   117  	return opts, nil
   118  }
   119  
   120  // getServiceInfoOptions returns service info options.
   121  func (c *client) getServiceInfoOptions(msg codec.Msg) []selector.Option {
   122  	if msg.Namespace() != "" {
   123  		return []selector.Option{
   124  			selector.WithSourceNamespace(msg.Namespace()),
   125  			selector.WithSourceServiceName(msg.CallerServiceName()),
   126  			selector.WithSourceEnvName(msg.EnvName()),
   127  			selector.WithEnvTransfer(msg.EnvTransfer()),
   128  			selector.WithSourceSetName(msg.SetName()),
   129  		}
   130  	}
   131  	return nil
   132  }
   133  
   134  // updateMsg updates msg.
   135  func (c *client) updateMsg(msg codec.Msg, opts *Options) {
   136  	// Set callee service name.
   137  	// Generally, service name is the same as the package.service defined in proto file,
   138  	// but it can be customized by options.
   139  	if opts.ServiceName != "" {
   140  		// From client's perspective, caller refers to itself, callee refers to the backend service.
   141  		msg.WithCalleeServiceName(opts.ServiceName)
   142  	}
   143  	if opts.endpoint == "" {
   144  		// If endpoint is not configured, DefaultSelector (generally polaris)
   145  		// will be used to address callee service name.
   146  		opts.endpoint = msg.CalleeServiceName()
   147  	}
   148  	if opts.CalleeMethod != "" {
   149  		msg.WithCalleeMethod(opts.CalleeMethod)
   150  	}
   151  
   152  	// Set metadata.
   153  	if len(opts.MetaData) > 0 {
   154  		msg.WithClientMetaData(c.getMetaData(msg, opts))
   155  	}
   156  
   157  	// Set caller service name if needed.
   158  	if opts.CallerServiceName != "" {
   159  		msg.WithCallerServiceName(opts.CallerServiceName)
   160  	}
   161  	if icodec.IsValidSerializationType(opts.SerializationType) {
   162  		msg.WithSerializationType(opts.SerializationType)
   163  	}
   164  	if icodec.IsValidCompressType(opts.CompressType) && opts.CompressType != codec.CompressTypeNoop {
   165  		msg.WithCompressType(opts.CompressType)
   166  	}
   167  
   168  	// Set client req head if needed.
   169  	if opts.ReqHead != nil {
   170  		msg.WithClientReqHead(opts.ReqHead)
   171  	}
   172  	// Set client rsp head if needed.
   173  	if opts.RspHead != nil {
   174  		msg.WithClientRspHead(opts.RspHead)
   175  	}
   176  
   177  	msg.WithCallType(opts.CallType)
   178  
   179  	if opts.attachment != nil {
   180  		setAttachment(msg, opts.attachment)
   181  	}
   182  }
   183  
   184  // SetAttachment sets attachment to msg.
   185  func setAttachment(msg codec.Msg, attm *attachment.Attachment) {
   186  	cm := msg.CommonMeta()
   187  	if cm == nil {
   188  		cm = make(codec.CommonMeta)
   189  		msg.WithCommonMeta(cm)
   190  	}
   191  	cm[attachment.ClientAttachmentKey{}] = attm
   192  }
   193  
   194  // getMetaData returns metadata that will be transparently transmitted to the backend service.
   195  func (c *client) getMetaData(msg codec.Msg, opts *Options) codec.MetaData {
   196  	md := msg.ClientMetaData()
   197  	if md == nil {
   198  		md = codec.MetaData{}
   199  	}
   200  	for k, v := range opts.MetaData {
   201  		md[k] = v
   202  	}
   203  	return md
   204  }
   205  
   206  func (c *client) fixFilters(opts *Options) filter.ClientChain {
   207  	if opts.DisableFilter || len(opts.Filters) == 0 {
   208  		// All filters but selector filter are disabled.
   209  		opts.FilterNames = append(opts.FilterNames, DefaultSelectorFilterName)
   210  		return filter.ClientChain{selectorFilter}
   211  	}
   212  	if !opts.selectorFilterPosFixed {
   213  		// Selector filter pos is not fixed, append it to the filter chain.
   214  		opts.Filters = append(opts.Filters, selectorFilter)
   215  		opts.FilterNames = append(opts.FilterNames, DefaultSelectorFilterName)
   216  	}
   217  	return opts.Filters
   218  }
   219  
   220  // callFunc is the function that calls the backend service with
   221  // codec encoding/decoding and network transportation.
   222  // Filters executed before this function are called prev filters. Filters executed after
   223  // this function are called post filters.
   224  func callFunc(ctx context.Context, reqBody interface{}, rspBody interface{}) (err error) {
   225  	msg := codec.Message(ctx)
   226  	opts := OptionsFromContext(ctx)
   227  
   228  	defer func() { err = opts.fixTimeout(err) }()
   229  
   230  	// Check if codec is empty, after updating msg.
   231  	if opts.Codec == nil {
   232  		report.ClientCodecEmpty.Incr()
   233  		return errs.NewFrameError(errs.RetClientEncodeFail, "client: codec empty")
   234  	}
   235  
   236  	reqBuf, err := prepareRequestBuf(ctx, msg, reqBody, opts)
   237  	if err != nil {
   238  		return err
   239  	}
   240  
   241  	// Call backend service.
   242  	if opts.EnableMultiplexed {
   243  		opts.CallOptions = append(opts.CallOptions, transport.WithMsg(msg), transport.WithMultiplexed(true))
   244  	}
   245  	rspBuf, err := opts.Transport.RoundTrip(ctx, reqBuf, opts.CallOptions...)
   246  	if err != nil {
   247  		if err == errs.ErrClientNoResponse { // Sendonly mode, no response, just return nil.
   248  			return nil
   249  		}
   250  		return err
   251  	}
   252  
   253  	span := rpcz.SpanFromContext(ctx)
   254  	span.SetAttribute(rpcz.TRPCAttributeResponseSize, len(rspBuf))
   255  	_, end := span.NewChild("DecodeProtocolHead")
   256  	rspBodyBuf, err := opts.Codec.Decode(msg, rspBuf)
   257  	end.End()
   258  	if err != nil {
   259  		return errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decode: "+err.Error())
   260  	}
   261  
   262  	return processResponseBuf(ctx, msg, rspBody, rspBodyBuf, opts)
   263  }
   264  
   265  func prepareRequestBuf(
   266  	ctx context.Context,
   267  	msg codec.Msg,
   268  	reqBody interface{},
   269  	opts *Options,
   270  ) ([]byte, error) {
   271  	reqBodyBuf, err := serializeAndCompress(ctx, msg, reqBody, opts)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  
   276  	// Encode the whole reqBodyBuf.
   277  	span := rpcz.SpanFromContext(ctx)
   278  	_, end := span.NewChild("EncodeProtocolHead")
   279  	reqBuf, err := opts.Codec.Encode(msg, reqBodyBuf)
   280  	end.End()
   281  	span.SetAttribute(rpcz.TRPCAttributeRequestSize, len(reqBuf))
   282  	if err != nil {
   283  		return nil, errs.NewFrameError(errs.RetClientEncodeFail, "client codec Encode: "+err.Error())
   284  	}
   285  
   286  	return reqBuf, nil
   287  }
   288  
   289  func processResponseBuf(
   290  	ctx context.Context,
   291  	msg codec.Msg,
   292  	rspBody interface{},
   293  	rspBodyBuf []byte,
   294  	opts *Options,
   295  ) error {
   296  	// Error from response.
   297  	if msg.ClientRspErr() != nil {
   298  		return msg.ClientRspErr()
   299  	}
   300  
   301  	if len(rspBodyBuf) == 0 {
   302  		return nil
   303  	}
   304  
   305  	// Decompress.
   306  	span := rpcz.SpanFromContext(ctx)
   307  	_, end := span.NewChild("Decompress")
   308  	compressType := msg.CompressType()
   309  	if icodec.IsValidCompressType(opts.CurrentCompressType) {
   310  		compressType = opts.CurrentCompressType
   311  	}
   312  	var err error
   313  	if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop {
   314  		rspBodyBuf, err = codec.Decompress(compressType, rspBodyBuf)
   315  	}
   316  	end.End()
   317  	if err != nil {
   318  		return errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decompress: "+err.Error())
   319  	}
   320  
   321  	// unmarshal rspBodyBuf to rspBody.
   322  	_, end = span.NewChild("Unmarshal")
   323  	serializationType := msg.SerializationType()
   324  	if icodec.IsValidSerializationType(opts.CurrentSerializationType) {
   325  		serializationType = opts.CurrentSerializationType
   326  	}
   327  	if icodec.IsValidSerializationType(serializationType) {
   328  		err = codec.Unmarshal(serializationType, rspBodyBuf, rspBody)
   329  	}
   330  
   331  	end.End()
   332  	if err != nil {
   333  		return errs.NewFrameError(errs.RetClientDecodeFail, "client codec Unmarshal: "+err.Error())
   334  	}
   335  
   336  	return nil
   337  }
   338  
   339  // serializeAndCompress serializes and compresses reqBody.
   340  func serializeAndCompress(ctx context.Context, msg codec.Msg, reqBody interface{}, opts *Options) ([]byte, error) {
   341  	// Marshal reqBody into binary body.
   342  	span := rpcz.SpanFromContext(ctx)
   343  	_, end := span.NewChild("Marshal")
   344  	serializationType := msg.SerializationType()
   345  	if icodec.IsValidSerializationType(opts.CurrentSerializationType) {
   346  		serializationType = opts.CurrentSerializationType
   347  	}
   348  	var (
   349  		reqBodyBuf []byte
   350  		err        error
   351  	)
   352  	if icodec.IsValidSerializationType(serializationType) {
   353  		reqBodyBuf, err = codec.Marshal(serializationType, reqBody)
   354  	}
   355  	end.End()
   356  	if err != nil {
   357  		return nil, errs.NewFrameError(errs.RetClientEncodeFail, "client codec Marshal: "+err.Error())
   358  	}
   359  
   360  	// Compress.
   361  	_, end = span.NewChild("Compress")
   362  	compressType := msg.CompressType()
   363  	if icodec.IsValidCompressType(opts.CurrentCompressType) {
   364  		compressType = opts.CurrentCompressType
   365  	}
   366  	if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop {
   367  		reqBodyBuf, err = codec.Compress(compressType, reqBodyBuf)
   368  	}
   369  	end.End()
   370  	if err != nil {
   371  		return nil, errs.NewFrameError(errs.RetClientEncodeFail, "client codec Compress: "+err.Error())
   372  	}
   373  	return reqBodyBuf, nil
   374  }
   375  
   376  // -------------------------------- client selector filter ------------------------------------- //
   377  
   378  // selectorFilter is the client selector filter.
   379  func selectorFilter(ctx context.Context, req interface{}, rsp interface{}, next filter.ClientHandleFunc) error {
   380  	msg := codec.Message(ctx)
   381  	opts := OptionsFromContext(ctx)
   382  	if IsOptionsImmutable(ctx) { // Check if options are immutable.
   383  		// The retry plugin will start multiple goroutines to process this filter concurrently,
   384  		// and will set the options to be immutable. Therefore, the original opts cannot be modified directly,
   385  		// and it is necessary to clone new opts.
   386  		opts = opts.clone()
   387  		opts.rebuildSliceCapacity()
   388  		ctx = contextWithOptions(ctx, opts)
   389  	}
   390  
   391  	// Select a node of the backend service.
   392  	node, err := selectNode(ctx, msg, opts)
   393  	if err != nil {
   394  		return OptionsFromContext(ctx).fixTimeout(err)
   395  	}
   396  	ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr)
   397  
   398  	// Start to process the next filter and report.
   399  	begin := time.Now()
   400  	err = next(ctx, req, rsp)
   401  	cost := time.Since(begin)
   402  	if e, ok := err.(*errs.Error); ok &&
   403  		e.Type == errs.ErrorTypeFramework &&
   404  		(e.Code == errs.RetClientConnectFail ||
   405  			e.Code == errs.RetClientTimeout ||
   406  			e.Code == errs.RetClientNetErr) {
   407  		e.Msg = fmt.Sprintf("%s, cost:%s", e.Msg, cost)
   408  		opts.Selector.Report(node, cost, err)
   409  	} else if opts.shouldErrReportToSelector(err) {
   410  		opts.Selector.Report(node, cost, err)
   411  	} else {
   412  		opts.Selector.Report(node, cost, nil)
   413  	}
   414  
   415  	// Transmits node information back to the user.
   416  	if addr := msg.RemoteAddr(); addr != nil {
   417  		opts.Node.set(node, addr.String(), cost)
   418  	} else {
   419  		opts.Node.set(node, node.Address, cost)
   420  	}
   421  	return err
   422  }
   423  
   424  // selectNode selects a backend node by selector related options and sets the msg.
   425  func selectNode(ctx context.Context, msg codec.Msg, opts *Options) (*registry.Node, error) {
   426  	opts.SelectOptions = append(opts.SelectOptions, selector.WithContext(ctx))
   427  	node, err := getNode(opts)
   428  	if err != nil {
   429  		report.SelectNodeFail.Incr()
   430  		return nil, err
   431  	}
   432  
   433  	// Update msg by node config.
   434  	opts.LoadNodeConfig(node)
   435  	msg.WithCalleeContainerName(node.ContainerName)
   436  	msg.WithCalleeSetName(node.SetName)
   437  
   438  	// Set current env info as environment message for transfer only if
   439  	// env info from upstream service is not set.
   440  	if msg.EnvTransfer() == "" {
   441  		msg.WithEnvTransfer(node.EnvKey)
   442  	}
   443  
   444  	// If service router is disabled, env info should be cleared.
   445  	if opts.DisableServiceRouter {
   446  		msg.WithEnvTransfer("")
   447  	}
   448  
   449  	// Selector might block for a while, need to check if ctx is still available.
   450  	if ctx.Err() == context.Canceled {
   451  		return nil, errs.NewFrameError(errs.RetClientCanceled,
   452  			"selector canceled after Select: "+ctx.Err().Error())
   453  	}
   454  	if ctx.Err() == context.DeadlineExceeded {
   455  		return nil, errs.NewFrameError(errs.RetClientTimeout,
   456  			"selector timeout after Select: "+ctx.Err().Error())
   457  	}
   458  
   459  	return node, nil
   460  }
   461  
   462  func getNode(opts *Options) (*registry.Node, error) {
   463  	// Select node.
   464  	node, err := opts.Selector.Select(opts.endpoint, opts.SelectOptions...)
   465  	if err != nil {
   466  		return nil, errs.NewFrameError(errs.RetClientRouteErr, "client Select: "+err.Error())
   467  	}
   468  	if node.Address == "" {
   469  		return nil, errs.NewFrameError(errs.RetClientRouteErr, fmt.Sprintf("client Select: node address empty:%+v", node))
   470  	}
   471  	return node, nil
   472  }
   473  
   474  func ensureMsgRemoteAddr(
   475  	msg codec.Msg,
   476  	network, address string,
   477  	parseAddr func(network, address string) net.Addr,
   478  ) {
   479  	// If RemoteAddr has already been set, just return.
   480  	if msg.RemoteAddr() != nil {
   481  		return
   482  	}
   483  
   484  	if parseAddr != nil {
   485  		msg.WithRemoteAddr(parseAddr(network, address))
   486  		return
   487  	}
   488  
   489  	switch network {
   490  	case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
   491  		// Check if address can be parsed as an ip.
   492  		host, _, err := net.SplitHostPort(address)
   493  		if err != nil || net.ParseIP(host) == nil {
   494  			return
   495  		}
   496  	}
   497  	var addr net.Addr
   498  	switch network {
   499  	case "tcp", "tcp4", "tcp6":
   500  		addr, _ = net.ResolveTCPAddr(network, address)
   501  	case "udp", "udp4", "udp6":
   502  		addr, _ = net.ResolveUDPAddr(network, address)
   503  	case "unix":
   504  		addr, _ = net.ResolveUnixAddr(network, address)
   505  	default:
   506  		addr, _ = net.ResolveTCPAddr("tcp4", address)
   507  	}
   508  	msg.WithRemoteAddr(addr)
   509  }