go.dedis.ch/onet/v4@v4.0.0-pre1/network/tcp.go (about) 1 package network 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "io" 7 "net" 8 "strings" 9 "sync" 10 "time" 11 12 "go.dedis.ch/onet/v4/log" 13 "golang.org/x/xerrors" 14 ) 15 16 // a connection will return an io.EOF after networkTimeout if nothing has been 17 // received. sends and connects will timeout using this timeout as well. 18 var timeout = 1 * time.Minute 19 20 // dialTimeout is the timeout for connecting to an end point. 21 var dialTimeout = 1 * time.Minute 22 23 // Global lock for 'timeout' (because also used in 'tcp_test.go') 24 // Using a 'RWMutex' to be as efficient as possible, because it will be used 25 // quite a lot in 'Receive()'. 26 var timeoutLock = sync.RWMutex{} 27 28 // MaxPacketSize limits the amount of memory that is allocated before a packet 29 // is checked and thrown away if it's not legit. If you need more than 10MB 30 // packets, increase this value. 31 var MaxPacketSize = Size(10 * 1024 * 1024) 32 33 // NewTCPAddress returns a new Address that has type PlainTCP with the given 34 // address addr. 35 func NewTCPAddress(addr string) Address { 36 return NewAddress(PlainTCP, addr) 37 } 38 39 // NewTCPRouter returns a new Router using TCPHost as the underlying Host. 40 func NewTCPRouter(sid *ServerIdentity, suite Suite) (*Router, error) { 41 r, err := NewTCPRouterWithListenAddr(sid, suite, "") 42 if err != nil { 43 return nil, xerrors.Errorf("tcp router: %v", err) 44 } 45 return r, nil 46 } 47 48 // NewTCPRouterWithListenAddr returns a new Router using TCPHost with the 49 // given listen address as the underlying Host. 50 func NewTCPRouterWithListenAddr(sid *ServerIdentity, suite Suite, 51 listenAddr string) (*Router, error) { 52 h, err := NewTCPHostWithListenAddr(sid, suite, listenAddr) 53 if err != nil { 54 return nil, xerrors.Errorf("tcp router: %v", err) 55 } 56 r := NewRouter(sid, h) 57 return r, nil 58 } 59 60 // SetTCPDialTimeout sets the dialing timeout for the TCP connection. The 61 // default is one minute. This function is not thread-safe. 62 func SetTCPDialTimeout(dur time.Duration) { 63 dialTimeout = dur 64 } 65 66 // TCPConn implements the Conn interface using plain, unencrypted TCP. 67 type TCPConn struct { 68 // The connection used 69 conn net.Conn 70 71 // the suite used to unmarshal messages 72 suite Suite 73 74 // closed indicator 75 closed bool 76 closedMut sync.Mutex 77 // So we only handle one receiving packet at a time 78 receiveMutex sync.Mutex 79 // So we only handle one sending packet at a time 80 sendMutex sync.Mutex 81 82 counterSafe 83 84 // a hook to let us test dead servers 85 receiveRawTest func() ([]byte, error) 86 } 87 88 // NewTCPConn will open a TCPConn to the given address. 89 // In case of an error it returns a nil TCPConn and the error. 90 func NewTCPConn(addr Address, suite Suite) (conn *TCPConn, err error) { 91 netAddr := addr.NetworkAddress() 92 for i := 1; i <= MaxRetryConnect; i++ { 93 var c net.Conn 94 c, err = net.DialTimeout("tcp", netAddr, dialTimeout) 95 if err == nil { 96 conn = &TCPConn{ 97 conn: c, 98 suite: suite, 99 } 100 return 101 } 102 err = xerrors.Errorf("dial: %v", err) 103 if i < MaxRetryConnect { 104 time.Sleep(WaitRetry) 105 } 106 } 107 if err == nil { 108 err = xerrors.Errorf("timeout: %w", ErrTimeout) 109 } 110 return 111 } 112 113 // Receive get the bytes from the connection then decodes the buffer. 114 // It returns the Envelope containing the message, 115 // or EmptyEnvelope and an error if something wrong happened. 116 func (c *TCPConn) Receive() (env *Envelope, e error) { 117 buff, err := c.receiveRaw() 118 if err != nil { 119 return nil, xerrors.Errorf("receiving: %w", err) 120 } 121 122 id, body, err := Unmarshal(buff, c.suite) 123 return &Envelope{ 124 MsgType: id, 125 Msg: body, 126 Size: Size(len(buff)), 127 }, err 128 } 129 130 func (c *TCPConn) receiveRaw() ([]byte, error) { 131 if c.receiveRawTest != nil { 132 return c.receiveRawTest() 133 } 134 return c.receiveRawProd() 135 } 136 137 // receiveRawProd reads the size of the message, then the 138 // whole message. It returns the raw message as slice of bytes. 139 // If there is no message available, it blocks until one becomes 140 // available. 141 // In case of an error it returns a nil slice and the error. 142 func (c *TCPConn) receiveRawProd() ([]byte, error) { 143 c.receiveMutex.Lock() 144 defer c.receiveMutex.Unlock() 145 timeoutLock.RLock() 146 c.conn.SetReadDeadline(time.Now().Add(timeout)) 147 timeoutLock.RUnlock() 148 // First read the size 149 var total Size 150 if err := binary.Read(c.conn, globalOrder, &total); err != nil { 151 return nil, xerrors.Errorf("buffer read: %w", handleError(err)) 152 } 153 if total > MaxPacketSize { 154 return nil, xerrors.Errorf("%v sends too big packet: %v>%v", 155 c.conn.RemoteAddr().String(), total, MaxPacketSize) 156 } 157 158 b := make([]byte, total) 159 var read Size 160 var buffer bytes.Buffer 161 for read < total { 162 // Read the size of the next packet. 163 timeoutLock.RLock() 164 c.conn.SetReadDeadline(time.Now().Add(timeout)) 165 timeoutLock.RUnlock() 166 n, err := c.conn.Read(b) 167 // Quit if there is an error. 168 if err != nil { 169 c.updateRx(4 + uint64(read)) 170 return nil, xerrors.Errorf("reading: %w", handleError(err)) 171 } 172 // Append the read bytes into the buffer. 173 if _, err := buffer.Write(b[:n]); err != nil { 174 log.Error("Couldn't write to buffer:", err) 175 } 176 read += Size(n) 177 b = b[n:] 178 } 179 180 // register how many bytes we read. (4 is for the frame size 181 // that we read up above). 182 c.updateRx(4 + uint64(read)) 183 return buffer.Bytes(), nil 184 } 185 186 // Send converts the NetworkMessage into an ApplicationMessage 187 // and sends it using send(). 188 // It returns the number of bytes sent and an error if anything was wrong. 189 func (c *TCPConn) Send(msg Message) (uint64, error) { 190 c.sendMutex.Lock() 191 defer c.sendMutex.Unlock() 192 193 b, err := Marshal(msg) 194 if err != nil { 195 return 0, xerrors.Errorf("Error marshaling message: %s", err.Error()) 196 } 197 len, err := c.sendRaw(b) 198 if err != nil { 199 return len, xerrors.Errorf("sending: %w", err) 200 } 201 return len, nil 202 } 203 204 // sendRaw writes the number of bytes of the message to the network then the 205 // whole message b in slices of size maxChunkSize. 206 // In case of an error it aborts. 207 func (c *TCPConn) sendRaw(b []byte) (uint64, error) { 208 timeoutLock.RLock() 209 c.conn.SetWriteDeadline(time.Now().Add(timeout)) 210 timeoutLock.RUnlock() 211 212 // First write the size 213 packetSize := Size(len(b)) 214 if err := binary.Write(c.conn, globalOrder, packetSize); err != nil { 215 return 0, xerrors.Errorf("buffer write: %v", err) 216 } 217 // Then send everything through the connection 218 // Send chunk by chunk 219 log.Lvl5("Sending from", c.conn.LocalAddr(), "to", c.conn.RemoteAddr()) 220 var sent Size 221 for sent < packetSize { 222 n, err := c.conn.Write(b[sent:]) 223 if err != nil { 224 sentLen := 4 + uint64(sent) 225 c.updateTx(sentLen) 226 return sentLen, xerrors.Errorf("sending: %w", handleError(err)) 227 } 228 sent += Size(n) 229 } 230 // update stats on the connection. Plus 4 for the uint32 for the frame size. 231 sentLen := 4 + uint64(sent) 232 c.updateTx(sentLen) 233 return sentLen, nil 234 } 235 236 // Remote returns the name of the peer at the end point of 237 // the connection. 238 func (c *TCPConn) Remote() Address { 239 return Address(c.conn.RemoteAddr().String()) 240 } 241 242 // Local returns the local address and port. 243 func (c *TCPConn) Local() Address { 244 return NewTCPAddress(c.conn.LocalAddr().String()) 245 } 246 247 // Type returns PlainTCP. 248 func (c *TCPConn) Type() ConnType { 249 return PlainTCP 250 } 251 252 // Close the connection. 253 // Returns error if it couldn't close the connection. 254 func (c *TCPConn) Close() error { 255 c.closedMut.Lock() 256 defer c.closedMut.Unlock() 257 if c.closed == true { 258 return xerrors.Errorf("closing: %w", ErrClosed) 259 } 260 err := c.conn.Close() 261 c.closed = true 262 if err != nil { 263 return xerrors.Errorf("closing: %w", handleError(err)) 264 } 265 return nil 266 } 267 268 // handleError translates the network-layer error to a set of errors 269 // used in our packages. 270 func handleError(err error) error { 271 if strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "broken pipe") { 272 return ErrClosed 273 } else if strings.Contains(err.Error(), "canceled") { 274 return ErrCanceled 275 } else if err == io.EOF || strings.Contains(err.Error(), "EOF") { 276 return ErrEOF 277 } 278 279 netErr, ok := err.(net.Error) 280 if !ok { 281 return ErrUnknown 282 } 283 if netErr.Timeout() { 284 return ErrTimeout 285 } 286 287 log.Errorf("Unknown error caught: %s", err.Error()) 288 return ErrUnknown 289 } 290 291 // TCPListener implements the Host-interface using Tcp as a communication 292 // channel. 293 type TCPListener struct { 294 // the underlying golang/net listener. 295 listener net.Listener 296 // the close channel used to indicate to the listener we want to quit. 297 quit chan bool 298 // quitListener is a channel to indicate to the closing function that the 299 // listener has actually really quit. 300 quitListener chan bool 301 listeningLock sync.Mutex 302 listening bool 303 304 // closed tells the listen routine to return immediately if a 305 // Stop() has been called. 306 closed bool 307 308 // actual listening addr which might differ from initial address in 309 // case of ":0"-address. 310 addr net.Addr 311 312 // Is this a TCP or a TLS listener? 313 conntype ConnType 314 315 // suite that is given to each incoming connection 316 suite Suite 317 } 318 319 // NewTCPListener returns a TCPListener. This function binds globally using 320 // the port of 'addr'. 321 // It returns the listener and an error if one occurred during 322 // the binding. 323 // A subsequent call to Address() gives the actual listening 324 // address which is different if you gave it a ":0"-address. 325 func NewTCPListener(addr Address, s Suite) (*TCPListener, error) { 326 l, err := NewTCPListenerWithListenAddr(addr, s, "") 327 if err != nil { 328 return nil, xerrors.Errorf("tcp listener: %v", err) 329 } 330 return l, nil 331 } 332 333 // NewTCPListenerWithListenAddr returns a TCPListener. This function binds to the 334 // given 'listenAddr'. If it is empty, the function binds globally using 335 // the port of 'addr'. 336 // It returns the listener and an error if one occurred during 337 // the binding. 338 // A subsequent call to Address() gives the actual listening 339 // address which is different if you gave it a ":0"-address. 340 func NewTCPListenerWithListenAddr(addr Address, 341 s Suite, listenAddr string) (*TCPListener, error) { 342 if addr.ConnType() != PlainTCP && addr.ConnType() != TLS { 343 return nil, xerrors.New("TCPListener can only listen on TCP and TLS addresses") 344 } 345 t := &TCPListener{ 346 conntype: addr.ConnType(), 347 quit: make(chan bool), 348 quitListener: make(chan bool), 349 suite: s, 350 } 351 listenOn, err := getListenAddress(addr, listenAddr) 352 if err != nil { 353 return nil, xerrors.Errorf("listener: %v", err) 354 } 355 for i := 0; i < MaxRetryConnect; i++ { 356 ln, err := net.Listen("tcp", listenOn) 357 if err == nil { 358 t.listener = ln 359 break 360 } else if i == MaxRetryConnect-1 { 361 return nil, xerrors.New("Error opening listener: " + err.Error()) 362 } 363 time.Sleep(WaitRetry) 364 } 365 t.addr = t.listener.Addr() 366 return t, nil 367 } 368 369 // Listen starts to listen for incoming connections and calls fn for every 370 // connection-request it receives. 371 // If the connection is closed, an error will be returned. 372 func (t *TCPListener) Listen(fn func(Conn)) error { 373 receiver := func(tc Conn) { 374 go fn(tc) 375 } 376 err := t.listen(receiver) 377 if err != nil { 378 return xerrors.Errorf("listening: %v", err) 379 } 380 return nil 381 } 382 383 // listen is the private function that takes a function that takes a TCPConn. 384 // That way we can control what to do of the TCPConn before returning it to the 385 // function given by the user. fn is called in the same routine. 386 func (t *TCPListener) listen(fn func(Conn)) error { 387 t.listeningLock.Lock() 388 if t.closed == true { 389 t.listeningLock.Unlock() 390 return nil 391 } 392 t.listening = true 393 t.listeningLock.Unlock() 394 for { 395 conn, err := t.listener.Accept() 396 if err != nil { 397 select { 398 case <-t.quit: 399 t.quitListener <- true 400 return nil 401 default: 402 } 403 continue 404 } 405 c := TCPConn{ 406 conn: conn, 407 suite: t.suite, 408 } 409 fn(&c) 410 } 411 } 412 413 // Stop the listener. It waits till all connections are closed 414 // and returned from. 415 // If there is no listener it will return an error. 416 func (t *TCPListener) Stop() error { 417 // lets see if we launched a listening routing 418 t.listeningLock.Lock() 419 defer t.listeningLock.Unlock() 420 421 close(t.quit) 422 423 if t.listener != nil { 424 if err := t.listener.Close(); err != nil { 425 if handleError(err) != ErrClosed { 426 return xerrors.Errorf("closing: %w", handleError(err)) 427 } 428 } 429 } 430 var stop bool 431 if t.listening { 432 for !stop { 433 select { 434 case <-t.quitListener: 435 stop = true 436 case <-time.After(time.Millisecond * 50): 437 continue 438 } 439 } 440 } 441 442 t.quit = make(chan bool) 443 t.listening = false 444 t.closed = true 445 return nil 446 } 447 448 // Address returns the listening address. 449 func (t *TCPListener) Address() Address { 450 t.listeningLock.Lock() 451 defer t.listeningLock.Unlock() 452 return NewAddress(t.conntype, t.addr.String()) 453 } 454 455 // Listening returns whether it's already listening. 456 func (t *TCPListener) Listening() bool { 457 t.listeningLock.Lock() 458 defer t.listeningLock.Unlock() 459 return t.listening 460 } 461 462 // getListenAddress returns the address the listener should listen 463 // on given the server's address (addr) and the address it was told to listen 464 // on (listenAddr), which could be empty. 465 // Rules: 466 // 1. If there is no listenAddr, bind globally with addr. 467 // 2. If there is only an IP in listenAddr, take the port from addr. 468 // 3. If there is an IP:Port in listenAddr, take only listenAddr. 469 // Otherwise return an error. 470 func getListenAddress(addr Address, listenAddr string) (string, error) { 471 // If no `listenAddr`, bind globally. 472 if listenAddr == "" { 473 return GlobalBind(addr.NetworkAddress()) 474 } 475 _, port, err := net.SplitHostPort(addr.NetworkAddress()) 476 if err != nil { 477 return "", xerrors.Errorf("invalid address: %v", err) 478 } 479 480 // If 'listenAddr' only contains the host, combine it with the port 481 // of 'addr'. 482 splitted := strings.Split(listenAddr, ":") 483 if len(splitted) == 1 && port != "" { 484 return splitted[0] + ":" + port, nil 485 } 486 487 // If host and port in `listenAddr`, choose this one. 488 hostListen, portListen, err := net.SplitHostPort(listenAddr) 489 if err != nil { 490 return "", xerrors.Errorf("invalid address: %v", err) 491 } 492 if hostListen != "" && portListen != "" { 493 return listenAddr, nil 494 } 495 496 return "", xerrors.Errorf("Invalid combination of 'addr' (%s) and 'listenAddr' (%s)", addr.NetworkAddress(), listenAddr) 497 } 498 499 // TCPHost implements the Host interface using TCP connections. 500 type TCPHost struct { 501 suite Suite 502 sid *ServerIdentity 503 *TCPListener 504 } 505 506 // NewTCPHost returns a new Host using TCP connection based type. 507 func NewTCPHost(sid *ServerIdentity, s Suite) (*TCPHost, error) { 508 host, err := NewTCPHostWithListenAddr(sid, s, "") 509 if err != nil { 510 return nil, xerrors.Errorf("tcp host: %v", err) 511 } 512 return host, nil 513 } 514 515 // NewTCPHostWithListenAddr returns a new Host using TCP connection based type 516 // listening on the given address. 517 func NewTCPHostWithListenAddr(sid *ServerIdentity, s Suite, 518 listenAddr string) (*TCPHost, error) { 519 h := &TCPHost{ 520 suite: s, 521 sid: sid, 522 } 523 var err error 524 if sid.Address.ConnType() == TLS { 525 h.TCPListener, err = NewTLSListenerWithListenAddr(sid, s, listenAddr) 526 } else { 527 h.TCPListener, err = NewTCPListenerWithListenAddr(sid.Address, s, listenAddr) 528 } 529 if err != nil { 530 return nil, xerrors.Errorf("tcp host: %v", err) 531 } 532 return h, nil 533 } 534 535 // Connect can only connect to PlainTCP connections. 536 // It will return an error if it is not a PlainTCP-connection-type. 537 func (t *TCPHost) Connect(si *ServerIdentity) (Conn, error) { 538 switch si.Address.ConnType() { 539 case PlainTCP: 540 c, err := NewTCPConn(si.Address, t.suite) 541 if err != nil { 542 return nil, xerrors.Errorf("tcp connection: %v", err) 543 } 544 return c, nil 545 case TLS: 546 c, err := NewTLSConn(t.sid, si, t.suite) 547 if err != nil { 548 return nil, xerrors.Errorf("tcp connection: %v", err) 549 } 550 return c, nil 551 case InvalidConnType: 552 return nil, xerrors.New("This address is not correctly formatted: " + si.Address.String()) 553 } 554 return nil, xerrors.Errorf("TCPHost %s can't handle this type of connection: %s", si.Address, si.Address.ConnType()) 555 }