trpc.group/trpc-go/trpc-go@v1.0.2/transport/server_transport_tcp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport
    15  
    16  import (
    17  	"context"
    18  	"io"
    19  	"math"
    20  	"net"
    21  	"strings"
    22  	"sync"
    23  	"time"
    24  
    25  	"github.com/panjf2000/ants/v2"
    26  
    27  	"trpc.group/trpc-go/trpc-go/codec"
    28  	"trpc.group/trpc-go/trpc-go/errs"
    29  	"trpc.group/trpc-go/trpc-go/internal/addrutil"
    30  	"trpc.group/trpc-go/trpc-go/internal/report"
    31  	"trpc.group/trpc-go/trpc-go/internal/writev"
    32  	"trpc.group/trpc-go/trpc-go/log"
    33  	"trpc.group/trpc-go/trpc-go/rpcz"
    34  	"trpc.group/trpc-go/trpc-go/transport/internal/frame"
    35  )
    36  
    37  const defaultBufferSize = 128 * 1024
    38  
    39  type handleParam struct {
    40  	req   []byte
    41  	c     *tcpconn
    42  	start time.Time
    43  }
    44  
    45  func (p *handleParam) reset() {
    46  	p.req = nil
    47  	p.c = nil
    48  	p.start = time.Time{}
    49  }
    50  
    51  var handleParamPool = &sync.Pool{
    52  	New: func() interface{} { return new(handleParam) },
    53  }
    54  
    55  func createRoutinePool(size int) *ants.PoolWithFunc {
    56  	if size <= 0 {
    57  		size = math.MaxInt32
    58  	}
    59  	pool, err := ants.NewPoolWithFunc(size, func(args interface{}) {
    60  		param, ok := args.(*handleParam)
    61  		if !ok {
    62  			log.Tracef("routine pool args type error, shouldn't happen!")
    63  			return
    64  		}
    65  		report.TCPServerAsyncGoroutineScheduleDelay.Set(float64(time.Since(param.start).Microseconds()))
    66  		if param.c == nil {
    67  			log.Tracef("routine pool tcpconn is nil, shouldn't happen!")
    68  			return
    69  		}
    70  		param.c.handleSync(param.req)
    71  		param.reset()
    72  		handleParamPool.Put(param)
    73  	})
    74  	if err != nil {
    75  		log.Tracef("routine pool create error:%v", err)
    76  		return nil
    77  	}
    78  	return pool
    79  }
    80  
    81  func (s *serverTransport) serveTCP(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error {
    82  	var once sync.Once
    83  	closeListener := func() { ln.Close() }
    84  	defer once.Do(closeListener)
    85  	// Create a goroutine to watch ctx.Done() channel.
    86  	// Once Server.Close(), TCP listener should be closed immediately and won't accept any new connection.
    87  	go func() {
    88  		<-ctx.Done()
    89  		log.Tracef("recv server close event")
    90  		once.Do(closeListener)
    91  	}()
    92  	// Create a goroutine pool if ServerAsync enabled.
    93  	var pool *ants.PoolWithFunc
    94  	if opts.ServerAsync {
    95  		pool = createRoutinePool(opts.Routines)
    96  	}
    97  	for tempDelay := time.Duration(0); ; {
    98  		rwc, err := ln.Accept()
    99  		if err != nil {
   100  			if ne, ok := err.(net.Error); ok && ne.Temporary() {
   101  				tempDelay = doTempDelay(tempDelay)
   102  				continue
   103  			}
   104  			select {
   105  			case <-ctx.Done(): // If this error is triggered by the user, such as during a restart,
   106  				return err // it is possible to directly return the error, causing the current listener to exit.
   107  			default:
   108  				// Restricted access to the internal/poll.ErrNetClosing type necessitates comparing a string literal.
   109  				const accept, closeError = "accept", "use of closed network connection"
   110  				const msg = "the server transport, listening on %s, encountered an error: %+v; this error was handled" +
   111  					" gracefully by the framework to prevent abnormal termination, serving as a reference for" +
   112  					" investigating acceptance errors that can't be filtered by the Temporary interface"
   113  				if e, ok := err.(*net.OpError); ok && e.Op == accept && strings.Contains(e.Err.Error(), closeError) {
   114  					log.Infof("listener with address %s is closed", ln.Addr())
   115  					return err
   116  				}
   117  				log.Errorf(msg, ln.Addr(), err)
   118  				continue
   119  			}
   120  		}
   121  		tempDelay = 0
   122  		if tcpConn, ok := rwc.(*net.TCPConn); ok {
   123  			if err := tcpConn.SetKeepAlive(true); err != nil {
   124  				log.Tracef("tcp conn set keepalive error:%v", err)
   125  			}
   126  			if s.opts.KeepAlivePeriod > 0 {
   127  				if err := tcpConn.SetKeepAlivePeriod(s.opts.KeepAlivePeriod); err != nil {
   128  					log.Tracef("tcp conn set keepalive period error:%v", err)
   129  				}
   130  			}
   131  		}
   132  		tc := &tcpconn{
   133  			conn:        s.newConn(ctx, opts),
   134  			rwc:         rwc,
   135  			fr:          opts.FramerBuilder.New(codec.NewReader(rwc)),
   136  			remoteAddr:  rwc.RemoteAddr(),
   137  			localAddr:   rwc.LocalAddr(),
   138  			serverAsync: opts.ServerAsync,
   139  			writev:      opts.Writev,
   140  			st:          s,
   141  			pool:        pool,
   142  		}
   143  		// Start goroutine sending with writev.
   144  		if tc.writev {
   145  			tc.buffer = writev.NewBuffer()
   146  			tc.closeNotify = make(chan struct{}, 1)
   147  			tc.buffer.Start(tc.rwc, tc.closeNotify)
   148  		}
   149  		// To avoid over writing packages, checks whether should we copy packages by Framer and
   150  		// some other configurations.
   151  		tc.copyFrame = frame.ShouldCopy(opts.CopyFrame, tc.serverAsync, codec.IsSafeFramer(tc.fr))
   152  		key := addrutil.AddrToKey(tc.localAddr, tc.remoteAddr)
   153  		s.m.Lock()
   154  		s.addrToConn[key] = tc
   155  		s.m.Unlock()
   156  		go tc.serve()
   157  	}
   158  }
   159  
   160  func doTempDelay(tempDelay time.Duration) time.Duration {
   161  	if tempDelay == 0 {
   162  		tempDelay = 5 * time.Millisecond
   163  	} else {
   164  		tempDelay *= 2
   165  	}
   166  	if max := 1 * time.Second; tempDelay > max {
   167  		tempDelay = max
   168  	}
   169  	time.Sleep(tempDelay)
   170  	return tempDelay
   171  }
   172  
   173  // tcpconn is the connection which is established when server accept a client connecting request.
   174  type tcpconn struct {
   175  	*conn
   176  	rwc         net.Conn
   177  	fr          codec.Framer
   178  	localAddr   net.Addr
   179  	remoteAddr  net.Addr
   180  	serverAsync bool
   181  	writev      bool
   182  	copyFrame   bool
   183  	closeOnce   sync.Once
   184  	st          *serverTransport
   185  	pool        *ants.PoolWithFunc
   186  	buffer      *writev.Buffer
   187  	closeNotify chan struct{}
   188  }
   189  
   190  // close closes socket and cleans up.
   191  func (c *tcpconn) close() {
   192  	c.closeOnce.Do(func() {
   193  		// Send error msg to handler.
   194  		ctx, msg := codec.WithNewMessage(context.Background())
   195  		msg.WithLocalAddr(c.localAddr)
   196  		msg.WithRemoteAddr(c.remoteAddr)
   197  		e := &errs.Error{
   198  			Type: errs.ErrorTypeFramework,
   199  			Code: errs.RetServerSystemErr,
   200  			Desc: "trpc",
   201  			Msg:  "Server connection closed",
   202  		}
   203  		msg.WithServerRspErr(e)
   204  		// The connection closing message is handed over to handler.
   205  		if err := c.conn.handleClose(ctx); err != nil {
   206  			log.Trace("transport: notify connection close failed", err)
   207  		}
   208  		// Notify to stop writev sending goroutine.
   209  		if c.writev {
   210  			close(c.closeNotify)
   211  		}
   212  
   213  		// Remove cache in server stream transport.
   214  		key := addrutil.AddrToKey(c.localAddr, c.remoteAddr)
   215  		c.st.m.Lock()
   216  		delete(c.st.addrToConn, key)
   217  		c.st.m.Unlock()
   218  
   219  		// Finally, close the socket connection.
   220  		c.rwc.Close()
   221  	})
   222  }
   223  
   224  // write encapsulates tcp conn write.
   225  func (c *tcpconn) write(p []byte) (int, error) {
   226  	if c.writev {
   227  		return c.buffer.Write(p)
   228  	}
   229  	return c.rwc.Write(p)
   230  }
   231  
   232  func (c *tcpconn) serve() {
   233  	defer c.close()
   234  	for {
   235  		// Check if upstream has closed.
   236  		select {
   237  		case <-c.ctx.Done():
   238  			return
   239  		default:
   240  		}
   241  
   242  		if c.idleTimeout > 0 {
   243  			now := time.Now()
   244  			// SetReadDeadline has poor performance, so, update timeout every 5 seconds.
   245  			if now.Sub(c.lastVisited) > 5*time.Second {
   246  				c.lastVisited = now
   247  				err := c.rwc.SetReadDeadline(now.Add(c.idleTimeout))
   248  				if err != nil {
   249  					log.Trace("transport: tcpconn SetReadDeadline fail ", err)
   250  					return
   251  				}
   252  			}
   253  		}
   254  
   255  		req, err := c.fr.ReadFrame()
   256  		if err != nil {
   257  			if err == io.EOF {
   258  				report.TCPServerTransportReadEOF.Incr() // client has closed the connections.
   259  				return
   260  			}
   261  			// Server closes the connection if client sends no package in last idle timeout.
   262  			if e, ok := err.(net.Error); ok && e.Timeout() {
   263  				report.TCPServerTransportIdleTimeout.Incr()
   264  				return
   265  			}
   266  			report.TCPServerTransportReadFail.Incr()
   267  			log.Trace("transport: tcpconn serve ReadFrame fail ", err)
   268  			return
   269  		}
   270  		report.TCPServerTransportReceiveSize.Set(float64(len(req)))
   271  		// if framer is not concurrent safe, copy the data to avoid over writing.
   272  		if c.copyFrame {
   273  			reqCopy := make([]byte, len(req))
   274  			copy(reqCopy, req)
   275  			req = reqCopy
   276  		}
   277  
   278  		c.handle(req)
   279  	}
   280  }
   281  
   282  func (c *tcpconn) handle(req []byte) {
   283  	if !c.serverAsync || c.pool == nil {
   284  		c.handleSync(req)
   285  		return
   286  	}
   287  
   288  	// Using sync.pool to dispatch package processing goroutine parameters can reduce a memory
   289  	// allocation and slightly promote performance.
   290  	args := handleParamPool.Get().(*handleParam)
   291  	args.req = req
   292  	args.c = c
   293  	args.start = time.Now()
   294  	if err := c.pool.Invoke(args); err != nil {
   295  		report.TCPServerTransportJobQueueFullFail.Incr()
   296  		log.Trace("transport: tcpconn serve routine pool put job queue fail ", err)
   297  		c.handleSyncWithErr(req, errs.ErrServerRoutinePoolBusy)
   298  	}
   299  }
   300  
   301  func (c *tcpconn) handleSync(req []byte) {
   302  	c.handleSyncWithErr(req, nil)
   303  }
   304  
   305  func (c *tcpconn) handleSyncWithErr(req []byte, e error) {
   306  	ctx, msg := codec.WithNewMessage(context.Background())
   307  	defer codec.PutBackMessage(msg)
   308  	msg.WithServerRspErr(e)
   309  	// Record local addr and remote addr to context.
   310  	msg.WithLocalAddr(c.localAddr)
   311  	msg.WithRemoteAddr(c.remoteAddr)
   312  
   313  	span, ender, ctx := rpcz.NewSpanContext(ctx, "server")
   314  	span.SetAttribute(rpcz.TRPCAttributeRequestSize, len(req))
   315  
   316  	rsp, err := c.conn.handle(ctx, req)
   317  
   318  	defer func() {
   319  		span.SetAttribute(rpcz.TRPCAttributeRPCName, msg.ServerRPCName())
   320  		if err == nil {
   321  			span.SetAttribute(rpcz.TRPCAttributeError, msg.ServerRspErr())
   322  		} else {
   323  			span.SetAttribute(rpcz.TRPCAttributeError, err)
   324  		}
   325  		ender.End()
   326  	}()
   327  	if err != nil {
   328  		if err != errs.ErrServerNoResponse {
   329  			report.TCPServerTransportHandleFail.Incr()
   330  			log.Trace("transport: tcpconn serve handle fail ", err)
   331  			c.close()
   332  			return
   333  		}
   334  		// On stream RPC, server does not need to write rsp, just returns.
   335  		return
   336  	}
   337  	report.TCPServerTransportSendSize.Set(float64(len(rsp)))
   338  	span.SetAttribute(rpcz.TRPCAttributeResponseSize, len(rsp))
   339  	{
   340  		// common RPC write rsp.
   341  		_, ender := span.NewChild("SendMessage")
   342  		_, err = c.write(rsp)
   343  		ender.End()
   344  	}
   345  
   346  	if err != nil {
   347  		report.TCPServerTransportWriteFail.Incr()
   348  		log.Trace("transport: tcpconn write fail ", err)
   349  		c.close()
   350  	}
   351  }