github.com/cmd-stream/base-go@v0.0.0-20230813145615-dd6ac24c16f5/server/conn_receiver.go (about)

     1  package server
     2  
     3  import (
     4  	"net"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/cmd-stream/base-go"
     9  )
    10  
    11  const (
    12  	inProgress int = iota
    13  	shutdown
    14  	closed
    15  )
    16  
    17  // NewConnReceiver creates a new ConnReceiver.
    18  func NewConnReceiver(conf ConnReceiverConf, listener base.Listener,
    19  	conns chan net.Conn) *ConnReceiver {
    20  	return &ConnReceiver{
    21  		conf:     conf,
    22  		listener: listener,
    23  		conns:    conns,
    24  		stopped:  make(chan struct{}),
    25  	}
    26  }
    27  
    28  // ConnReceiver accepts incoming connections on the listener and adds them to
    29  // the conns channel.
    30  //
    31  // It can wait for the first connection for a limited amount of time, after
    32  // which, it stops. Also ConnReceiver implements the jointwork.Task interface,
    33  // so it + Workers may do the job together.
    34  type ConnReceiver struct {
    35  	conf     ConnReceiverConf
    36  	listener base.Listener
    37  	conns    chan net.Conn
    38  	state    int
    39  	stopped  chan struct{}
    40  	mu       sync.Mutex
    41  }
    42  
    43  func (r *ConnReceiver) Run() (err error) {
    44  	defer func() {
    45  		r.postRun()
    46  	}()
    47  	if err = r.acceptFirstConn(); err != nil {
    48  		return r.correctErr(err)
    49  	}
    50  	return r.correctErr(r.acceptConns())
    51  }
    52  
    53  // Shutdown stops ConnReceiver - the Run() method returns nil, which allows
    54  // Workers to finish their work.
    55  func (r *ConnReceiver) Shutdown() (err error) {
    56  	return r.terminate(shutdown)
    57  }
    58  
    59  // Stop stops ConnReceiver - the Run() method returns ErrClosed.
    60  func (r *ConnReceiver) Stop() (err error) {
    61  	return r.terminate(closed)
    62  }
    63  
    64  func (r *ConnReceiver) acceptFirstConn() (err error) {
    65  	if r.conf.FirstConnTimeout != 0 {
    66  		defer func() {
    67  			if err == nil {
    68  				err = r.listener.SetDeadline(time.Time{})
    69  			}
    70  		}()
    71  		err = r.listener.SetDeadline(time.Now().Add(r.conf.FirstConnTimeout))
    72  		if err != nil {
    73  			return err
    74  		}
    75  	}
    76  	conn, err := r.listener.Accept()
    77  	if err != nil {
    78  		return err
    79  	}
    80  	return r.queueConn(conn)
    81  }
    82  
    83  func (r *ConnReceiver) acceptConns() (err error) {
    84  	var conn net.Conn
    85  	for {
    86  		conn, err = r.listener.Accept()
    87  		if err != nil {
    88  			return
    89  		}
    90  		if err = r.queueConn(conn); err != nil {
    91  			return
    92  		}
    93  	}
    94  }
    95  
    96  func (r *ConnReceiver) queueConn(conn net.Conn) error {
    97  	select {
    98  	case <-r.stopped:
    99  		if err := conn.Close(); err != nil {
   100  			panic(err)
   101  		}
   102  		return ErrClosed
   103  	case r.conns <- conn:
   104  		return nil
   105  	}
   106  }
   107  
   108  func (r *ConnReceiver) terminate(state int) (err error) {
   109  	r.mu.Lock()
   110  	defer r.mu.Unlock()
   111  	if r.state == inProgress {
   112  		r.state = state
   113  		if err = r.listener.Close(); err != nil {
   114  			r.state = inProgress
   115  			return
   116  		}
   117  		close(r.stopped)
   118  	}
   119  	return
   120  }
   121  
   122  func (r *ConnReceiver) correctErr(err error) error {
   123  	r.mu.Lock()
   124  	defer r.mu.Unlock()
   125  	switch r.state {
   126  	case inProgress:
   127  		return err
   128  	case shutdown:
   129  		return nil
   130  	case closed:
   131  		return ErrClosed
   132  	default:
   133  		panic("unexpected state")
   134  	}
   135  }
   136  
   137  func (r *ConnReceiver) postRun() {
   138  	r.mu.Lock()
   139  	defer r.mu.Unlock()
   140  	close(r.conns)
   141  	if r.state == shutdown {
   142  		return
   143  	}
   144  	for conn := range r.conns {
   145  		if err := conn.Close(); err != nil {
   146  			panic(err)
   147  		}
   148  	}
   149  }