trpc.group/trpc-go/trpc-go@v1.0.3/pool/connpool/pool.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package connpool provides the connection pool.
    15  package connpool
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"io"
    21  	"net"
    22  	"time"
    23  
    24  	"trpc.group/trpc-go/trpc-go/codec"
    25  	"trpc.group/trpc-go/trpc-go/errs"
    26  	intertls "trpc.group/trpc-go/trpc-go/internal/tls"
    27  )
    28  
    29  // GetOptions is the get conn configuration.
    30  type GetOptions struct {
    31  	FramerBuilder codec.FramerBuilder
    32  	CustomReader  func(io.Reader) io.Reader
    33  	Ctx           context.Context
    34  
    35  	CACertFile    string // ca certificate.
    36  	TLSCertFile   string // client certificate.
    37  	TLSKeyFile    string // client secret key.
    38  	TLSServerName string // The client verifies the server's service name,
    39  	// if not filled in, it defaults to the http hostname.
    40  
    41  	LocalAddr   string        // The local address when establishing a connection, which is randomly selected by default.
    42  	DialTimeout time.Duration // Connection establishment timeout.
    43  	Protocol    string        // protocol type.
    44  }
    45  
    46  func (o *GetOptions) getDialCtx(dialTimeout time.Duration) (context.Context, context.CancelFunc) {
    47  	ctx := o.Ctx
    48  	defer func() {
    49  		// opts.Ctx is only used to pass ctx parameters, ctx is not recommended to be held by data structures.
    50  		o.Ctx = nil
    51  	}()
    52  
    53  	for {
    54  		// If the RPC request does not set ctx, create a new ctx.
    55  		if ctx == nil {
    56  			break
    57  		}
    58  		// If the RPC request does not set the ctx timeout, create a new ctx.
    59  		deadline, ok := ctx.Deadline()
    60  		if !ok {
    61  			break
    62  		}
    63  		// If the RPC request timeout is greater than the set timeout, create a new ctx.
    64  		d := time.Until(deadline)
    65  		if o.DialTimeout > 0 && o.DialTimeout < d {
    66  			break
    67  		}
    68  		return ctx, nil
    69  	}
    70  
    71  	if o.DialTimeout > 0 {
    72  		dialTimeout = o.DialTimeout
    73  	}
    74  	if dialTimeout == 0 {
    75  		dialTimeout = defaultDialTimeout
    76  	}
    77  	return context.WithTimeout(context.Background(), dialTimeout)
    78  }
    79  
    80  // NewGetOptions creates and initializes GetOptions.
    81  func NewGetOptions() GetOptions {
    82  	return GetOptions{
    83  		CustomReader: codec.NewReader,
    84  	}
    85  }
    86  
    87  // WithFramerBuilder returns an Option which sets the FramerBuilder.
    88  func (o *GetOptions) WithFramerBuilder(fb codec.FramerBuilder) {
    89  	o.FramerBuilder = fb
    90  }
    91  
    92  // WithDialTLS returns an Option which sets the client to support TLS.
    93  func (o *GetOptions) WithDialTLS(certFile, keyFile, caFile, serverName string) {
    94  	o.TLSCertFile = certFile
    95  	o.TLSKeyFile = keyFile
    96  	o.CACertFile = caFile
    97  	o.TLSServerName = serverName
    98  }
    99  
   100  // WithContext returns an Option which sets the requested ctx.
   101  func (o *GetOptions) WithContext(ctx context.Context) {
   102  	o.Ctx = ctx
   103  }
   104  
   105  // WithLocalAddr returns an Option which sets the local address when establishing a connection,
   106  // and it is randomly selected by default when there are multiple network cards.
   107  func (o *GetOptions) WithLocalAddr(addr string) {
   108  	o.LocalAddr = addr
   109  }
   110  
   111  // WithDialTimeout returns an Option which sets the connection timeout.
   112  func (o *GetOptions) WithDialTimeout(dur time.Duration) {
   113  	o.DialTimeout = dur
   114  }
   115  
   116  // WithProtocol returns an Option which sets the backend service protocol name.
   117  func (o *GetOptions) WithProtocol(s string) {
   118  	o.Protocol = s
   119  }
   120  
   121  // WithCustomReader returns an option which sets a customReader. Connection pool will uses this customReader
   122  // to create a reader encapsulating the underlying connection, which is usually used to create a buffer.
   123  func (o *GetOptions) WithCustomReader(customReader func(io.Reader) io.Reader) {
   124  	o.CustomReader = customReader
   125  }
   126  
   127  // Pool is the interface that specifies client connection pool options.
   128  // Compared with Pool, Pool directly uses the GetOptions data structure for function input parameters.
   129  // Compared with function option input parameter mode, it can reduce memory escape and improve calling performance.
   130  type Pool interface {
   131  	Get(network string, address string, opt GetOptions) (net.Conn, error)
   132  }
   133  
   134  // DialFunc connects to an endpoint with the information in options.
   135  type DialFunc func(opts *DialOptions) (net.Conn, error)
   136  
   137  // DialOptions request parameters.
   138  type DialOptions struct {
   139  	Network       string
   140  	Address       string
   141  	LocalAddr     string
   142  	Timeout       time.Duration
   143  	CACertFile    string // ca certificate.
   144  	TLSCertFile   string // client certificate.
   145  	TLSKeyFile    string // client secret key.
   146  	TLSServerName string // The client verifies the server's service name,
   147  	// if not filled in, it defaults to the http hostname.
   148  	IdleTimeout time.Duration
   149  }
   150  
   151  // Dial initiates the request.
   152  func Dial(opts *DialOptions) (net.Conn, error) {
   153  	var localAddr net.Addr
   154  	if opts.LocalAddr != "" {
   155  		var err error
   156  		localAddr, err = net.ResolveTCPAddr(opts.Network, opts.LocalAddr)
   157  		if err != nil {
   158  			return nil, err
   159  		}
   160  	}
   161  	dialer := &net.Dialer{
   162  		Timeout:   opts.Timeout,
   163  		LocalAddr: localAddr,
   164  	}
   165  	if opts.CACertFile == "" {
   166  		return dialer.Dial(opts.Network, opts.Address)
   167  	}
   168  
   169  	if opts.TLSServerName == "" {
   170  		opts.TLSServerName = opts.Address
   171  	}
   172  
   173  	tlsConf, err := intertls.GetClientConfig(opts.TLSServerName, opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile)
   174  	if err != nil {
   175  		return nil, errs.NewFrameError(errs.RetClientDecodeFail, "client dial tls fail: "+err.Error())
   176  	}
   177  	return tls.DialWithDialer(dialer, opts.Network, opts.Address, tlsConf)
   178  }