
     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    14  package transport
    16  import (
    17  	"context"
    18  	"sync"
    20  	""
    21  	""
    22  	""
    23  )
    25  const (
    26  	defaultMaxConcurrentStreams = 1000
    27  	defaultMaxIdleConnsPerHost  = 2
    28  )
    30  // DefaultClientStreamTransport is the default client stream transport.
    31  var DefaultClientStreamTransport = NewClientStreamTransport()
    33  // NewClientStreamTransport creates a new ClientStreamTransport.
    34  func NewClientStreamTransport(opts ...ClientStreamTransportOption) ClientStreamTransport {
    35  	options := &cstOptions{
    36  		maxConcurrentStreams: defaultMaxConcurrentStreams,
    37  		maxIdleConnsPerHost:  defaultMaxIdleConnsPerHost,
    38  	}
    39  	for _, opt := range opts {
    40  		opt(options)
    41  	}
    42  	t := &clientStreamTransport{
    43  		// Map streamID to connection. On the client side, ensure that the streamID is
    44  		// incremented and unique, otherwise the map of addr must be added.
    45  		streamIDToConn: make(map[uint32]multiplexed.MuxConn),
    46  		m:              &sync.RWMutex{},
    47  		multiplexedPool: multiplexed.New(
    48  			multiplexed.WithMaxVirConnsPerConn(options.maxConcurrentStreams),
    49  			multiplexed.WithMaxIdleConnsPerHost(options.maxIdleConnsPerHost),
    50  		),
    51  	}
    52  	return t
    53  }
    55  // cstOptions is the client stream transport options.
    56  type cstOptions struct {
    57  	maxConcurrentStreams int
    58  	maxIdleConnsPerHost  int
    59  }
    61  // ClientStreamTransportOption sets properties of ClientStreamTransport.
    62  type ClientStreamTransportOption func(*cstOptions)
    64  // WithMaxConcurrentStreams sets the maximum concurrent streams in each TCP connection.
    65  func WithMaxConcurrentStreams(n int) ClientStreamTransportOption {
    66  	return func(opts *cstOptions) {
    67  		opts.maxConcurrentStreams = n
    68  	}
    69  }
    71  // WithMaxIdleConnsPerHost sets the maximum idle connections per host.
    72  func WithMaxIdleConnsPerHost(n int) ClientStreamTransportOption {
    73  	return func(opts *cstOptions) {
    74  		opts.maxIdleConnsPerHost = n
    75  	}
    76  }
    78  // clientStreamTransport keeps compatibility with the original client transport.
    79  type clientStreamTransport struct {
    80  	streamIDToConn  map[uint32]multiplexed.MuxConn
    81  	m               *sync.RWMutex
    82  	multiplexedPool multiplexed.Pool
    83  }
    85  // Init inits clientStreamTransport. It gets a connection from the multiplexing pool. A stream is
    86  // corresponding to a virtual connection, which provides the interface for the stream.
    87  func (c *clientStreamTransport) Init(ctx context.Context, roundTripOpts ...RoundTripOption) error {
    88  	opts, err := c.getOptions(ctx, roundTripOpts...)
    89  	if err != nil {
    90  		return err
    91  	}
    92  	// If ctx has been canceled or timeout, just return.
    93  	if ctx.Err() == context.Canceled {
    94  		return errs.NewFrameError(errs.RetClientCanceled,
    95  			"client canceled before tcp dial: "+ctx.Err().Error())
    96  	}
    97  	if ctx.Err() == context.DeadlineExceeded {
    98  		return errs.NewFrameError(errs.RetClientTimeout,
    99  			"client timeout before tcp dial: "+ctx.Err().Error())
   100  	}
   101  	msg := opts.Msg
   102  	streamID := msg.StreamID()
   104  	getOpts := multiplexed.NewGetOptions()
   105  	getOpts.WithVID(streamID)
   106  	fp, ok := opts.FramerBuilder.(multiplexed.FrameParser)
   107  	if !ok {
   108  		return errs.NewFrameError(errs.RetClientConnectFail,
   109  			"frame builder does not implement multiplexed.FrameParser")
   110  	}
   111  	getOpts.WithFrameParser(fp)
   112  	getOpts.WithDialTLS(opts.TLSCertFile, opts.TLSKeyFile, opts.CACertFile, opts.TLSServerName)
   113  	getOpts.WithLocalAddr(opts.LocalAddr)
   114  	conn, err := opts.Multiplexed.GetMuxConn(ctx, opts.Network, opts.Address, getOpts)
   115  	if err != nil {
   116  		return errs.NewFrameError(errs.RetClientConnectFail,
   117  			"tcp client transport multiplexd pool: "+err.Error())
   118  	}
   119  	msg.WithRemoteAddr(conn.RemoteAddr())
   120  	msg.WithLocalAddr(conn.LocalAddr())
   121  	c.m.Lock()
   122  	c.streamIDToConn[streamID] = conn
   123  	c.m.Unlock()
   124  	return nil
   125  }
   127  // Send sends stream data and provides interface for stream.
   128  func (c *clientStreamTransport) Send(ctx context.Context, req []byte, roundTripOpts ...RoundTripOption) error {
   129  	msg := codec.Message(ctx)
   130  	streamID := msg.StreamID()
   131  	// StreamID is uniquely generated by stream client.
   132  	c.m.RLock()
   133  	cc := c.streamIDToConn[streamID]
   134  	c.m.RUnlock()
   135  	if cc == nil {
   136  		return errs.NewFrameError(errs.RetServerSystemErr, "Connection is Closed")
   137  	}
   138  	if err := cc.Write(req); err != nil {
   139  		return err
   140  	}
   141  	return nil
   142  }
   144  // Recv receives stream data and provides interface for stream.
   145  func (c *clientStreamTransport) Recv(ctx context.Context, roundTripOpts ...RoundTripOption) ([]byte, error) {
   146  	cc, err := c.getConnect(ctx, roundTripOpts...)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   151  	select {
   152  	case <-ctx.Done():
   153  		if ctx.Err() == context.Canceled {
   154  			return nil, errs.NewFrameError(errs.RetClientCanceled,
   155  				"tcp client transport canceled before Write: "+ctx.Err().Error())
   156  		}
   157  		if ctx.Err() == context.DeadlineExceeded {
   158  			return nil, errs.NewFrameError(errs.RetClientTimeout,
   159  				"tcp client transport timeout before Write: "+ctx.Err().Error())
   160  		}
   161  	default:
   162  	}
   163  	return cc.Read()
   164  }
   166  // Close closes connections and cleans up.
   167  func (c *clientStreamTransport) Close(ctx context.Context) {
   168  	msg := codec.Message(ctx)
   169  	streamID := msg.StreamID()
   170  	c.m.Lock()
   171  	defer c.m.Unlock()
   172  	if conn, ok := c.streamIDToConn[streamID]; ok {
   173  		conn.Close()
   174  		delete(c.streamIDToConn, streamID)
   175  	}
   176  }
   178  // getOptions inits RoundTripOptions and does some basic check.
   179  func (c *clientStreamTransport) getOptions(ctx context.Context,
   180  	roundTripOpts ...RoundTripOption) (*RoundTripOptions, error) {
   181  	opts := &RoundTripOptions{
   182  		Multiplexed: c.multiplexedPool,
   183  	}
   185  	// use roundTripOpts to modify opts.
   186  	for _, o := range roundTripOpts {
   187  		o(opts)
   188  	}
   190  	if opts.Multiplexed == nil {
   191  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
   192  			"tcp client transport: multiplexd pool empty")
   193  	}
   195  	if opts.FramerBuilder == nil {
   196  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
   197  			"tcp client transport: framer builder empty")
   198  	}
   200  	if opts.Msg == nil {
   201  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
   202  			"tcp client transport: message empty")
   203  	}
   204  	return opts, nil
   205  }
   207  func (c *clientStreamTransport) getConnect(ctx context.Context,
   208  	roundTripOpts ...RoundTripOption) (multiplexed.MuxConn, error) {
   209  	msg := codec.Message(ctx)
   210  	streamID := msg.StreamID()
   211  	c.m.RLock()
   212  	cc := c.streamIDToConn[streamID]
   213  	c.m.RUnlock()
   214  	if cc == nil {
   215  		return nil, errs.NewFrameError(errs.RetServerSystemErr, "Stream is not inited yet")
   216  	}
   217  	return cc, nil
   218  }