github.com/cnotch/ipchub@v1.1.0/network/socket/listener/listener.go (about)

     1  /**********************************************************************************
     2  * Copyright (c) 2009-2017 Misakai Ltd.
     3  * This program is free software: you can redistribute it and/or modify it under the
     4  * terms of the GNU Affero General Public License as published by the  Free Software
     5  * Foundation, either version 3 of the License, or(at your option) any later version.
     6  *
     7  * This program is distributed  in the hope that it  will be useful, but WITHOUT ANY
     8  * WARRANTY;  without even  the implied warranty of MERCHANTABILITY or FITNESS FOR A
     9  * PARTICULAR PURPOSE.  See the GNU Affero General Public License  for  more details.
    10  *
    11  * You should have  received a copy  of the  GNU Affero General Public License along
    12  * with this program. If not, see<http://www.gnu.org/licenses/>.
    13  ************************************************************************************/
    14  
    15  package listener
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/tls"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  )
    26  
    27  // Server represents a server which can serve requests.
    28  type Server interface {
    29  	Serve(listener net.Listener)
    30  }
    31  
    32  // Matcher matches a connection based on its content.
    33  type Matcher func(io.Reader) bool
    34  
    35  // SettingsHandler 处理连接使用前的设置
    36  type SettingsHandler func(net.Conn)
    37  
    38  // ErrorHandler handles an error and notifies the listener on whether
    39  // it should continue serving.
    40  type ErrorHandler func(error) bool
    41  
    42  var _ net.Error = ErrNotMatched{}
    43  
    44  // ErrNotMatched is returned whenever a connection is not matched by any of
    45  // the matchers registered in the multiplexer.
    46  type ErrNotMatched struct {
    47  	c net.Conn
    48  }
    49  
    50  func (e ErrNotMatched) Error() string {
    51  	return fmt.Sprintf("Unable to match connection %v", e.c.RemoteAddr())
    52  }
    53  
    54  // Temporary implements the net.Error interface.
    55  func (e ErrNotMatched) Temporary() bool { return true }
    56  
    57  // Timeout implements the net.Error interface.
    58  func (e ErrNotMatched) Timeout() bool { return false }
    59  
    60  type errListenerClosed string
    61  
    62  func (e errListenerClosed) Error() string   { return string(e) }
    63  func (e errListenerClosed) Temporary() bool { return false }
    64  func (e errListenerClosed) Timeout() bool   { return false }
    65  
    66  // ErrListenerClosed is returned from muxListener.Accept when the underlying
    67  // listener is closed.
    68  var ErrListenerClosed = errListenerClosed("mux: listener closed")
    69  
    70  // for readability of readTimeout
    71  var noTimeout time.Duration
    72  
    73  // New announces on the local network address laddr. The syntax of laddr is
    74  // "host:port", like "127.0.0.1:8080". If host is omitted, as in ":8080",
    75  // New listens on all available interfaces instead of just the interface
    76  // with the given host address. Listening on a hostname is not recommended
    77  // because this creates a socket for at most one of its IP addresses.
    78  func New(address string, config *tls.Config) (*Listener, error) {
    79  	l, err := net.Listen("tcp", address)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	// If we have a TLS configuration provided, wrap the listener in TLS
    85  	if config != nil {
    86  		l = tls.NewListener(l, config)
    87  	}
    88  
    89  	return &Listener{
    90  		root:            l,
    91  		bufferSize:      1024,
    92  		errorHandler:    func(_ error) bool { return true },
    93  		closing:         make(chan struct{}),
    94  		readTimeout:     noTimeout,
    95  		settingsHandler: func(_ net.Conn) {},
    96  	}, nil
    97  }
    98  
    99  type processor struct {
   100  	matchers []Matcher
   101  	listen   muxListener
   102  }
   103  
   104  // Listener represents a listener used for multiplexing protocols.
   105  type Listener struct {
   106  	root            net.Listener
   107  	bufferSize      int
   108  	errorHandler    ErrorHandler
   109  	closing         chan struct{}
   110  	matchers        []processor
   111  	readTimeout     time.Duration
   112  	settingsHandler SettingsHandler
   113  }
   114  
   115  // Accept waits for and returns the next connection to the listener.
   116  func (m *Listener) Accept() (net.Conn, error) {
   117  	return m.root.Accept()
   118  }
   119  
   120  // ServeAsync adds a protocol based on the matcher and serves it.
   121  func (m *Listener) ServeAsync(matcher Matcher, serve func(l net.Listener) error) {
   122  	l := m.Match(matcher)
   123  	go serve(l)
   124  }
   125  
   126  // Match returns a net.Listener that sees (i.e., accepts) only
   127  // the connections matched by at least one of the matcher.
   128  func (m *Listener) Match(matchers ...Matcher) net.Listener {
   129  	ml := muxListener{
   130  		Listener:    m.root,
   131  		connections: make(chan net.Conn, m.bufferSize),
   132  	}
   133  	m.matchers = append(m.matchers, processor{matchers: matchers, listen: ml})
   134  	return ml
   135  }
   136  
   137  // SetReadTimeout sets a timeout for the read of matchers.
   138  func (m *Listener) SetReadTimeout(t time.Duration) {
   139  	m.readTimeout = t
   140  }
   141  
   142  // Serve starts multiplexing the listener.
   143  func (m *Listener) Serve() error {
   144  	var wg sync.WaitGroup
   145  
   146  	defer func() {
   147  		close(m.closing)
   148  		wg.Wait()
   149  
   150  		for _, sl := range m.matchers {
   151  			close(sl.listen.connections)
   152  			// Drain the connections enqueued for the listener.
   153  			for c := range sl.listen.connections {
   154  				_ = c.Close()
   155  			}
   156  		}
   157  	}()
   158  
   159  	for {
   160  		c, err := m.root.Accept()
   161  		if err != nil {
   162  			if !m.handleErr(err) {
   163  				return err
   164  			}
   165  			continue
   166  		}
   167  
   168  		wg.Add(1)
   169  		go m.serve(c, m.closing, &wg)
   170  	}
   171  }
   172  
   173  func (m *Listener) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
   174  	defer wg.Done()
   175  
   176  	m.settingsHandler(c)
   177  
   178  	muc := newConn(c)
   179  	if m.readTimeout > noTimeout {
   180  		_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
   181  	}
   182  	for _, sl := range m.matchers {
   183  		for _, processor := range sl.matchers {
   184  			matched := processor(muc.startSniffing())
   185  			if matched {
   186  				muc.doneSniffing()
   187  				if m.readTimeout > noTimeout {
   188  					_ = c.SetReadDeadline(time.Time{})
   189  				}
   190  				select {
   191  				case sl.listen.connections <- muc:
   192  				case <-donec:
   193  					_ = c.Close()
   194  				}
   195  				return
   196  			}
   197  		}
   198  	}
   199  
   200  	_ = c.Close()
   201  	err := ErrNotMatched{c: c}
   202  	if !m.handleErr(err) {
   203  		_ = m.root.Close()
   204  	}
   205  }
   206  
   207  // HandleSettings 处理连接设置的函数,给予调用者一个干预系统级设置的机会
   208  func (m *Listener) HandleSettings(h SettingsHandler) {
   209  	if h != nil {
   210  		m.settingsHandler = h
   211  	}
   212  }
   213  
   214  // HandleError registers an error handler that handles listener errors.
   215  func (m *Listener) HandleError(h ErrorHandler) {
   216  	m.errorHandler = h
   217  }
   218  
   219  func (m *Listener) handleErr(err error) bool {
   220  	if !m.errorHandler(err) {
   221  		return false
   222  	}
   223  
   224  	if ne, ok := err.(net.Error); ok {
   225  		return ne.Temporary()
   226  	}
   227  
   228  	return false
   229  }
   230  
   231  // Close closes the listener
   232  func (m *Listener) Close() error {
   233  	return m.root.Close()
   234  }
   235  
   236  // Addr returns the listener's network address.
   237  func (m *Listener) Addr() net.Addr {
   238  	return m.root.Addr()
   239  }
   240  
   241  // ------------------------------------------------------------------------------------
   242  
   243  type muxListener struct {
   244  	net.Listener
   245  	connections chan net.Conn
   246  }
   247  
   248  func (l muxListener) Accept() (net.Conn, error) {
   249  	c, ok := <-l.connections
   250  	if !ok {
   251  		return nil, ErrListenerClosed
   252  	}
   253  	return c, nil
   254  }
   255  
   256  // ------------------------------------------------------------------------------------
   257  
   258  // Conn wraps a net.Conn and provides transparent sniffing of connection data.
   259  type Conn struct {
   260  	net.Conn
   261  	sniffer sniffer
   262  	reader io.Reader
   263  }
   264  
   265  // NewConn creates a new sniffed connection.
   266  func newConn(c net.Conn) *Conn {
   267  	m := &Conn{
   268  		Conn:   c,
   269  		sniffer: sniffer{source: c},
   270  	}
   271  
   272  	m.sniffer.conn = m
   273  	m.reader = &m.sniffer
   274  	return m
   275  }
   276  
   277  // Read reads the block of data from the underlying buffer.
   278  func (m *Conn) Read(p []byte) (int, error) {
   279  	return m.reader.Read(p)
   280  }
   281  
   282  func (m *Conn) startSniffing() io.Reader {
   283  	m.sniffer.reset(true)
   284  	return &m.sniffer
   285  }
   286  
   287  func (m *Conn) doneSniffing() {
   288  	m.sniffer.reset(false)
   289  }
   290  
   291  // ------------------------------------------------------------------------------------
   292  
   293  // Sniffer represents a io.Reader which can peek incoming bytes and reset back to normal.
   294  type sniffer struct {
   295  	conn       *Conn
   296  	source     io.Reader
   297  	buffer     bytes.Buffer
   298  	bufferRead int
   299  	bufferSize int
   300  	sniffing   bool
   301  	lastErr    error
   302  }
   303  
   304  // Read reads data from the buffer.
   305  func (s *sniffer) Read(p []byte) (int, error) {
   306  	if s.bufferSize > s.bufferRead {
   307  		bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize])
   308  		s.bufferRead += bn
   309  		return bn, s.lastErr
   310  	} else if !s.sniffing && s.buffer.Cap() != 0 {
   311  		s.buffer = bytes.Buffer{}
   312  		s.conn.reader = s.conn.Conn // 重置到直接从Conn读取,减少判断
   313  	}
   314  
   315  	sn, sErr := s.source.Read(p)
   316  	if sn > 0 && s.sniffing {
   317  		s.lastErr = sErr
   318  		if wn, wErr := s.buffer.Write(p[:sn]); wErr != nil {
   319  			return wn, wErr
   320  		}
   321  	}
   322  	return sn, sErr
   323  }
   324  
   325  // Reset resets the buffer.
   326  func (s *sniffer) reset(snif bool) {
   327  	s.sniffing = snif
   328  	s.bufferRead = 0
   329  	s.bufferSize = s.buffer.Len()
   330  }