trpc.group/trpc-go/trpc-go@v1.0.2/transport/tnet/server_transport_tcp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  package tnet
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"math"
    24  	"net"
    25  	"os"
    26  	"strconv"
    27  	"sync"
    28  	"time"
    29  
    30  	"github.com/panjf2000/ants/v2"
    31  	"trpc.group/trpc-go/tnet"
    32  	"trpc.group/trpc-go/tnet/tls"
    33  	"trpc.group/trpc-go/trpc-go/internal/reuseport"
    34  
    35  	"trpc.group/trpc-go/trpc-go/codec"
    36  	"trpc.group/trpc-go/trpc-go/errs"
    37  	"trpc.group/trpc-go/trpc-go/internal/addrutil"
    38  	"trpc.group/trpc-go/trpc-go/internal/report"
    39  	intertls "trpc.group/trpc-go/trpc-go/internal/tls"
    40  	"trpc.group/trpc-go/trpc-go/log"
    41  	"trpc.group/trpc-go/trpc-go/transport"
    42  	"trpc.group/trpc-go/trpc-go/transport/internal/frame"
    43  )
    44  
    45  type task struct {
    46  	req    []byte
    47  	handle handler
    48  	start  time.Time
    49  }
    50  
    51  type handler = func(req []byte)
    52  
    53  func (t *task) reset() {
    54  	t.req = nil
    55  	t.handle = nil
    56  	t.start = time.Time{}
    57  }
    58  
    59  var taskPool = &sync.Pool{
    60  	New: func() interface{} { return new(task) },
    61  }
    62  
    63  func newTask(req []byte, handle handler) *task {
    64  	t := taskPool.Get().(*task)
    65  	t.req = req
    66  	t.handle = handle
    67  	t.start = time.Now()
    68  	return t
    69  }
    70  
    71  // createRoutinePool creates a goroutines pool to avoid the performance overhead caused
    72  // by frequent creation and destruction of goroutines. It also helps to control the number
    73  // of concurrent goroutines, which can prevent sudden spikes in traffic by implementing
    74  // throttling mechanisms.
    75  func createRoutinePool(size int) *ants.PoolWithFunc {
    76  	if size <= 0 {
    77  		size = math.MaxInt32
    78  	}
    79  	pf := func(args interface{}) {
    80  		t, ok := args.(*task)
    81  		if !ok {
    82  			log.Tracef("routine pool args type error, shouldn't happen!")
    83  			return
    84  		}
    85  		report.TCPServerAsyncGoroutineScheduleDelay.Set(float64(time.Since(t.start).Microseconds()))
    86  		t.handle(t.req)
    87  		t.reset()
    88  		taskPool.Put(t)
    89  	}
    90  	pool, err := ants.NewPoolWithFunc(size, pf)
    91  	if err != nil {
    92  		log.Tracef("routine pool create error: %v", err)
    93  		return nil
    94  	}
    95  	return pool
    96  }
    97  
    98  func (s *serverTransport) getTCPListener(opts *transport.ListenServeOptions) (net.Listener, error) {
    99  	if opts.Listener != nil {
   100  		return opts.Listener, nil
   101  	}
   102  
   103  	// During graceful restart, the relevant information has
   104  	// already been stored in environment variables.
   105  	v, _ := os.LookupEnv(transport.EnvGraceRestart)
   106  	ok, _ := strconv.ParseBool(v)
   107  	if ok {
   108  		pln, err := transport.GetPassedListener(opts.Network, opts.Address)
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		listener, ok := pln.(net.Listener)
   113  		if !ok {
   114  			return nil, errors.New("invalid net.Listener")
   115  		}
   116  		return listener, nil
   117  	}
   118  	var listener net.Listener
   119  	if s.opts.ReusePort {
   120  		var err error
   121  		listener, err = reuseport.Listen(opts.Network, opts.Address)
   122  		if err != nil {
   123  			return nil, fmt.Errorf("%s reuseport error: %w", opts.Network, err)
   124  		}
   125  		return listener, nil
   126  	}
   127  	return tnet.Listen(opts.Network, opts.Address)
   128  }
   129  
   130  func (s *serverTransport) listenAndServeTCP(ctx context.Context, opts *transport.ListenServeOptions) error {
   131  	// Create a goroutine pool if ServerAsync enabled.
   132  	var pool *ants.PoolWithFunc
   133  	if opts.ServerAsync {
   134  		pool = createRoutinePool(opts.Routines)
   135  	}
   136  
   137  	listener, err := s.getTCPListener(opts)
   138  	if err != nil {
   139  		return fmt.Errorf("trpc-tnet-transport get TCP listener fail, %w", err)
   140  	}
   141  	if err := transport.SaveListener(listener); err != nil {
   142  		return fmt.Errorf("save tnet listener failed: %w", err)
   143  	}
   144  
   145  	if opts.TLSCertFile != "" && opts.TLSKeyFile != "" {
   146  		return s.startTLSService(ctx, listener, pool, opts)
   147  	}
   148  	return s.startService(ctx, listener, pool, opts)
   149  }
   150  
   151  func (s *serverTransport) startService(
   152  	ctx context.Context,
   153  	listener net.Listener,
   154  	pool *ants.PoolWithFunc,
   155  	opts *transport.ListenServeOptions,
   156  ) error {
   157  	tnetOpts := []tnet.Option{
   158  		tnet.WithOnTCPOpened(func(conn tnet.Conn) error {
   159  			tc := s.onConnOpened(conn, pool, opts)
   160  			conn.SetMetaData(tc)
   161  			return nil
   162  		}),
   163  		tnet.WithOnTCPClosed(func(conn tnet.Conn) error {
   164  			s.onConnClosed(conn, opts.Handler)
   165  			return nil
   166  		}),
   167  		tnet.WithTCPIdleTimeout(opts.IdleTimeout),
   168  		tnet.WithTCPKeepAlive(s.opts.KeepAlivePeriod),
   169  	}
   170  	svr, err := tnet.NewTCPService(
   171  		listener,
   172  		func(conn tnet.Conn) error {
   173  			m := conn.GetMetaData()
   174  			return handleTCP(m)
   175  		},
   176  		tnetOpts...)
   177  	if err != nil {
   178  		return fmt.Errorf("trpc-tnet-transport NewTCPService fail, %w", err)
   179  	}
   180  	go svr.Serve(ctx)
   181  	return nil
   182  }
   183  
   184  func (s *serverTransport) startTLSService(
   185  	ctx context.Context,
   186  	listener net.Listener,
   187  	pool *ants.PoolWithFunc,
   188  	opts *transport.ListenServeOptions,
   189  ) error {
   190  	conf, err := intertls.GetServerConfig(opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile)
   191  	if err != nil {
   192  		return fmt.Errorf("get tls config fail: %w", err)
   193  	}
   194  
   195  	tlsOpts := []tls.ServerOption{
   196  		tls.WithOnOpened(func(conn tls.Conn) error {
   197  			tc := s.onConnOpened(conn, pool, opts)
   198  			conn.SetMetaData(tc)
   199  			return nil
   200  		}),
   201  		tls.WithOnClosed(func(conn tls.Conn) error {
   202  			s.onConnClosed(conn, opts.Handler)
   203  			return nil
   204  		}),
   205  		tls.WithServerTLSConfig(conf),
   206  		tls.WithServerIdleTimeout(opts.IdleTimeout),
   207  		tls.WithTCPKeepAlive(s.opts.KeepAlivePeriod),
   208  	}
   209  	svr, err := tls.NewService(
   210  		listener,
   211  		func(conn tls.Conn) error {
   212  			m := conn.GetMetaData()
   213  			return handleTCP(m)
   214  		},
   215  		tlsOpts...)
   216  	if err != nil {
   217  		return fmt.Errorf("trpc-tnet-transport TLS NewService fail, %w", err)
   218  	}
   219  	go svr.Serve(ctx)
   220  	return nil
   221  }
   222  
   223  // onConnOpened is triggered after a successful connection is established with the client.
   224  func (s *serverTransport) onConnOpened(conn net.Conn, pool *ants.PoolWithFunc,
   225  	opts *transport.ListenServeOptions) *tcpConn {
   226  	tc := &tcpConn{
   227  		rawConn:     conn,
   228  		pool:        pool,
   229  		handler:     opts.Handler,
   230  		serverAsync: opts.ServerAsync,
   231  		framer:      opts.FramerBuilder.New(conn),
   232  	}
   233  	// To avoid overwriting packets, check whether we should copy packages by Framer and some other configurations.
   234  	tc.copyFrame = frame.ShouldCopy(opts.CopyFrame, tc.serverAsync, codec.IsSafeFramer(tc.framer))
   235  
   236  	s.storeConn(addrutil.AddrToKey(conn.LocalAddr(), conn.RemoteAddr()), tc)
   237  	return tc
   238  }
   239  
   240  // onConnClosed is triggered after the connection with the client is closed.
   241  func (s *serverTransport) onConnClosed(conn net.Conn, handler transport.Handler) {
   242  	ctx, msg := codec.WithNewMessage(context.Background())
   243  	msg.WithLocalAddr(conn.LocalAddr())
   244  	msg.WithRemoteAddr(conn.RemoteAddr())
   245  	e := &errs.Error{
   246  		Type: errs.ErrorTypeFramework,
   247  		Code: errs.RetServerSystemErr,
   248  		Desc: "trpc",
   249  		Msg:  "Server connection closed",
   250  	}
   251  	msg.WithServerRspErr(e)
   252  	if closeHandler, ok := handler.(transport.CloseHandler); ok {
   253  		if err := closeHandler.HandleClose(ctx); err != nil {
   254  			log.Trace("transport: notify connection close failed", err)
   255  		}
   256  	}
   257  
   258  	// Release the connection resources stored on the transport.
   259  	s.deleteConn(addrutil.AddrToKey(conn.LocalAddr(), conn.RemoteAddr()))
   260  }
   261  
   262  func handleTCP(conn interface{}) error {
   263  	tc, ok := conn.(*tcpConn)
   264  	if !ok {
   265  		return errors.New("bug: tcpConn type assert fail")
   266  	}
   267  	return tc.onRequest()
   268  }
   269  
   270  type tcpConn struct {
   271  	rawConn     net.Conn
   272  	framer      transport.Framer
   273  	pool        *ants.PoolWithFunc
   274  	handler     transport.Handler
   275  	serverAsync bool
   276  	copyFrame   bool
   277  }
   278  
   279  // onRequest is triggered when there is incoming data on the connection with the client.
   280  func (tc *tcpConn) onRequest() error {
   281  	req, err := tc.framer.ReadFrame()
   282  	if err != nil {
   283  		if err == tnet.ErrConnClosed {
   284  			report.TCPServerTransportReadEOF.Incr()
   285  			return err
   286  		}
   287  		report.TCPServerTransportReadFail.Incr()
   288  		log.Trace("transport: tcpConn onRequest ReadFrame fail ", err)
   289  		return err
   290  	}
   291  	if tc.copyFrame {
   292  		reqCopy := make([]byte, len(req))
   293  		copy(reqCopy, req)
   294  		req = reqCopy
   295  	}
   296  	report.TCPServerTransportReceiveSize.Set(float64(len(req)))
   297  
   298  	if !tc.serverAsync || tc.pool == nil {
   299  		tc.handleSync(req)
   300  		return nil
   301  	}
   302  
   303  	if err := tc.pool.Invoke(newTask(req, tc.handleSync)); err != nil {
   304  		report.TCPServerTransportJobQueueFullFail.Incr()
   305  		log.Trace("transport: tcpConn serve routine pool put job queue fail ", err)
   306  		tc.handleWithErr(req, errs.ErrServerRoutinePoolBusy)
   307  	}
   308  	return nil
   309  }
   310  
   311  func (tc *tcpConn) handleSync(req []byte) {
   312  	tc.handleWithErr(req, nil)
   313  }
   314  
   315  func (tc *tcpConn) handleWithErr(req []byte, e error) {
   316  	ctx, msg := codec.WithNewMessage(context.Background())
   317  	defer codec.PutBackMessage(msg)
   318  	msg.WithServerRspErr(e)
   319  	msg.WithLocalAddr(tc.rawConn.LocalAddr())
   320  	msg.WithRemoteAddr(tc.rawConn.RemoteAddr())
   321  
   322  	rsp, err := tc.handle(ctx, req)
   323  	if err != nil {
   324  		if err != errs.ErrServerNoResponse {
   325  			report.TCPServerTransportHandleFail.Incr()
   326  			log.Trace("transport: tcpConn serve handle fail ", err)
   327  			tc.close()
   328  			return
   329  		}
   330  		return
   331  	}
   332  	report.TCPServerTransportSendSize.Set(float64(len(rsp)))
   333  	if _, err = tc.rawConn.Write(rsp); err != nil {
   334  		report.TCPServerTransportWriteFail.Incr()
   335  		log.Trace("transport: tcpConn write fail ", err)
   336  		tc.close()
   337  		return
   338  	}
   339  }
   340  
   341  func (tc *tcpConn) handle(ctx context.Context, req []byte) ([]byte, error) {
   342  	return tc.handler.Handle(ctx, req)
   343  }
   344  
   345  func (tc *tcpConn) close() {
   346  	if err := tc.rawConn.Close(); err != nil {
   347  		log.Tracef("transport: tcpConn close fail %v", err)
   348  	}
   349  }