go.dedis.ch/onet/v4@v4.0.0-pre1/network/tcp.go (about)

     1  package network
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"io"
     7  	"net"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"go.dedis.ch/onet/v4/log"
    13  	"golang.org/x/xerrors"
    14  )
    15  
    16  // a connection will return an io.EOF after networkTimeout if nothing has been
    17  // received. sends and connects will timeout using this timeout as well.
    18  var timeout = 1 * time.Minute
    19  
    20  // dialTimeout is the timeout for connecting to an end point.
    21  var dialTimeout = 1 * time.Minute
    22  
    23  // Global lock for 'timeout' (because also used in 'tcp_test.go')
    24  // Using a 'RWMutex' to be as efficient as possible, because it will be used
    25  // quite a lot in 'Receive()'.
    26  var timeoutLock = sync.RWMutex{}
    27  
    28  // MaxPacketSize limits the amount of memory that is allocated before a packet
    29  // is checked and thrown away if it's not legit. If you need more than 10MB
    30  // packets, increase this value.
    31  var MaxPacketSize = Size(10 * 1024 * 1024)
    32  
    33  // NewTCPAddress returns a new Address that has type PlainTCP with the given
    34  // address addr.
    35  func NewTCPAddress(addr string) Address {
    36  	return NewAddress(PlainTCP, addr)
    37  }
    38  
    39  // NewTCPRouter returns a new Router using TCPHost as the underlying Host.
    40  func NewTCPRouter(sid *ServerIdentity, suite Suite) (*Router, error) {
    41  	r, err := NewTCPRouterWithListenAddr(sid, suite, "")
    42  	if err != nil {
    43  		return nil, xerrors.Errorf("tcp router: %v", err)
    44  	}
    45  	return r, nil
    46  }
    47  
    48  // NewTCPRouterWithListenAddr returns a new Router using TCPHost with the
    49  // given listen address as the underlying Host.
    50  func NewTCPRouterWithListenAddr(sid *ServerIdentity, suite Suite,
    51  	listenAddr string) (*Router, error) {
    52  	h, err := NewTCPHostWithListenAddr(sid, suite, listenAddr)
    53  	if err != nil {
    54  		return nil, xerrors.Errorf("tcp router: %v", err)
    55  	}
    56  	r := NewRouter(sid, h)
    57  	return r, nil
    58  }
    59  
    60  // SetTCPDialTimeout sets the dialing timeout for the TCP connection. The
    61  // default is one minute. This function is not thread-safe.
    62  func SetTCPDialTimeout(dur time.Duration) {
    63  	dialTimeout = dur
    64  }
    65  
    66  // TCPConn implements the Conn interface using plain, unencrypted TCP.
    67  type TCPConn struct {
    68  	// The connection used
    69  	conn net.Conn
    70  
    71  	// the suite used to unmarshal messages
    72  	suite Suite
    73  
    74  	// closed indicator
    75  	closed    bool
    76  	closedMut sync.Mutex
    77  	// So we only handle one receiving packet at a time
    78  	receiveMutex sync.Mutex
    79  	// So we only handle one sending packet at a time
    80  	sendMutex sync.Mutex
    81  
    82  	counterSafe
    83  
    84  	// a hook to let us test dead servers
    85  	receiveRawTest func() ([]byte, error)
    86  }
    87  
    88  // NewTCPConn will open a TCPConn to the given address.
    89  // In case of an error it returns a nil TCPConn and the error.
    90  func NewTCPConn(addr Address, suite Suite) (conn *TCPConn, err error) {
    91  	netAddr := addr.NetworkAddress()
    92  	for i := 1; i <= MaxRetryConnect; i++ {
    93  		var c net.Conn
    94  		c, err = net.DialTimeout("tcp", netAddr, dialTimeout)
    95  		if err == nil {
    96  			conn = &TCPConn{
    97  				conn:  c,
    98  				suite: suite,
    99  			}
   100  			return
   101  		}
   102  		err = xerrors.Errorf("dial: %v", err)
   103  		if i < MaxRetryConnect {
   104  			time.Sleep(WaitRetry)
   105  		}
   106  	}
   107  	if err == nil {
   108  		err = xerrors.Errorf("timeout: %w", ErrTimeout)
   109  	}
   110  	return
   111  }
   112  
   113  // Receive get the bytes from the connection then decodes the buffer.
   114  // It returns the Envelope containing the message,
   115  // or EmptyEnvelope and an error if something wrong happened.
   116  func (c *TCPConn) Receive() (env *Envelope, e error) {
   117  	buff, err := c.receiveRaw()
   118  	if err != nil {
   119  		return nil, xerrors.Errorf("receiving: %w", err)
   120  	}
   121  
   122  	id, body, err := Unmarshal(buff, c.suite)
   123  	return &Envelope{
   124  		MsgType: id,
   125  		Msg:     body,
   126  		Size:    Size(len(buff)),
   127  	}, err
   128  }
   129  
   130  func (c *TCPConn) receiveRaw() ([]byte, error) {
   131  	if c.receiveRawTest != nil {
   132  		return c.receiveRawTest()
   133  	}
   134  	return c.receiveRawProd()
   135  }
   136  
   137  // receiveRawProd reads the size of the message, then the
   138  // whole message. It returns the raw message as slice of bytes.
   139  // If there is no message available, it blocks until one becomes
   140  // available.
   141  // In case of an error it returns a nil slice and the error.
   142  func (c *TCPConn) receiveRawProd() ([]byte, error) {
   143  	c.receiveMutex.Lock()
   144  	defer c.receiveMutex.Unlock()
   145  	timeoutLock.RLock()
   146  	c.conn.SetReadDeadline(time.Now().Add(timeout))
   147  	timeoutLock.RUnlock()
   148  	// First read the size
   149  	var total Size
   150  	if err := binary.Read(c.conn, globalOrder, &total); err != nil {
   151  		return nil, xerrors.Errorf("buffer read: %w", handleError(err))
   152  	}
   153  	if total > MaxPacketSize {
   154  		return nil, xerrors.Errorf("%v sends too big packet: %v>%v",
   155  			c.conn.RemoteAddr().String(), total, MaxPacketSize)
   156  	}
   157  
   158  	b := make([]byte, total)
   159  	var read Size
   160  	var buffer bytes.Buffer
   161  	for read < total {
   162  		// Read the size of the next packet.
   163  		timeoutLock.RLock()
   164  		c.conn.SetReadDeadline(time.Now().Add(timeout))
   165  		timeoutLock.RUnlock()
   166  		n, err := c.conn.Read(b)
   167  		// Quit if there is an error.
   168  		if err != nil {
   169  			c.updateRx(4 + uint64(read))
   170  			return nil, xerrors.Errorf("reading: %w", handleError(err))
   171  		}
   172  		// Append the read bytes into the buffer.
   173  		if _, err := buffer.Write(b[:n]); err != nil {
   174  			log.Error("Couldn't write to buffer:", err)
   175  		}
   176  		read += Size(n)
   177  		b = b[n:]
   178  	}
   179  
   180  	// register how many bytes we read. (4 is for the frame size
   181  	// that we read up above).
   182  	c.updateRx(4 + uint64(read))
   183  	return buffer.Bytes(), nil
   184  }
   185  
   186  // Send converts the NetworkMessage into an ApplicationMessage
   187  // and sends it using send().
   188  // It returns the number of bytes sent and an error if anything was wrong.
   189  func (c *TCPConn) Send(msg Message) (uint64, error) {
   190  	c.sendMutex.Lock()
   191  	defer c.sendMutex.Unlock()
   192  
   193  	b, err := Marshal(msg)
   194  	if err != nil {
   195  		return 0, xerrors.Errorf("Error marshaling  message: %s", err.Error())
   196  	}
   197  	len, err := c.sendRaw(b)
   198  	if err != nil {
   199  		return len, xerrors.Errorf("sending: %w", err)
   200  	}
   201  	return len, nil
   202  }
   203  
   204  // sendRaw writes the number of bytes of the message to the network then the
   205  // whole message b in slices of size maxChunkSize.
   206  // In case of an error it aborts.
   207  func (c *TCPConn) sendRaw(b []byte) (uint64, error) {
   208  	timeoutLock.RLock()
   209  	c.conn.SetWriteDeadline(time.Now().Add(timeout))
   210  	timeoutLock.RUnlock()
   211  
   212  	// First write the size
   213  	packetSize := Size(len(b))
   214  	if err := binary.Write(c.conn, globalOrder, packetSize); err != nil {
   215  		return 0, xerrors.Errorf("buffer write: %v", err)
   216  	}
   217  	// Then send everything through the connection
   218  	// Send chunk by chunk
   219  	log.Lvl5("Sending from", c.conn.LocalAddr(), "to", c.conn.RemoteAddr())
   220  	var sent Size
   221  	for sent < packetSize {
   222  		n, err := c.conn.Write(b[sent:])
   223  		if err != nil {
   224  			sentLen := 4 + uint64(sent)
   225  			c.updateTx(sentLen)
   226  			return sentLen, xerrors.Errorf("sending: %w", handleError(err))
   227  		}
   228  		sent += Size(n)
   229  	}
   230  	// update stats on the connection. Plus 4 for the uint32 for the frame size.
   231  	sentLen := 4 + uint64(sent)
   232  	c.updateTx(sentLen)
   233  	return sentLen, nil
   234  }
   235  
   236  // Remote returns the name of the peer at the end point of
   237  // the connection.
   238  func (c *TCPConn) Remote() Address {
   239  	return Address(c.conn.RemoteAddr().String())
   240  }
   241  
   242  // Local returns the local address and port.
   243  func (c *TCPConn) Local() Address {
   244  	return NewTCPAddress(c.conn.LocalAddr().String())
   245  }
   246  
   247  // Type returns PlainTCP.
   248  func (c *TCPConn) Type() ConnType {
   249  	return PlainTCP
   250  }
   251  
   252  // Close the connection.
   253  // Returns error if it couldn't close the connection.
   254  func (c *TCPConn) Close() error {
   255  	c.closedMut.Lock()
   256  	defer c.closedMut.Unlock()
   257  	if c.closed == true {
   258  		return xerrors.Errorf("closing: %w", ErrClosed)
   259  	}
   260  	err := c.conn.Close()
   261  	c.closed = true
   262  	if err != nil {
   263  		return xerrors.Errorf("closing: %w", handleError(err))
   264  	}
   265  	return nil
   266  }
   267  
   268  // handleError translates the network-layer error to a set of errors
   269  // used in our packages.
   270  func handleError(err error) error {
   271  	if strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "broken pipe") {
   272  		return ErrClosed
   273  	} else if strings.Contains(err.Error(), "canceled") {
   274  		return ErrCanceled
   275  	} else if err == io.EOF || strings.Contains(err.Error(), "EOF") {
   276  		return ErrEOF
   277  	}
   278  
   279  	netErr, ok := err.(net.Error)
   280  	if !ok {
   281  		return ErrUnknown
   282  	}
   283  	if netErr.Timeout() {
   284  		return ErrTimeout
   285  	}
   286  
   287  	log.Errorf("Unknown error caught: %s", err.Error())
   288  	return ErrUnknown
   289  }
   290  
   291  // TCPListener implements the Host-interface using Tcp as a communication
   292  // channel.
   293  type TCPListener struct {
   294  	// the underlying golang/net listener.
   295  	listener net.Listener
   296  	// the close channel used to indicate to the listener we want to quit.
   297  	quit chan bool
   298  	// quitListener is a channel to indicate to the closing function that the
   299  	// listener has actually really quit.
   300  	quitListener  chan bool
   301  	listeningLock sync.Mutex
   302  	listening     bool
   303  
   304  	// closed tells the listen routine to return immediately if a
   305  	// Stop() has been called.
   306  	closed bool
   307  
   308  	// actual listening addr which might differ from initial address in
   309  	// case of ":0"-address.
   310  	addr net.Addr
   311  
   312  	// Is this a TCP or a TLS listener?
   313  	conntype ConnType
   314  
   315  	// suite that is given to each incoming connection
   316  	suite Suite
   317  }
   318  
   319  // NewTCPListener returns a TCPListener. This function binds globally using
   320  // the port of 'addr'.
   321  // It returns the listener and an error if one occurred during
   322  // the binding.
   323  // A subsequent call to Address() gives the actual listening
   324  // address which is different if you gave it a ":0"-address.
   325  func NewTCPListener(addr Address, s Suite) (*TCPListener, error) {
   326  	l, err := NewTCPListenerWithListenAddr(addr, s, "")
   327  	if err != nil {
   328  		return nil, xerrors.Errorf("tcp listener: %v", err)
   329  	}
   330  	return l, nil
   331  }
   332  
   333  // NewTCPListenerWithListenAddr returns a TCPListener. This function binds to the
   334  // given 'listenAddr'. If it is empty, the function binds globally using
   335  // the port of 'addr'.
   336  // It returns the listener and an error if one occurred during
   337  // the binding.
   338  // A subsequent call to Address() gives the actual listening
   339  // address which is different if you gave it a ":0"-address.
   340  func NewTCPListenerWithListenAddr(addr Address,
   341  	s Suite, listenAddr string) (*TCPListener, error) {
   342  	if addr.ConnType() != PlainTCP && addr.ConnType() != TLS {
   343  		return nil, xerrors.New("TCPListener can only listen on TCP and TLS addresses")
   344  	}
   345  	t := &TCPListener{
   346  		conntype:     addr.ConnType(),
   347  		quit:         make(chan bool),
   348  		quitListener: make(chan bool),
   349  		suite:        s,
   350  	}
   351  	listenOn, err := getListenAddress(addr, listenAddr)
   352  	if err != nil {
   353  		return nil, xerrors.Errorf("listener: %v", err)
   354  	}
   355  	for i := 0; i < MaxRetryConnect; i++ {
   356  		ln, err := net.Listen("tcp", listenOn)
   357  		if err == nil {
   358  			t.listener = ln
   359  			break
   360  		} else if i == MaxRetryConnect-1 {
   361  			return nil, xerrors.New("Error opening listener: " + err.Error())
   362  		}
   363  		time.Sleep(WaitRetry)
   364  	}
   365  	t.addr = t.listener.Addr()
   366  	return t, nil
   367  }
   368  
   369  // Listen starts to listen for incoming connections and calls fn for every
   370  // connection-request it receives.
   371  // If the connection is closed, an error will be returned.
   372  func (t *TCPListener) Listen(fn func(Conn)) error {
   373  	receiver := func(tc Conn) {
   374  		go fn(tc)
   375  	}
   376  	err := t.listen(receiver)
   377  	if err != nil {
   378  		return xerrors.Errorf("listening: %v", err)
   379  	}
   380  	return nil
   381  }
   382  
   383  // listen is the private function that takes a function that takes a TCPConn.
   384  // That way we can control what to do of the TCPConn before returning it to the
   385  // function given by the user. fn is called in the same routine.
   386  func (t *TCPListener) listen(fn func(Conn)) error {
   387  	t.listeningLock.Lock()
   388  	if t.closed == true {
   389  		t.listeningLock.Unlock()
   390  		return nil
   391  	}
   392  	t.listening = true
   393  	t.listeningLock.Unlock()
   394  	for {
   395  		conn, err := t.listener.Accept()
   396  		if err != nil {
   397  			select {
   398  			case <-t.quit:
   399  				t.quitListener <- true
   400  				return nil
   401  			default:
   402  			}
   403  			continue
   404  		}
   405  		c := TCPConn{
   406  			conn:  conn,
   407  			suite: t.suite,
   408  		}
   409  		fn(&c)
   410  	}
   411  }
   412  
   413  // Stop the listener. It waits till all connections are closed
   414  // and returned from.
   415  // If there is no listener it will return an error.
   416  func (t *TCPListener) Stop() error {
   417  	// lets see if we launched a listening routing
   418  	t.listeningLock.Lock()
   419  	defer t.listeningLock.Unlock()
   420  
   421  	close(t.quit)
   422  
   423  	if t.listener != nil {
   424  		if err := t.listener.Close(); err != nil {
   425  			if handleError(err) != ErrClosed {
   426  				return xerrors.Errorf("closing: %w", handleError(err))
   427  			}
   428  		}
   429  	}
   430  	var stop bool
   431  	if t.listening {
   432  		for !stop {
   433  			select {
   434  			case <-t.quitListener:
   435  				stop = true
   436  			case <-time.After(time.Millisecond * 50):
   437  				continue
   438  			}
   439  		}
   440  	}
   441  
   442  	t.quit = make(chan bool)
   443  	t.listening = false
   444  	t.closed = true
   445  	return nil
   446  }
   447  
   448  // Address returns the listening address.
   449  func (t *TCPListener) Address() Address {
   450  	t.listeningLock.Lock()
   451  	defer t.listeningLock.Unlock()
   452  	return NewAddress(t.conntype, t.addr.String())
   453  }
   454  
   455  // Listening returns whether it's already listening.
   456  func (t *TCPListener) Listening() bool {
   457  	t.listeningLock.Lock()
   458  	defer t.listeningLock.Unlock()
   459  	return t.listening
   460  }
   461  
   462  // getListenAddress returns the address the listener should listen
   463  // on given the server's address (addr) and the address it was told to listen
   464  // on (listenAddr), which could be empty.
   465  // Rules:
   466  // 1. If there is no listenAddr, bind globally with addr.
   467  // 2. If there is only an IP in listenAddr, take the port from addr.
   468  // 3. If there is an IP:Port in listenAddr, take only listenAddr.
   469  // Otherwise return an error.
   470  func getListenAddress(addr Address, listenAddr string) (string, error) {
   471  	// If no `listenAddr`, bind globally.
   472  	if listenAddr == "" {
   473  		return GlobalBind(addr.NetworkAddress())
   474  	}
   475  	_, port, err := net.SplitHostPort(addr.NetworkAddress())
   476  	if err != nil {
   477  		return "", xerrors.Errorf("invalid address: %v", err)
   478  	}
   479  
   480  	// If 'listenAddr' only contains the host, combine it with the port
   481  	// of 'addr'.
   482  	splitted := strings.Split(listenAddr, ":")
   483  	if len(splitted) == 1 && port != "" {
   484  		return splitted[0] + ":" + port, nil
   485  	}
   486  
   487  	// If host and port in `listenAddr`, choose this one.
   488  	hostListen, portListen, err := net.SplitHostPort(listenAddr)
   489  	if err != nil {
   490  		return "", xerrors.Errorf("invalid address: %v", err)
   491  	}
   492  	if hostListen != "" && portListen != "" {
   493  		return listenAddr, nil
   494  	}
   495  
   496  	return "", xerrors.Errorf("Invalid combination of 'addr' (%s) and 'listenAddr' (%s)", addr.NetworkAddress(), listenAddr)
   497  }
   498  
   499  // TCPHost implements the Host interface using TCP connections.
   500  type TCPHost struct {
   501  	suite Suite
   502  	sid   *ServerIdentity
   503  	*TCPListener
   504  }
   505  
   506  // NewTCPHost returns a new Host using TCP connection based type.
   507  func NewTCPHost(sid *ServerIdentity, s Suite) (*TCPHost, error) {
   508  	host, err := NewTCPHostWithListenAddr(sid, s, "")
   509  	if err != nil {
   510  		return nil, xerrors.Errorf("tcp host: %v", err)
   511  	}
   512  	return host, nil
   513  }
   514  
   515  // NewTCPHostWithListenAddr returns a new Host using TCP connection based type
   516  // listening on the given address.
   517  func NewTCPHostWithListenAddr(sid *ServerIdentity, s Suite,
   518  	listenAddr string) (*TCPHost, error) {
   519  	h := &TCPHost{
   520  		suite: s,
   521  		sid:   sid,
   522  	}
   523  	var err error
   524  	if sid.Address.ConnType() == TLS {
   525  		h.TCPListener, err = NewTLSListenerWithListenAddr(sid, s, listenAddr)
   526  	} else {
   527  		h.TCPListener, err = NewTCPListenerWithListenAddr(sid.Address, s, listenAddr)
   528  	}
   529  	if err != nil {
   530  		return nil, xerrors.Errorf("tcp host: %v", err)
   531  	}
   532  	return h, nil
   533  }
   534  
   535  // Connect can only connect to PlainTCP connections.
   536  // It will return an error if it is not a PlainTCP-connection-type.
   537  func (t *TCPHost) Connect(si *ServerIdentity) (Conn, error) {
   538  	switch si.Address.ConnType() {
   539  	case PlainTCP:
   540  		c, err := NewTCPConn(si.Address, t.suite)
   541  		if err != nil {
   542  			return nil, xerrors.Errorf("tcp connection: %v", err)
   543  		}
   544  		return c, nil
   545  	case TLS:
   546  		c, err := NewTLSConn(t.sid, si, t.suite)
   547  		if err != nil {
   548  			return nil, xerrors.Errorf("tcp connection: %v", err)
   549  		}
   550  		return c, nil
   551  	case InvalidConnType:
   552  		return nil, xerrors.New("This address is not correctly formatted: " + si.Address.String())
   553  	}
   554  	return nil, xerrors.Errorf("TCPHost %s can't handle this type of connection: %s", si.Address, si.Address.ConnType())
   555  }