trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/server_transport_tcp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  package tnet
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"math"
    24  	"net"
    25  	"os"
    26  	"strconv"
    27  	"sync"
    28  	"time"
    29  
    30  	"github.com/panjf2000/ants/v2"
    31  	"trpc.group/trpc-go/tnet"
    32  	"trpc.group/trpc-go/tnet/tls"
    33  	"trpc.group/trpc-go/trpc-go/internal/reuseport"
    34  
    35  	"trpc.group/trpc-go/trpc-go/codec"
    36  	"trpc.group/trpc-go/trpc-go/errs"
    37  	"trpc.group/trpc-go/trpc-go/internal/addrutil"
    38  	"trpc.group/trpc-go/trpc-go/internal/report"
    39  	intertls "trpc.group/trpc-go/trpc-go/internal/tls"
    40  	"trpc.group/trpc-go/trpc-go/log"
    41  	"trpc.group/trpc-go/trpc-go/transport"
    42  	"trpc.group/trpc-go/trpc-go/transport/internal/frame"
    43  )
    44  
    45  type task struct {
    46  	req    []byte
    47  	handle handler
    48  	start  time.Time
    49  }
    50  
    51  type handler = func(req []byte)
    52  
    53  func (t *task) reset() {
    54  	t.req = nil
    55  	t.handle = nil
    56  	t.start = time.Time{}
    57  }
    58  
    59  var taskPool = &sync.Pool{
    60  	New: func() interface{} { return new(task) },
    61  }
    62  
    63  func newTask(req []byte, handle handler) *task {
    64  	t := taskPool.Get().(*task)
    65  	t.req = req
    66  	t.handle = handle
    67  	t.start = time.Now()
    68  	return t
    69  }
    70  
    71  // createRoutinePool creates a goroutines pool to avoid the performance overhead caused
    72  // by frequent creation and destruction of goroutines. It also helps to control the number
    73  // of concurrent goroutines, which can prevent sudden spikes in traffic by implementing
    74  // throttling mechanisms.
    75  func createRoutinePool(size int) *ants.PoolWithFunc {
    76  	if size <= 0 {
    77  		size = math.MaxInt32
    78  	}
    79  	pf := func(args interface{}) {
    80  		t, ok := args.(*task)
    81  		if !ok {
    82  			log.Tracef("routine pool args type error, shouldn't happen!")
    83  			return
    84  		}
    85  		report.TCPServerAsyncGoroutineScheduleDelay.Set(float64(time.Since(t.start).Microseconds()))
    86  		t.handle(t.req)
    87  		t.reset()
    88  		taskPool.Put(t)
    89  	}
    90  	pool, err := ants.NewPoolWithFunc(size, pf)
    91  	if err != nil {
    92  		log.Tracef("routine pool create error: %v", err)
    93  		return nil
    94  	}
    95  	return pool
    96  }
    97  
    98  func (s *serverTransport) getTCPListener(opts *transport.ListenServeOptions) (net.Listener, error) {
    99  	if opts.Listener != nil {
   100  		return opts.Listener, nil
   101  	}
   102  
   103  	// During graceful restart, the relevant information has
   104  	// already been stored in environment variables.
   105  	v, _ := os.LookupEnv(transport.EnvGraceRestart)
   106  	ok, _ := strconv.ParseBool(v)
   107  	if ok {
   108  		pln, err := transport.GetPassedListener(opts.Network, opts.Address)
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		listener, ok := pln.(net.Listener)
   113  		if !ok {
   114  			return nil, errors.New("invalid net.Listener")
   115  		}
   116  		return listener, nil
   117  	}
   118  	var listener net.Listener
   119  	if s.opts.ReusePort {
   120  		var err error
   121  		listener, err = reuseport.Listen(opts.Network, opts.Address)
   122  		if err != nil {
   123  			return nil, fmt.Errorf("%s reuseport error: %w", opts.Network, err)
   124  		}
   125  		return listener, nil
   126  	}
   127  	return tnet.Listen(opts.Network, opts.Address)
   128  }
   129  
   130  func (s *serverTransport) listenAndServeTCP(ctx context.Context, opts *transport.ListenServeOptions) error {
   131  	// Create a goroutine pool if ServerAsync enabled.
   132  	var pool *ants.PoolWithFunc
   133  	if opts.ServerAsync {
   134  		pool = createRoutinePool(opts.Routines)
   135  	}
   136  
   137  	listener, err := s.getTCPListener(opts)
   138  	if err != nil {
   139  		return fmt.Errorf("trpc-tnet-transport get TCP listener fail, %w", err)
   140  	}
   141  	if err := transport.SaveListener(listener); err != nil {
   142  		return fmt.Errorf("save tnet listener failed: %w", err)
   143  	}
   144  
   145  	if opts.TLSCertFile != "" && opts.TLSKeyFile != "" {
   146  		return s.startTLSService(ctx, listener, pool, opts)
   147  	}
   148  	return s.startService(ctx, listener, pool, opts)
   149  }
   150  
   151  func (s *serverTransport) startService(
   152  	ctx context.Context,
   153  	listener net.Listener,
   154  	pool *ants.PoolWithFunc,
   155  	opts *transport.ListenServeOptions,
   156  ) error {
   157  	go func() {
   158  		<-opts.StopListening
   159  		listener.Close()
   160  	}()
   161  	tnetOpts := []tnet.Option{
   162  		tnet.WithOnTCPOpened(func(conn tnet.Conn) error {
   163  			tc := s.onConnOpened(conn, pool, opts)
   164  			conn.SetMetaData(tc)
   165  			return nil
   166  		}),
   167  		tnet.WithOnTCPClosed(func(conn tnet.Conn) error {
   168  			s.onConnClosed(conn, opts.Handler)
   169  			return nil
   170  		}),
   171  		tnet.WithTCPIdleTimeout(opts.IdleTimeout),
   172  		tnet.WithTCPKeepAlive(s.opts.KeepAlivePeriod),
   173  	}
   174  	svr, err := tnet.NewTCPService(
   175  		listener,
   176  		func(conn tnet.Conn) error {
   177  			m := conn.GetMetaData()
   178  			return handleTCP(m)
   179  		},
   180  		tnetOpts...)
   181  	if err != nil {
   182  		return fmt.Errorf("trpc-tnet-transport NewTCPService fail, %w", err)
   183  	}
   184  	go svr.Serve(ctx)
   185  	return nil
   186  }
   187  
   188  func (s *serverTransport) startTLSService(
   189  	ctx context.Context,
   190  	listener net.Listener,
   191  	pool *ants.PoolWithFunc,
   192  	opts *transport.ListenServeOptions,
   193  ) error {
   194  	conf, err := intertls.GetServerConfig(opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile)
   195  	if err != nil {
   196  		return fmt.Errorf("get tls config fail: %w", err)
   197  	}
   198  
   199  	tlsOpts := []tls.ServerOption{
   200  		tls.WithOnOpened(func(conn tls.Conn) error {
   201  			tc := s.onConnOpened(conn, pool, opts)
   202  			conn.SetMetaData(tc)
   203  			return nil
   204  		}),
   205  		tls.WithOnClosed(func(conn tls.Conn) error {
   206  			s.onConnClosed(conn, opts.Handler)
   207  			return nil
   208  		}),
   209  		tls.WithServerTLSConfig(conf),
   210  		tls.WithServerIdleTimeout(opts.IdleTimeout),
   211  		tls.WithTCPKeepAlive(s.opts.KeepAlivePeriod),
   212  	}
   213  	svr, err := tls.NewService(
   214  		listener,
   215  		func(conn tls.Conn) error {
   216  			m := conn.GetMetaData()
   217  			return handleTCP(m)
   218  		},
   219  		tlsOpts...)
   220  	if err != nil {
   221  		return fmt.Errorf("trpc-tnet-transport TLS NewService fail, %w", err)
   222  	}
   223  	go svr.Serve(ctx)
   224  	return nil
   225  }
   226  
   227  // onConnOpened is triggered after a successful connection is established with the client.
   228  func (s *serverTransport) onConnOpened(conn net.Conn, pool *ants.PoolWithFunc,
   229  	opts *transport.ListenServeOptions) *tcpConn {
   230  	tc := &tcpConn{
   231  		rawConn:     conn,
   232  		pool:        pool,
   233  		handler:     opts.Handler,
   234  		serverAsync: opts.ServerAsync,
   235  		framer:      opts.FramerBuilder.New(conn),
   236  	}
   237  	// To avoid overwriting packets, check whether we should copy packages by Framer and some other configurations.
   238  	tc.copyFrame = frame.ShouldCopy(opts.CopyFrame, tc.serverAsync, codec.IsSafeFramer(tc.framer))
   239  
   240  	s.storeConn(addrutil.AddrToKey(conn.LocalAddr(), conn.RemoteAddr()), tc)
   241  	return tc
   242  }
   243  
   244  // onConnClosed is triggered after the connection with the client is closed.
   245  func (s *serverTransport) onConnClosed(conn net.Conn, handler transport.Handler) {
   246  	ctx, msg := codec.WithNewMessage(context.Background())
   247  	msg.WithLocalAddr(conn.LocalAddr())
   248  	msg.WithRemoteAddr(conn.RemoteAddr())
   249  	e := &errs.Error{
   250  		Type: errs.ErrorTypeFramework,
   251  		Code: errs.RetServerSystemErr,
   252  		Desc: "trpc",
   253  		Msg:  "Server connection closed",
   254  	}
   255  	msg.WithServerRspErr(e)
   256  	if closeHandler, ok := handler.(transport.CloseHandler); ok {
   257  		if err := closeHandler.HandleClose(ctx); err != nil {
   258  			log.Trace("transport: notify connection close failed", err)
   259  		}
   260  	}
   261  
   262  	// Release the connection resources stored on the transport.
   263  	s.deleteConn(addrutil.AddrToKey(conn.LocalAddr(), conn.RemoteAddr()))
   264  }
   265  
   266  func handleTCP(conn interface{}) error {
   267  	tc, ok := conn.(*tcpConn)
   268  	if !ok {
   269  		return errors.New("bug: tcpConn type assert fail")
   270  	}
   271  	return tc.onRequest()
   272  }
   273  
   274  type tcpConn struct {
   275  	rawConn     net.Conn
   276  	framer      transport.Framer
   277  	pool        *ants.PoolWithFunc
   278  	handler     transport.Handler
   279  	serverAsync bool
   280  	copyFrame   bool
   281  }
   282  
   283  // onRequest is triggered when there is incoming data on the connection with the client.
   284  func (tc *tcpConn) onRequest() error {
   285  	req, err := tc.framer.ReadFrame()
   286  	if err != nil {
   287  		if err == tnet.ErrConnClosed {
   288  			report.TCPServerTransportReadEOF.Incr()
   289  			return err
   290  		}
   291  		report.TCPServerTransportReadFail.Incr()
   292  		log.Trace("transport: tcpConn onRequest ReadFrame fail ", err)
   293  		return err
   294  	}
   295  	if tc.copyFrame {
   296  		reqCopy := make([]byte, len(req))
   297  		copy(reqCopy, req)
   298  		req = reqCopy
   299  	}
   300  	report.TCPServerTransportReceiveSize.Set(float64(len(req)))
   301  
   302  	if !tc.serverAsync || tc.pool == nil {
   303  		tc.handleSync(req)
   304  		return nil
   305  	}
   306  
   307  	if err := tc.pool.Invoke(newTask(req, tc.handleSync)); err != nil {
   308  		report.TCPServerTransportJobQueueFullFail.Incr()
   309  		log.Trace("transport: tcpConn serve routine pool put job queue fail ", err)
   310  		tc.handleWithErr(req, errs.ErrServerRoutinePoolBusy)
   311  	}
   312  	return nil
   313  }
   314  
   315  func (tc *tcpConn) handleSync(req []byte) {
   316  	tc.handleWithErr(req, nil)
   317  }
   318  
   319  func (tc *tcpConn) handleWithErr(req []byte, e error) {
   320  	ctx, msg := codec.WithNewMessage(context.Background())
   321  	defer codec.PutBackMessage(msg)
   322  	msg.WithServerRspErr(e)
   323  	msg.WithLocalAddr(tc.rawConn.LocalAddr())
   324  	msg.WithRemoteAddr(tc.rawConn.RemoteAddr())
   325  
   326  	rsp, err := tc.handle(ctx, req)
   327  	if err != nil {
   328  		if err != errs.ErrServerNoResponse {
   329  			report.TCPServerTransportHandleFail.Incr()
   330  			log.Trace("transport: tcpConn serve handle fail ", err)
   331  			tc.close()
   332  			return
   333  		}
   334  		return
   335  	}
   336  	report.TCPServerTransportSendSize.Set(float64(len(rsp)))
   337  	if _, err = tc.rawConn.Write(rsp); err != nil {
   338  		report.TCPServerTransportWriteFail.Incr()
   339  		log.Trace("transport: tcpConn write fail ", err)
   340  		tc.close()
   341  		return
   342  	}
   343  }
   344  
   345  func (tc *tcpConn) handle(ctx context.Context, req []byte) ([]byte, error) {
   346  	return tc.handler.Handle(ctx, req)
   347  }
   348  
   349  func (tc *tcpConn) close() {
   350  	if err := tc.rawConn.Close(); err != nil {
   351  		log.Tracef("transport: tcpConn close fail %v", err)
   352  	}
   353  }