github.com/cnotch/ipchub@v1.1.0/network/socket/listener/listener.go (about) 1 /********************************************************************************** 2 * Copyright (c) 2009-2017 Misakai Ltd. 3 * This program is free software: you can redistribute it and/or modify it under the 4 * terms of the GNU Affero General Public License as published by the Free Software 5 * Foundation, either version 3 of the License, or(at your option) any later version. 6 * 7 * This program is distributed in the hope that it will be useful, but WITHOUT ANY 8 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 9 * PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 10 * 11 * You should have received a copy of the GNU Affero General Public License along 12 * with this program. If not, see<http://www.gnu.org/licenses/>. 13 ************************************************************************************/ 14 15 package listener 16 17 import ( 18 "bytes" 19 "crypto/tls" 20 "fmt" 21 "io" 22 "net" 23 "sync" 24 "time" 25 ) 26 27 // Server represents a server which can serve requests. 28 type Server interface { 29 Serve(listener net.Listener) 30 } 31 32 // Matcher matches a connection based on its content. 33 type Matcher func(io.Reader) bool 34 35 // SettingsHandler 处理连接使用前的设置 36 type SettingsHandler func(net.Conn) 37 38 // ErrorHandler handles an error and notifies the listener on whether 39 // it should continue serving. 40 type ErrorHandler func(error) bool 41 42 var _ net.Error = ErrNotMatched{} 43 44 // ErrNotMatched is returned whenever a connection is not matched by any of 45 // the matchers registered in the multiplexer. 46 type ErrNotMatched struct { 47 c net.Conn 48 } 49 50 func (e ErrNotMatched) Error() string { 51 return fmt.Sprintf("Unable to match connection %v", e.c.RemoteAddr()) 52 } 53 54 // Temporary implements the net.Error interface. 55 func (e ErrNotMatched) Temporary() bool { return true } 56 57 // Timeout implements the net.Error interface. 58 func (e ErrNotMatched) Timeout() bool { return false } 59 60 type errListenerClosed string 61 62 func (e errListenerClosed) Error() string { return string(e) } 63 func (e errListenerClosed) Temporary() bool { return false } 64 func (e errListenerClosed) Timeout() bool { return false } 65 66 // ErrListenerClosed is returned from muxListener.Accept when the underlying 67 // listener is closed. 68 var ErrListenerClosed = errListenerClosed("mux: listener closed") 69 70 // for readability of readTimeout 71 var noTimeout time.Duration 72 73 // New announces on the local network address laddr. The syntax of laddr is 74 // "host:port", like "127.0.0.1:8080". If host is omitted, as in ":8080", 75 // New listens on all available interfaces instead of just the interface 76 // with the given host address. Listening on a hostname is not recommended 77 // because this creates a socket for at most one of its IP addresses. 78 func New(address string, config *tls.Config) (*Listener, error) { 79 l, err := net.Listen("tcp", address) 80 if err != nil { 81 return nil, err 82 } 83 84 // If we have a TLS configuration provided, wrap the listener in TLS 85 if config != nil { 86 l = tls.NewListener(l, config) 87 } 88 89 return &Listener{ 90 root: l, 91 bufferSize: 1024, 92 errorHandler: func(_ error) bool { return true }, 93 closing: make(chan struct{}), 94 readTimeout: noTimeout, 95 settingsHandler: func(_ net.Conn) {}, 96 }, nil 97 } 98 99 type processor struct { 100 matchers []Matcher 101 listen muxListener 102 } 103 104 // Listener represents a listener used for multiplexing protocols. 105 type Listener struct { 106 root net.Listener 107 bufferSize int 108 errorHandler ErrorHandler 109 closing chan struct{} 110 matchers []processor 111 readTimeout time.Duration 112 settingsHandler SettingsHandler 113 } 114 115 // Accept waits for and returns the next connection to the listener. 116 func (m *Listener) Accept() (net.Conn, error) { 117 return m.root.Accept() 118 } 119 120 // ServeAsync adds a protocol based on the matcher and serves it. 121 func (m *Listener) ServeAsync(matcher Matcher, serve func(l net.Listener) error) { 122 l := m.Match(matcher) 123 go serve(l) 124 } 125 126 // Match returns a net.Listener that sees (i.e., accepts) only 127 // the connections matched by at least one of the matcher. 128 func (m *Listener) Match(matchers ...Matcher) net.Listener { 129 ml := muxListener{ 130 Listener: m.root, 131 connections: make(chan net.Conn, m.bufferSize), 132 } 133 m.matchers = append(m.matchers, processor{matchers: matchers, listen: ml}) 134 return ml 135 } 136 137 // SetReadTimeout sets a timeout for the read of matchers. 138 func (m *Listener) SetReadTimeout(t time.Duration) { 139 m.readTimeout = t 140 } 141 142 // Serve starts multiplexing the listener. 143 func (m *Listener) Serve() error { 144 var wg sync.WaitGroup 145 146 defer func() { 147 close(m.closing) 148 wg.Wait() 149 150 for _, sl := range m.matchers { 151 close(sl.listen.connections) 152 // Drain the connections enqueued for the listener. 153 for c := range sl.listen.connections { 154 _ = c.Close() 155 } 156 } 157 }() 158 159 for { 160 c, err := m.root.Accept() 161 if err != nil { 162 if !m.handleErr(err) { 163 return err 164 } 165 continue 166 } 167 168 wg.Add(1) 169 go m.serve(c, m.closing, &wg) 170 } 171 } 172 173 func (m *Listener) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { 174 defer wg.Done() 175 176 m.settingsHandler(c) 177 178 muc := newConn(c) 179 if m.readTimeout > noTimeout { 180 _ = c.SetReadDeadline(time.Now().Add(m.readTimeout)) 181 } 182 for _, sl := range m.matchers { 183 for _, processor := range sl.matchers { 184 matched := processor(muc.startSniffing()) 185 if matched { 186 muc.doneSniffing() 187 if m.readTimeout > noTimeout { 188 _ = c.SetReadDeadline(time.Time{}) 189 } 190 select { 191 case sl.listen.connections <- muc: 192 case <-donec: 193 _ = c.Close() 194 } 195 return 196 } 197 } 198 } 199 200 _ = c.Close() 201 err := ErrNotMatched{c: c} 202 if !m.handleErr(err) { 203 _ = m.root.Close() 204 } 205 } 206 207 // HandleSettings 处理连接设置的函数,给予调用者一个干预系统级设置的机会 208 func (m *Listener) HandleSettings(h SettingsHandler) { 209 if h != nil { 210 m.settingsHandler = h 211 } 212 } 213 214 // HandleError registers an error handler that handles listener errors. 215 func (m *Listener) HandleError(h ErrorHandler) { 216 m.errorHandler = h 217 } 218 219 func (m *Listener) handleErr(err error) bool { 220 if !m.errorHandler(err) { 221 return false 222 } 223 224 if ne, ok := err.(net.Error); ok { 225 return ne.Temporary() 226 } 227 228 return false 229 } 230 231 // Close closes the listener 232 func (m *Listener) Close() error { 233 return m.root.Close() 234 } 235 236 // Addr returns the listener's network address. 237 func (m *Listener) Addr() net.Addr { 238 return m.root.Addr() 239 } 240 241 // ------------------------------------------------------------------------------------ 242 243 type muxListener struct { 244 net.Listener 245 connections chan net.Conn 246 } 247 248 func (l muxListener) Accept() (net.Conn, error) { 249 c, ok := <-l.connections 250 if !ok { 251 return nil, ErrListenerClosed 252 } 253 return c, nil 254 } 255 256 // ------------------------------------------------------------------------------------ 257 258 // Conn wraps a net.Conn and provides transparent sniffing of connection data. 259 type Conn struct { 260 net.Conn 261 sniffer sniffer 262 reader io.Reader 263 } 264 265 // NewConn creates a new sniffed connection. 266 func newConn(c net.Conn) *Conn { 267 m := &Conn{ 268 Conn: c, 269 sniffer: sniffer{source: c}, 270 } 271 272 m.sniffer.conn = m 273 m.reader = &m.sniffer 274 return m 275 } 276 277 // Read reads the block of data from the underlying buffer. 278 func (m *Conn) Read(p []byte) (int, error) { 279 return m.reader.Read(p) 280 } 281 282 func (m *Conn) startSniffing() io.Reader { 283 m.sniffer.reset(true) 284 return &m.sniffer 285 } 286 287 func (m *Conn) doneSniffing() { 288 m.sniffer.reset(false) 289 } 290 291 // ------------------------------------------------------------------------------------ 292 293 // Sniffer represents a io.Reader which can peek incoming bytes and reset back to normal. 294 type sniffer struct { 295 conn *Conn 296 source io.Reader 297 buffer bytes.Buffer 298 bufferRead int 299 bufferSize int 300 sniffing bool 301 lastErr error 302 } 303 304 // Read reads data from the buffer. 305 func (s *sniffer) Read(p []byte) (int, error) { 306 if s.bufferSize > s.bufferRead { 307 bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize]) 308 s.bufferRead += bn 309 return bn, s.lastErr 310 } else if !s.sniffing && s.buffer.Cap() != 0 { 311 s.buffer = bytes.Buffer{} 312 s.conn.reader = s.conn.Conn // 重置到直接从Conn读取,减少判断 313 } 314 315 sn, sErr := s.source.Read(p) 316 if sn > 0 && s.sniffing { 317 s.lastErr = sErr 318 if wn, wErr := s.buffer.Write(p[:sn]); wErr != nil { 319 return wn, wErr 320 } 321 } 322 return sn, sErr 323 } 324 325 // Reset resets the buffer. 326 func (s *sniffer) reset(snif bool) { 327 s.sniffing = snif 328 s.bufferRead = 0 329 s.bufferSize = s.buffer.Len() 330 }