github.com/cmd-stream/base-go@v0.0.0-20230813145615-dd6ac24c16f5/server/conn_receiver.go (about) 1 package server 2 3 import ( 4 "net" 5 "sync" 6 "time" 7 8 "github.com/cmd-stream/base-go" 9 ) 10 11 const ( 12 inProgress int = iota 13 shutdown 14 closed 15 ) 16 17 // NewConnReceiver creates a new ConnReceiver. 18 func NewConnReceiver(conf ConnReceiverConf, listener base.Listener, 19 conns chan net.Conn) *ConnReceiver { 20 return &ConnReceiver{ 21 conf: conf, 22 listener: listener, 23 conns: conns, 24 stopped: make(chan struct{}), 25 } 26 } 27 28 // ConnReceiver accepts incoming connections on the listener and adds them to 29 // the conns channel. 30 // 31 // It can wait for the first connection for a limited amount of time, after 32 // which, it stops. Also ConnReceiver implements the jointwork.Task interface, 33 // so it + Workers may do the job together. 34 type ConnReceiver struct { 35 conf ConnReceiverConf 36 listener base.Listener 37 conns chan net.Conn 38 state int 39 stopped chan struct{} 40 mu sync.Mutex 41 } 42 43 func (r *ConnReceiver) Run() (err error) { 44 defer func() { 45 r.postRun() 46 }() 47 if err = r.acceptFirstConn(); err != nil { 48 return r.correctErr(err) 49 } 50 return r.correctErr(r.acceptConns()) 51 } 52 53 // Shutdown stops ConnReceiver - the Run() method returns nil, which allows 54 // Workers to finish their work. 55 func (r *ConnReceiver) Shutdown() (err error) { 56 return r.terminate(shutdown) 57 } 58 59 // Stop stops ConnReceiver - the Run() method returns ErrClosed. 60 func (r *ConnReceiver) Stop() (err error) { 61 return r.terminate(closed) 62 } 63 64 func (r *ConnReceiver) acceptFirstConn() (err error) { 65 if r.conf.FirstConnTimeout != 0 { 66 defer func() { 67 if err == nil { 68 err = r.listener.SetDeadline(time.Time{}) 69 } 70 }() 71 err = r.listener.SetDeadline(time.Now().Add(r.conf.FirstConnTimeout)) 72 if err != nil { 73 return err 74 } 75 } 76 conn, err := r.listener.Accept() 77 if err != nil { 78 return err 79 } 80 return r.queueConn(conn) 81 } 82 83 func (r *ConnReceiver) acceptConns() (err error) { 84 var conn net.Conn 85 for { 86 conn, err = r.listener.Accept() 87 if err != nil { 88 return 89 } 90 if err = r.queueConn(conn); err != nil { 91 return 92 } 93 } 94 } 95 96 func (r *ConnReceiver) queueConn(conn net.Conn) error { 97 select { 98 case <-r.stopped: 99 if err := conn.Close(); err != nil { 100 panic(err) 101 } 102 return ErrClosed 103 case r.conns <- conn: 104 return nil 105 } 106 } 107 108 func (r *ConnReceiver) terminate(state int) (err error) { 109 r.mu.Lock() 110 defer r.mu.Unlock() 111 if r.state == inProgress { 112 r.state = state 113 if err = r.listener.Close(); err != nil { 114 r.state = inProgress 115 return 116 } 117 close(r.stopped) 118 } 119 return 120 } 121 122 func (r *ConnReceiver) correctErr(err error) error { 123 r.mu.Lock() 124 defer r.mu.Unlock() 125 switch r.state { 126 case inProgress: 127 return err 128 case shutdown: 129 return nil 130 case closed: 131 return ErrClosed 132 default: 133 panic("unexpected state") 134 } 135 } 136 137 func (r *ConnReceiver) postRun() { 138 r.mu.Lock() 139 defer r.mu.Unlock() 140 close(r.conns) 141 if r.state == shutdown { 142 return 143 } 144 for conn := range r.conns { 145 if err := conn.Close(); err != nil { 146 panic(err) 147 } 148 } 149 }