github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/common/transport/std_tcp.go (about)

     1  package transport
     2  
     3  import (
     4  	"errors"
     5  	"github.com/nyan233/littlerpc/core/common/logger"
     6  	"net"
     7  	"sync"
     8  	"sync/atomic"
     9  	"syscall"
    10  )
    11  
    12  const (
    13  	StdTCPClient int = iota
    14  	StdTCPServer
    15  )
    16  
    17  type StdNetTcpEngine struct {
    18  	mu sync.Mutex
    19  	// 指示是客户端模式还是服务器
    20  	mode      int
    21  	onOpen    func(conn ConnAdapter)
    22  	onRead    func(conn ConnAdapter)
    23  	onMessage func(conn ConnAdapter, data []byte)
    24  	onClose   func(conn ConnAdapter, err error)
    25  	addrs     []string
    26  	listeners []net.Listener
    27  	readBuf   sync.Pool
    28  	closed    int32
    29  }
    30  
    31  func NewStdTcpServer(config NetworkServerConfig) ServerBuilder {
    32  	return &StdNetTcpEngine{
    33  		listeners: make([]net.Listener, len(config.Addrs)),
    34  		addrs:     config.Addrs,
    35  		mode:      StdTCPServer,
    36  		readBuf: sync.Pool{
    37  			New: func() interface{} {
    38  				tmp := make([]byte, ReadBufferSize)
    39  				return &tmp
    40  			},
    41  		},
    42  		onOpen:    func(conn ConnAdapter) {},
    43  		onMessage: func(conn ConnAdapter, data []byte) {},
    44  		onClose:   func(conn ConnAdapter, err error) {},
    45  	}
    46  }
    47  
    48  func NewStdTcpClient() ClientBuilder {
    49  	return &StdNetTcpEngine{
    50  		mode: StdTCPClient,
    51  		readBuf: sync.Pool{
    52  			New: func() interface{} {
    53  				tmp := make([]byte, ReadBufferSize)
    54  				return &tmp
    55  			},
    56  		},
    57  		onOpen:    func(conn ConnAdapter) {},
    58  		onMessage: func(conn ConnAdapter, data []byte) {},
    59  		onClose:   func(conn ConnAdapter, err error) {},
    60  	}
    61  }
    62  
    63  func (s *StdNetTcpEngine) NewConn(config NetworkClientConfig) (ConnAdapter, error) {
    64  	conn, err := net.Dial("tcp", config.ServerAddr)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	return s.connService(conn), nil
    69  }
    70  
    71  func (s *StdNetTcpEngine) Server() ServerEngine {
    72  	return s
    73  }
    74  
    75  func (s *StdNetTcpEngine) Client() ClientEngine {
    76  	return s
    77  }
    78  
    79  func (s *StdNetTcpEngine) EventDriveInter() EventDriveInter {
    80  	return s
    81  }
    82  
    83  func (s *StdNetTcpEngine) Start() error {
    84  	if atomic.LoadInt32(&s.closed) == 1 {
    85  		return errors.New("wsEngine already closed")
    86  	}
    87  	if s.mode == StdTCPClient {
    88  		return nil
    89  	}
    90  	var wg sync.WaitGroup
    91  	wg.Add(len(s.listeners))
    92  	for k, v := range s.addrs {
    93  		lIndex := k
    94  		addr := v
    95  		go func() {
    96  			listener, err := net.Listen("tcp", addr)
    97  			if err != nil {
    98  				panic(err)
    99  			}
   100  			wg.Done()
   101  			s.mu.Lock()
   102  			s.listeners[lIndex] = listener
   103  			s.mu.Unlock()
   104  			for {
   105  				conn, err := listener.Accept()
   106  				if err != nil {
   107  					logger.DefaultLogger.Warn("std-tcp engine accept conn failed, err = %v", err)
   108  					break
   109  				}
   110  				s.connService(conn)
   111  			}
   112  		}()
   113  	}
   114  	wg.Wait()
   115  	return nil
   116  }
   117  
   118  func (s *StdNetTcpEngine) connService(conn net.Conn) *nioConn {
   119  	nc := &nioConn{Conn: conn}
   120  	s.onOpen(nc)
   121  	go func() {
   122  		var (
   123  			buf = make([]byte, 0)
   124  		)
   125  		for {
   126  			if atomic.LoadInt32(&s.closed) == 1 {
   127  				s.onClose(nc, errors.New("eventLoop already closed"))
   128  				_ = nc.Close()
   129  				break
   130  			}
   131  			if s.onRead != nil {
   132  				_, err := conn.Read(buf)
   133  				if err != nil {
   134  					s.onClose(nc, err)
   135  					_ = nc.Close()
   136  					break
   137  				}
   138  				s.onRead(nc)
   139  				continue
   140  			}
   141  			readBuf := s.readBuf.Get().(*[]byte)
   142  			readN, err := conn.Read(*readBuf)
   143  			if err != nil {
   144  				s.readBuf.Put(readBuf)
   145  				s.onClose(nc, err)
   146  				_ = nc.Close()
   147  				break
   148  			}
   149  			s.onMessage(nc, (*readBuf)[:readN])
   150  			s.readBuf.Put(readBuf)
   151  		}
   152  	}()
   153  	return nc
   154  }
   155  
   156  func (s *StdNetTcpEngine) Stop() error {
   157  	if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
   158  		return errors.New("wsEngine already closed")
   159  	}
   160  	s.mu.Lock()
   161  	defer s.mu.Unlock()
   162  	for _, v := range s.listeners {
   163  		_ = v.Close()
   164  	}
   165  	return nil
   166  }
   167  
   168  func (s *StdNetTcpEngine) OnRead(f func(conn ConnAdapter)) {
   169  	s.onRead = f
   170  }
   171  
   172  func (s *StdNetTcpEngine) OnMessage(f func(conn ConnAdapter, data []byte)) {
   173  	s.onMessage = f
   174  }
   175  
   176  func (s *StdNetTcpEngine) OnOpen(f func(conn ConnAdapter)) {
   177  	s.onOpen = f
   178  }
   179  
   180  func (s *StdNetTcpEngine) OnClose(f func(conn ConnAdapter, err error)) {
   181  	s.onClose = f
   182  }
   183  
   184  type nioConn struct {
   185  	net.Conn
   186  	source atomic.Value
   187  }
   188  
   189  // 保证OnRead只调用一次Read
   190  func (c *nioConn) Read(p []byte) (n int, err error) {
   191  	readN, err := c.Conn.Read(p)
   192  	if err != nil {
   193  		return readN, err
   194  	}
   195  	return readN, syscall.EWOULDBLOCK
   196  }
   197  
   198  func (c *nioConn) SetSource(s interface{}) {
   199  	c.source.Store(s)
   200  }
   201  
   202  func (c *nioConn) Source() interface{} {
   203  	return c.source.Load()
   204  }