github.com/weedge/lib@v0.0.0-20230424045628-a36dcc1d90e4/poller/server_tcp.go (about)

     1  package poller
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"sync"
     7  	"sync/atomic"
     8  	"syscall"
     9  	"time"
    10  
    11  	"github.com/weedge/lib/log"
    12  )
    13  
    14  // Server TCP server
    15  type Server struct {
    16  	options        *options                    // Service parameters
    17  	readBufferPool *sync.Pool                  // Read cache memory pool
    18  	handler        Handler                     // Indicates the processing of registration
    19  	ioEventQueues  []chan *eventInfo           // IO A collection of event queues
    20  	ioQueueNum     int                         // Number of I/O event queues
    21  	conns          sync.Map                    // TCP long connection management
    22  	connsNum       int64                       // Indicates the number of established long connections
    23  	stop           chan struct{}               // Indicates the server shutdown signal
    24  	listenFD       int                         // listen fd
    25  	pollerFD       int                         // event poller fd (epoll/kqueue, poll, select)
    26  	iourings       []*ioUring                  // iouring async event rings
    27  	asyncEventCb   map[EventType]EventCallBack // async event call back register
    28  	looperWg       sync.WaitGroup              // main looper group wait
    29  }
    30  
    31  // NewServer
    32  // init server to start
    33  func NewServer(address string, handler Handler, opts ...Option) (*Server, error) {
    34  	options := getOptions(opts...)
    35  
    36  	// init read buffer pool
    37  	readBufferPool := &sync.Pool{
    38  		New: func() interface{} {
    39  			b := make([]byte, options.readBufferLen)
    40  			return b
    41  		},
    42  	}
    43  
    44  	// listen
    45  	lfd, err := listen(address, options.listenBacklog)
    46  	if err != nil {
    47  		log.Error(err)
    48  		return nil, err
    49  	}
    50  
    51  	// init poller(epoll/kqueue)
    52  	var pollerFD int
    53  	pollerFD, err = createPoller()
    54  	if err != nil {
    55  		log.Error(err)
    56  		return nil, err
    57  	}
    58  
    59  	// init io_uring setup
    60  	var rings []*ioUring
    61  	if options.ioMode == IOModeUring || options.ioMode == IOModeEpollUring {
    62  		rings = make([]*ioUring, options.ioUringNum)
    63  		for i := 0; i < options.ioUringNum; i++ {
    64  			ring, err := newIoUring(options.ioUringEntries, options.ioUringParams)
    65  			if err != nil {
    66  				log.Errorf("newIoUring %d err %s", i, err.Error())
    67  				return nil, err
    68  			}
    69  
    70  			// register eventfd
    71  			if options.ioMode == IOModeEpollUring {
    72  				err = ring.RegisterEventFd()
    73  				if err != nil {
    74  					log.Errorf("ring.RegisterEventFd %d err %s", i, err.Error())
    75  					ring.CloseRing()
    76  					return nil, err
    77  				}
    78  			}
    79  
    80  			rings[i] = ring
    81  		} // end for
    82  	}
    83  
    84  	// init io event channel(queue)
    85  	ioEventQueues := make([]chan *eventInfo, options.ioGNum)
    86  	for i := range ioEventQueues {
    87  		ioEventQueues[i] = make(chan *eventInfo, options.ioEventQueueLen)
    88  	}
    89  
    90  	return &Server{
    91  		options:        options,
    92  		readBufferPool: readBufferPool,
    93  		handler:        handler,
    94  		ioEventQueues:  ioEventQueues,
    95  		ioQueueNum:     options.ioGNum,
    96  		conns:          sync.Map{},
    97  		connsNum:       0,
    98  		stop:           make(chan struct{}),
    99  		listenFD:       lfd,
   100  		pollerFD:       pollerFD,
   101  		iourings:       rings,
   102  		asyncEventCb:   map[EventType]EventCallBack{},
   103  	}, nil
   104  }
   105  
   106  // GetConn
   107  // get connect by connect fd from session connect sync Map
   108  func (s *Server) GetConn(fd int32) (*Conn, bool) {
   109  	value, ok := s.conns.Load(fd)
   110  	if !ok {
   111  		return nil, false
   112  	}
   113  	return value.(*Conn), true
   114  }
   115  
   116  // Run run server
   117  // acceptor accept connet fron listen fd,
   118  // eventLooper dispatch event,
   119  // ioConsumeHandler hanle event for biz logic
   120  // check time out connenct session,
   121  func (s *Server) Run() {
   122  	log.Info("start server runing...")
   123  	// rigister event
   124  	s.rigisterEpollIouringEvent()
   125  
   126  	// monitor
   127  	s.report()
   128  	s.checkTimeout()
   129  
   130  	// start server
   131  	s.startAcceptor()
   132  	s.startIOConsumeHandler()
   133  	s.startIOEventLooper()
   134  }
   135  
   136  // rigisterEpollIouringEvent
   137  // rigister epoll iouring eventfd, wait eventfd iouring cq read ready event,
   138  // notify iouring to get cqe
   139  func (s *Server) rigisterEpollIouringEvent() {
   140  	if s.options.ioMode != IOModeEpollUring {
   141  		return
   142  	}
   143  	log.Info("start notify iouring cq event by epoll rigistered iouring eventfd")
   144  
   145  	err := s.rigisterIoUringEvent()
   146  	if err != nil {
   147  		log.Errorf("rigisterIoUringEvent err %s", err.Error())
   148  		return
   149  	}
   150  
   151  	go s.startNotifyIoUringCQEvent()
   152  }
   153  
   154  // rigisterIoUringEvent
   155  // add read event to epoll item list for registered iouring eventfd (just read)
   156  func (s *Server) rigisterIoUringEvent() (err error) {
   157  	for i := 0; i < s.options.ioUringNum; i++ {
   158  		log.Debugf("pollerFD %d eventfd %d add epoll readable event", s.pollerFD, s.iourings[i].eventfd)
   159  		err = addReadEvent(s.pollerFD, s.iourings[i].eventfd)
   160  		if err != nil {
   161  			return
   162  		}
   163  	}
   164  	return
   165  }
   166  
   167  // CloseIoUring
   168  // remove rigistered  eventfd iouring event, free iouring mmap
   169  func (s *Server) CloseIoUring() {
   170  	for i := 0; i < len(s.iourings); i++ {
   171  		delEventFD(s.pollerFD, s.iourings[i].eventfd)
   172  		s.iourings[i].CloseRing()
   173  	}
   174  }
   175  
   176  // Stop
   177  // stop server, close communication channel(queue)
   178  // free io uring
   179  func (s *Server) Stop() {
   180  	close(s.stop)
   181  	for _, queue := range s.ioEventQueues {
   182  		close(queue)
   183  	}
   184  
   185  	s.CloseIoUring()
   186  }
   187  
   188  // GetConnsNum
   189  func (s *Server) GetConnsNum() int64 {
   190  	return atomic.LoadInt64(&s.connsNum)
   191  }
   192  
   193  // startAcceptor
   194  // setup accept connect goroutine
   195  func (s *Server) startAcceptor() {
   196  	if len(s.iourings) != 0 {
   197  		go s.asyncBlockAccept()
   198  		log.Info("start trigger async block accept")
   199  		return
   200  	}
   201  
   202  	for i := 0; i < s.options.acceptGNum; i++ {
   203  		go s.accept()
   204  	}
   205  	log.Infof("start accept by %d goroutine", s.options.acceptGNum)
   206  }
   207  
   208  // accept
   209  // block accept connect from listen fd
   210  // save non block connect fd session and OnConnect logic handle
   211  func (s *Server) accept() {
   212  	for {
   213  		select {
   214  		case <-s.stop:
   215  			return
   216  		default:
   217  			cfd, socketAddr, err := accept(s.listenFD, s.options.keepaliveInterval)
   218  			if err != nil {
   219  				log.Error(err)
   220  				continue
   221  			}
   222  			addr := getAddr(socketAddr)
   223  
   224  			conn := newConn(s.pollerFD, cfd, addr, s)
   225  			s.conns.Store(cfd, conn)
   226  			atomic.AddInt64(&s.connsNum, 1)
   227  
   228  			err = addReadEvent(s.pollerFD, cfd)
   229  			if err != nil {
   230  				log.Error(err)
   231  				conn.Close()
   232  				continue
   233  			}
   234  
   235  			s.handler.OnConnect(conn)
   236  		}
   237  	}
   238  }
   239  
   240  // nonBlockPollAccept
   241  // non block accept, when return EAGAIN, add/produce event poll op to sqe
   242  func (s *Server) nonBlockPollAccept() {
   243  }
   244  
   245  // asyncBlockAccept
   246  // async add/produce block accept op to sqe
   247  func (s *Server) asyncBlockAccept() {
   248  	var rsa syscall.RawSockaddrAny
   249  	var len uint32 = syscall.SizeofSockaddrAny
   250  	s.GetIoUring(s.listenFD).addAcceptSqe(s.getAcceptCallback(&rsa), s.listenFD, &rsa, len, 0)
   251  }
   252  
   253  func (s *Server) GetIoUring(fd int) *ioUring {
   254  	return s.iourings[fd%s.options.ioUringNum]
   255  }
   256  
   257  func (s *Server) GetEventIoUring(efd int) *ioUring {
   258  	if len(s.iourings) == 0 {
   259  		return nil
   260  	}
   261  	for i := 0; i < s.options.ioUringNum; i++ {
   262  		if efd == s.iourings[i].eventfd {
   263  			return s.iourings[i]
   264  		}
   265  	}
   266  
   267  	return nil
   268  }
   269  
   270  func (s *Server) startNotifyIoUringCQEvent() {
   271  	for {
   272  		select {
   273  		case <-s.stop:
   274  			return
   275  		default:
   276  			events, err := getEvents(s.pollerFD)
   277  			if err != nil {
   278  				if err != syscall.EINTR {
   279  					log.Errorf("getEvents err %s", err.Error())
   280  				}
   281  				continue
   282  			}
   283  			for _, event := range events {
   284  				ring := s.GetEventIoUring(event.fd)
   285  				ring.cqeSignCh <- struct{}{}
   286  			}
   287  		}
   288  	}
   289  }
   290  
   291  func (s *Server) getAcceptCallback(rsa *syscall.RawSockaddrAny) EventCallBack {
   292  	return func(e *eventInfo) (err error) {
   293  		if e.cqe.Res < 0 {
   294  			err = fmt.Errorf("accept err res %d", e.cqe.Res)
   295  			return
   296  		}
   297  
   298  		cfd := int(e.cqe.Res)
   299  		err = setConnectOption(cfd, s.options.keepaliveInterval)
   300  		if err != nil {
   301  			return
   302  		}
   303  
   304  		socketAddr, err := anyToSockaddr(rsa)
   305  		if err != nil {
   306  			return
   307  		}
   308  		addr := getAddr(socketAddr)
   309  
   310  		conn := newConn(s.pollerFD, cfd, addr, s)
   311  		s.conns.Store(cfd, conn)
   312  		atomic.AddInt64(&s.connsNum, 1)
   313  		s.handler.OnConnect(conn)
   314  
   315  		// new connected client, async read data from socket
   316  		conn.AsyncBlockRead()
   317  
   318  		// re-add accept to monitor for new connections
   319  		s.asyncBlockAccept()
   320  
   321  		return
   322  	}
   323  }
   324  
   325  // startIOEventLooper main looper
   326  // from poller events or io_uring cqe event entries
   327  func (s *Server) startIOEventLooper() {
   328  	//runtime.LockOSThread()
   329  	if s.iourings == nil {
   330  		s.startIOEventPollDispatcher()
   331  		return
   332  	}
   333  
   334  	s.looperWg.Add(len(s.iourings))
   335  	for i := 0; i < len(s.iourings); i++ {
   336  		go s.startIOUringPollDispatcher(i)
   337  	}
   338  	s.looperWg.Wait()
   339  }
   340  
   341  // startIOEventPollDispatcher
   342  // get ready events from poller, distpatch to event channel(queue)
   343  func (s *Server) startIOEventPollDispatcher() {
   344  	log.Info("start io event poll dispatcher")
   345  	for {
   346  		select {
   347  		case <-s.stop:
   348  			log.Infof("stop io event poll dispatcher")
   349  			return
   350  		default:
   351  			var err error
   352  			var events []eventInfo
   353  			events, err = getEvents(s.pollerFD)
   354  			if err != nil {
   355  				if err != syscall.EINTR {
   356  					log.Error(err)
   357  				}
   358  				continue
   359  			}
   360  
   361  			// dispatch
   362  			for i := range events {
   363  				s.handleEvent(&events[i])
   364  			}
   365  		}
   366  	} // end for
   367  }
   368  
   369  // startIOUringPollDispatcher
   370  // get completed event ops from io_uring cqe event entries, distpatch to event channel(queue)
   371  func (s *Server) startIOUringPollDispatcher(id int) {
   372  	defer s.looperWg.Done()
   373  	if s.iourings[id] == nil {
   374  		return
   375  	}
   376  	log.Infof("start io_uring event op poll dispatcher id %d", id)
   377  	for {
   378  		select {
   379  		case <-s.stop:
   380  			log.Infof("stop io_uring event op poll dispatcher id %d", id)
   381  			return
   382  		default:
   383  			event, err := s.iourings[id].getEventInfo()
   384  			if err != nil {
   385  				log.Warnf("id %d iouring get events error:%s continue", id, err.Error())
   386  				continue
   387  			}
   388  			if event == nil {
   389  				continue
   390  			}
   391  
   392  			// dispatch
   393  			s.handleEvent(event)
   394  			// commit cqe is seen
   395  			s.iourings[id].cqeDone(event.cqe)
   396  		}
   397  	} // end for
   398  }
   399  
   400  // handleEvent
   401  // use hash dispatch event to channel(queue)
   402  // need balance(hash fd, same connect have orderly event process)
   403  // golang scheduler have a good way to schedule thread in bound cpu affinity
   404  func (s *Server) handleEvent(event *eventInfo) {
   405  	index := event.fd % s.ioQueueNum
   406  	s.ioEventQueues[index] <- event
   407  }
   408  
   409  // startIOConsumeHandler
   410  // setup io event consume goroutine
   411  func (s *Server) startIOConsumeHandler() {
   412  	for _, queue := range s.ioEventQueues {
   413  		go s.consumeIOEvent(queue)
   414  	}
   415  	log.Info(fmt.Sprintf("start io event consumer by %d goroutine handler", len(s.ioEventQueues)))
   416  }
   417  
   418  func (s *Server) consumeIOEvent(queue chan *eventInfo) {
   419  	if s.iourings != nil {
   420  		s.consumeIOCompletionEvent(queue)
   421  	} else {
   422  		s.consumeIOReadyEvent(queue)
   423  	}
   424  }
   425  
   426  // consumeIOCompletionEvent
   427  func (s *Server) consumeIOCompletionEvent(queue chan *eventInfo) {
   428  	for event := range queue {
   429  		// process async accept connect complete event
   430  		if event.etype == ETypeAccept {
   431  			err := event.cb(event)
   432  			if err != nil {
   433  				log.Errorf("accept event %s cb error:%s, continue next event", event, err.Error())
   434  			}
   435  			continue
   436  		}
   437  
   438  		// get connect from fd
   439  		v, ok := s.conns.Load(event.fd)
   440  		if !ok {
   441  			log.Warnf("fd %d not found in conns, event:%s , continue next event", event.fd, event)
   442  			continue
   443  		}
   444  		c := v.(*Conn)
   445  
   446  		// process async read complete event
   447  		if event.etype == ETypeRead {
   448  			err := c.processReadEvent(event)
   449  			if err != nil {
   450  				// notice: if next connect use closed cfd (TIME_WAIT stat between 2MSL eg:4m),
   451  				// read from closed cfd return EBADF
   452  				if err == syscall.EBADF {
   453  					log.Errorf("read closed connect fd %d EBADF, continue next event", event.fd)
   454  					continue
   455  				}
   456  
   457  				// no bytes available on socket, client must be disconnected
   458  				log.Warnf("process read event %s err:%s , client connect must be disconnected", event, err.Error())
   459  				// close and free connect
   460  				c.CloseConnect()
   461  				s.handler.OnClose(c, err)
   462  			}
   463  		}
   464  
   465  		// async write complete event
   466  		if event.etype == ETypeWrite {
   467  			err := c.processWirteEvent(event)
   468  			if err != nil {
   469  				log.Errorf("process write event %s err:%s, continue next event", event, err.Error())
   470  				continue
   471  			}
   472  		}
   473  	} // end for
   474  }
   475  
   476  // consumeIOReadyEvent
   477  // handle ready r/w, close, connect timeout etc event
   478  func (s *Server) consumeIOReadyEvent(queue chan *eventInfo) {
   479  	for event := range queue {
   480  		v, ok := s.conns.Load(event.fd)
   481  		if !ok {
   482  			log.Warn("not found in conns,", event.fd, event)
   483  			continue
   484  		}
   485  		c := v.(*Conn)
   486  
   487  		if event.etype == ETypeClose {
   488  			c.Close()
   489  			s.handler.OnClose(c, io.EOF)
   490  			continue
   491  		}
   492  		if event.etype == ETypeTimeout {
   493  			c.Close()
   494  			s.handler.OnClose(c, ErrReadTimeout)
   495  			continue
   496  		}
   497  
   498  		err := c.Read()
   499  		if err != nil {
   500  			// notice: if next connect use closed cfd (TIME_WAIT stat between 2MSL eg:4m),
   501  			// read from closed cfd return EBADF
   502  			if err == syscall.EBADF {
   503  				continue
   504  			}
   505  
   506  			// no bytes available on socket, client must be disconnected
   507  			log.Warnf("process sync read connect fd %d err:%s , client connect must be disconnected", c.fd, err.Error())
   508  			// close and free connect
   509  			c.Close()
   510  			s.handler.OnClose(c, err)
   511  
   512  		}
   513  	} // end for
   514  }
   515  
   516  // checkTimeout
   517  // tick to check connect time out
   518  func (s *Server) checkTimeout() {
   519  	if s.options.timeout == 0 || s.options.timeoutTicker == 0 {
   520  		return
   521  	}
   522  
   523  	log.Infof("check timeout goroutine run,check_time:%v, timeout:%v", s.options.timeoutTicker, s.options.timeout)
   524  	go func() {
   525  		ticker := time.NewTicker(s.options.timeoutTicker)
   526  		for {
   527  			select {
   528  			case <-s.stop:
   529  				return
   530  			case <-ticker.C:
   531  				s.conns.Range(func(key, value interface{}) bool {
   532  					c := value.(*Conn)
   533  					//log.Infof("check connect %+v", c)
   534  					if time.Since(c.lastReadTime) > s.options.timeout {
   535  						s.handleEvent(&eventInfo{fd: int(c.fd), etype: ETypeTimeout})
   536  					}
   537  					return true
   538  				})
   539  			}
   540  		} // end for
   541  	}()
   542  }
   543  
   544  func (s *Server) report() {
   545  	if s.options.reportTicker == 0 {
   546  		return
   547  	}
   548  
   549  	log.Infof("start report server info, report tick time %v", s.options.reportTicker)
   550  	go func() {
   551  		ticker := time.NewTicker(s.options.reportTicker)
   552  		for {
   553  			select {
   554  			case <-s.stop:
   555  				return
   556  			case <-ticker.C:
   557  				n := s.GetConnsNum()
   558  				if n > 0 {
   559  					log.Infof("current active connect num %d", n)
   560  				}
   561  			}
   562  		}
   563  	}()
   564  }