github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/adapters/gonet/gonet.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package gonet provides a Go net package compatible wrapper for a tcpip stack. 16 package gonet 17 18 import ( 19 "context" 20 "errors" 21 "io" 22 "net" 23 "sync" 24 "time" 25 26 "github.com/FlowerWrong/netstack/tcpip" 27 "github.com/FlowerWrong/netstack/tcpip/buffer" 28 "github.com/FlowerWrong/netstack/tcpip/stack" 29 "github.com/FlowerWrong/netstack/tcpip/transport/tcp" 30 "github.com/FlowerWrong/netstack/tcpip/transport/udp" 31 "github.com/FlowerWrong/netstack/waiter" 32 ) 33 34 var ( 35 errCanceled = errors.New("operation canceled") 36 errWouldBlock = errors.New("operation would block") 37 ) 38 39 // timeoutError is how the net package reports timeouts. 40 type timeoutError struct{} 41 42 func (e *timeoutError) Error() string { return "i/o timeout" } 43 func (e *timeoutError) Timeout() bool { return true } 44 func (e *timeoutError) Temporary() bool { return true } 45 46 // A Listener is a wrapper around a tcpip endpoint that implements 47 // net.Listener. 48 type Listener struct { 49 stack *stack.Stack 50 ep tcpip.Endpoint 51 wq *waiter.Queue 52 cancel chan struct{} 53 } 54 55 // NewListener creates a new Listener. 56 func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) { 57 // Create TCP endpoint, bind it, then start listening. 58 var wq waiter.Queue 59 ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) 60 if err != nil { 61 return nil, errors.New(err.String()) 62 } 63 64 if err := ep.Bind(addr); err != nil { 65 ep.Close() 66 return nil, &net.OpError{ 67 Op: "bind", 68 Net: "tcp", 69 Addr: fullToTCPAddr(addr), 70 Err: errors.New(err.String()), 71 } 72 } 73 74 if err := ep.Listen(10); err != nil { 75 ep.Close() 76 return nil, &net.OpError{ 77 Op: "listen", 78 Net: "tcp", 79 Addr: fullToTCPAddr(addr), 80 Err: errors.New(err.String()), 81 } 82 } 83 84 return &Listener{ 85 stack: s, 86 ep: ep, 87 wq: &wq, 88 cancel: make(chan struct{}), 89 }, nil 90 } 91 92 // Close implements net.Listener.Close. 93 func (l *Listener) Close() error { 94 l.ep.Close() 95 return nil 96 } 97 98 // Shutdown stops the HTTP server. 99 func (l *Listener) Shutdown() { 100 l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) 101 close(l.cancel) // broadcast cancellation 102 } 103 104 // Addr implements net.Listener.Addr. 105 func (l *Listener) Addr() net.Addr { 106 a, err := l.ep.GetLocalAddress() 107 if err != nil { 108 return nil 109 } 110 return fullToTCPAddr(a) 111 } 112 113 type deadlineTimer struct { 114 // mu protects the fields below. 115 mu sync.Mutex 116 117 readTimer *time.Timer 118 readCancelCh chan struct{} 119 writeTimer *time.Timer 120 writeCancelCh chan struct{} 121 } 122 123 func (d *deadlineTimer) init() { 124 d.readCancelCh = make(chan struct{}) 125 d.writeCancelCh = make(chan struct{}) 126 } 127 128 func (d *deadlineTimer) readCancel() <-chan struct{} { 129 d.mu.Lock() 130 c := d.readCancelCh 131 d.mu.Unlock() 132 return c 133 } 134 func (d *deadlineTimer) writeCancel() <-chan struct{} { 135 d.mu.Lock() 136 c := d.writeCancelCh 137 d.mu.Unlock() 138 return c 139 } 140 141 // setDeadline contains the shared logic for setting a deadline. 142 // 143 // cancelCh and timer must be pointers to deadlineTimer.readCancelCh and 144 // deadlineTimer.readTimer or deadlineTimer.writeCancelCh and 145 // deadlineTimer.writeTimer. 146 // 147 // setDeadline must only be called while holding d.mu. 148 func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { 149 if *timer != nil && !(*timer).Stop() { 150 *cancelCh = make(chan struct{}) 151 } 152 153 // Create a new channel if we already closed it due to setting an already 154 // expired time. We won't race with the timer because we already handled 155 // that above. 156 select { 157 case <-*cancelCh: 158 *cancelCh = make(chan struct{}) 159 default: 160 } 161 162 // "A zero value for t means I/O operations will not time out." 163 // - net.Conn.SetDeadline 164 if t.IsZero() { 165 return 166 } 167 168 timeout := t.Sub(time.Now()) 169 if timeout <= 0 { 170 close(*cancelCh) 171 return 172 } 173 174 // Timer.Stop returns whether or not the AfterFunc has started, but 175 // does not indicate whether or not it has completed. Make a copy of 176 // the cancel channel to prevent this code from racing with the next 177 // call of setDeadline replacing *cancelCh. 178 ch := *cancelCh 179 *timer = time.AfterFunc(timeout, func() { 180 close(ch) 181 }) 182 } 183 184 // SetReadDeadline implements net.Conn.SetReadDeadline and 185 // net.PacketConn.SetReadDeadline. 186 func (d *deadlineTimer) SetReadDeadline(t time.Time) error { 187 d.mu.Lock() 188 d.setDeadline(&d.readCancelCh, &d.readTimer, t) 189 d.mu.Unlock() 190 return nil 191 } 192 193 // SetWriteDeadline implements net.Conn.SetWriteDeadline and 194 // net.PacketConn.SetWriteDeadline. 195 func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { 196 d.mu.Lock() 197 d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) 198 d.mu.Unlock() 199 return nil 200 } 201 202 // SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. 203 func (d *deadlineTimer) SetDeadline(t time.Time) error { 204 d.mu.Lock() 205 d.setDeadline(&d.readCancelCh, &d.readTimer, t) 206 d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) 207 d.mu.Unlock() 208 return nil 209 } 210 211 // A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn 212 // interface. 213 type Conn struct { 214 deadlineTimer 215 216 wq *waiter.Queue 217 ep tcpip.Endpoint 218 219 // readMu serializes reads and implicitly protects read. 220 // 221 // Lock ordering: 222 // If both readMu and deadlineTimer.mu are to be used in a single 223 // request, readMu must be acquired before deadlineTimer.mu. 224 readMu sync.Mutex 225 226 // read contains bytes that have been read from the endpoint, 227 // but haven't yet been returned. 228 read buffer.View 229 } 230 231 // NewConn creates a new Conn. 232 func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { 233 c := &Conn{ 234 wq: wq, 235 ep: ep, 236 } 237 c.deadlineTimer.init() 238 return c 239 } 240 241 // Accept implements net.Conn.Accept. 242 func (l *Listener) Accept() (net.Conn, error) { 243 n, wq, err := l.ep.Accept() 244 245 if err == tcpip.ErrWouldBlock { 246 // Create wait queue entry that notifies a channel. 247 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 248 l.wq.EventRegister(&waitEntry, waiter.EventIn) 249 defer l.wq.EventUnregister(&waitEntry) 250 251 for { 252 n, wq, err = l.ep.Accept() 253 254 if err != tcpip.ErrWouldBlock { 255 break 256 } 257 258 select { 259 case <-l.cancel: 260 return nil, errCanceled 261 case <-notifyCh: 262 } 263 } 264 } 265 266 if err != nil { 267 return nil, &net.OpError{ 268 Op: "accept", 269 Net: "tcp", 270 Addr: l.Addr(), 271 Err: errors.New(err.String()), 272 } 273 } 274 275 return NewConn(wq, n), nil 276 } 277 278 type opErrorer interface { 279 newOpError(op string, err error) *net.OpError 280 } 281 282 // commonRead implements the common logic between net.Conn.Read and 283 // net.PacketConn.ReadFrom. 284 func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) { 285 select { 286 case <-deadline: 287 return nil, errorer.newOpError("read", &timeoutError{}) 288 default: 289 } 290 291 read, _, err := ep.Read(addr) 292 293 if err == tcpip.ErrWouldBlock { 294 if dontWait { 295 return nil, errWouldBlock 296 } 297 // Create wait queue entry that notifies a channel. 298 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 299 wq.EventRegister(&waitEntry, waiter.EventIn) 300 defer wq.EventUnregister(&waitEntry) 301 for { 302 read, _, err = ep.Read(addr) 303 if err != tcpip.ErrWouldBlock { 304 break 305 } 306 select { 307 case <-deadline: 308 return nil, errorer.newOpError("read", &timeoutError{}) 309 case <-notifyCh: 310 } 311 } 312 } 313 314 if err == tcpip.ErrClosedForReceive { 315 return nil, io.EOF 316 } 317 318 if err != nil { 319 return nil, errorer.newOpError("read", errors.New(err.String())) 320 } 321 322 return read, nil 323 } 324 325 // Read implements net.Conn.Read. 326 func (c *Conn) Read(b []byte) (int, error) { 327 c.readMu.Lock() 328 defer c.readMu.Unlock() 329 330 deadline := c.readCancel() 331 332 numRead := 0 333 for numRead != len(b) { 334 if len(c.read) == 0 { 335 var err error 336 c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0) 337 if err != nil { 338 if numRead != 0 { 339 return numRead, nil 340 } 341 return numRead, err 342 } 343 } 344 n := copy(b[numRead:], c.read) 345 c.read.TrimFront(n) 346 numRead += n 347 if len(c.read) == 0 { 348 c.read = nil 349 } 350 } 351 return numRead, nil 352 } 353 354 // Write implements net.Conn.Write. 355 func (c *Conn) Write(b []byte) (int, error) { 356 deadline := c.writeCancel() 357 358 // Check if deadlineTimer has already expired. 359 select { 360 case <-deadline: 361 return 0, c.newOpError("write", &timeoutError{}) 362 default: 363 } 364 365 v := buffer.NewViewFromBytes(b) 366 367 // We must handle two soft failure conditions simultaneously: 368 // 1. Write may write nothing and return tcpip.ErrWouldBlock. 369 // If this happens, we need to register for notifications if we have 370 // not already and wait to try again. 371 // 2. Write may write fewer than the full number of bytes and return 372 // without error. In this case we need to try writing the remaining 373 // bytes again. I do not need to register for notifications. 374 // 375 // What is more, these two soft failure conditions can be interspersed. 376 // There is no guarantee that all of the condition #1s will occur before 377 // all of the condition #2s or visa-versa. 378 var ( 379 err *tcpip.Error 380 nbytes int 381 reg bool 382 notifyCh chan struct{} 383 ) 384 for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { 385 if err == tcpip.ErrWouldBlock { 386 if !reg { 387 // Only register once. 388 reg = true 389 390 // Create wait queue entry that notifies a channel. 391 var waitEntry waiter.Entry 392 waitEntry, notifyCh = waiter.NewChannelEntry(nil) 393 c.wq.EventRegister(&waitEntry, waiter.EventOut) 394 defer c.wq.EventUnregister(&waitEntry) 395 } else { 396 // Don't wait immediately after registration in case more data 397 // became available between when we last checked and when we setup 398 // the notification. 399 select { 400 case <-deadline: 401 return nbytes, c.newOpError("write", &timeoutError{}) 402 case <-notifyCh: 403 } 404 } 405 } 406 407 var n int64 408 var resCh <-chan struct{} 409 n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) 410 nbytes += int(n) 411 v.TrimFront(int(n)) 412 413 if resCh != nil { 414 select { 415 case <-deadline: 416 return nbytes, c.newOpError("write", &timeoutError{}) 417 case <-resCh: 418 } 419 420 n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) 421 nbytes += int(n) 422 v.TrimFront(int(n)) 423 } 424 } 425 426 if err == nil { 427 return nbytes, nil 428 } 429 430 return nbytes, c.newOpError("write", errors.New(err.String())) 431 } 432 433 // Close implements net.Conn.Close. 434 func (c *Conn) Close() error { 435 c.ep.Close() 436 return nil 437 } 438 439 // CloseRead shuts down the reading side of the TCP connection. Most callers 440 // should just use Close. 441 // 442 // A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. 443 func (c *Conn) CloseRead() error { 444 if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { 445 return c.newOpError("close", errors.New(terr.String())) 446 } 447 return nil 448 } 449 450 // CloseWrite shuts down the writing side of the TCP connection. Most callers 451 // should just use Close. 452 // 453 // A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. 454 func (c *Conn) CloseWrite() error { 455 if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { 456 return c.newOpError("close", errors.New(terr.String())) 457 } 458 return nil 459 } 460 461 // LocalAddr implements net.Conn.LocalAddr. 462 func (c *Conn) LocalAddr() net.Addr { 463 a, err := c.ep.GetLocalAddress() 464 if err != nil { 465 return nil 466 } 467 return fullToTCPAddr(a) 468 } 469 470 // RemoteAddr implements net.Conn.RemoteAddr. 471 func (c *Conn) RemoteAddr() net.Addr { 472 a, err := c.ep.GetRemoteAddress() 473 if err != nil { 474 return nil 475 } 476 return fullToTCPAddr(a) 477 } 478 479 func (c *Conn) newOpError(op string, err error) *net.OpError { 480 return &net.OpError{ 481 Op: op, 482 Net: "tcp", 483 Source: c.LocalAddr(), 484 Addr: c.RemoteAddr(), 485 Err: err, 486 } 487 } 488 489 func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr { 490 return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} 491 } 492 493 func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { 494 return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} 495 } 496 497 // DialTCP creates a new TCP Conn connected to the specified address. 498 func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { 499 return DialContextTCP(context.Background(), s, addr, network) 500 } 501 502 // DialContextTCP creates a new TCP Conn connected to the specified address 503 // with the option of adding cancellation and timeouts. 504 func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { 505 // Create TCP endpoint, then connect. 506 var wq waiter.Queue 507 ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) 508 if err != nil { 509 return nil, errors.New(err.String()) 510 } 511 512 // Create wait queue entry that notifies a channel. 513 // 514 // We do this unconditionally as Connect will always return an error. 515 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 516 wq.EventRegister(&waitEntry, waiter.EventOut) 517 defer wq.EventUnregister(&waitEntry) 518 519 select { 520 case <-ctx.Done(): 521 return nil, ctx.Err() 522 default: 523 } 524 525 err = ep.Connect(addr) 526 if err == tcpip.ErrConnectStarted { 527 select { 528 case <-ctx.Done(): 529 ep.Close() 530 return nil, ctx.Err() 531 case <-notifyCh: 532 } 533 534 err = ep.GetSockOpt(tcpip.ErrorOption{}) 535 } 536 if err != nil { 537 ep.Close() 538 return nil, &net.OpError{ 539 Op: "connect", 540 Net: "tcp", 541 Addr: fullToTCPAddr(addr), 542 Err: errors.New(err.String()), 543 } 544 } 545 546 return NewConn(&wq, ep), nil 547 } 548 549 // A PacketConn is a wrapper around a tcpip endpoint that implements 550 // net.PacketConn. 551 type PacketConn struct { 552 deadlineTimer 553 554 stack *stack.Stack 555 ep tcpip.Endpoint 556 wq *waiter.Queue 557 } 558 559 // DialUDP creates a new PacketConn. 560 // 561 // If laddr is nil, a local address is automatically chosen. 562 // 563 // If raddr is nil, the PacketConn is left unconnected. 564 func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { 565 var wq waiter.Queue 566 ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) 567 if err != nil { 568 return nil, errors.New(err.String()) 569 } 570 571 if laddr != nil { 572 if err := ep.Bind(*laddr); err != nil { 573 ep.Close() 574 return nil, &net.OpError{ 575 Op: "bind", 576 Net: "udp", 577 Addr: fullToUDPAddr(*laddr), 578 Err: errors.New(err.String()), 579 } 580 } 581 } 582 583 c := PacketConn{ 584 stack: s, 585 ep: ep, 586 wq: &wq, 587 } 588 c.deadlineTimer.init() 589 590 if raddr != nil { 591 if err := c.ep.Connect(*raddr); err != nil { 592 c.ep.Close() 593 return nil, &net.OpError{ 594 Op: "connect", 595 Net: "udp", 596 Addr: fullToUDPAddr(*raddr), 597 Err: errors.New(err.String()), 598 } 599 } 600 } 601 602 return &c, nil 603 } 604 605 func (c *PacketConn) newOpError(op string, err error) *net.OpError { 606 return c.newRemoteOpError(op, nil, err) 607 } 608 609 func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { 610 return &net.OpError{ 611 Op: op, 612 Net: "udp", 613 Source: c.LocalAddr(), 614 Addr: remote, 615 Err: err, 616 } 617 } 618 619 // RemoteAddr implements net.Conn.RemoteAddr. 620 func (c *PacketConn) RemoteAddr() net.Addr { 621 a, err := c.ep.GetRemoteAddress() 622 if err != nil { 623 return nil 624 } 625 return fullToTCPAddr(a) 626 } 627 628 // Read implements net.Conn.Read 629 func (c *PacketConn) Read(b []byte) (int, error) { 630 bytesRead, _, err := c.ReadFrom(b) 631 return bytesRead, err 632 } 633 634 // ReadFrom implements net.PacketConn.ReadFrom. 635 func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { 636 deadline := c.readCancel() 637 638 var addr tcpip.FullAddress 639 read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false) 640 if err != nil { 641 return 0, nil, err 642 } 643 644 return copy(b, read), fullToUDPAddr(addr), nil 645 } 646 647 func (c *PacketConn) Write(b []byte) (int, error) { 648 return c.WriteTo(b, nil) 649 } 650 651 // WriteTo implements net.PacketConn.WriteTo. 652 func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { 653 deadline := c.writeCancel() 654 655 // Check if deadline has already expired. 656 select { 657 case <-deadline: 658 return 0, c.newRemoteOpError("write", addr, &timeoutError{}) 659 default: 660 } 661 662 // If we're being called by Write, there is no addr 663 wopts := tcpip.WriteOptions{} 664 if addr != nil { 665 ua := addr.(*net.UDPAddr) 666 wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} 667 } 668 669 v := buffer.NewView(len(b)) 670 copy(v, b) 671 672 n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts) 673 if resCh != nil { 674 select { 675 case <-deadline: 676 return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) 677 case <-resCh: 678 } 679 680 n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) 681 } 682 683 if err == tcpip.ErrWouldBlock { 684 // Create wait queue entry that notifies a channel. 685 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 686 c.wq.EventRegister(&waitEntry, waiter.EventOut) 687 defer c.wq.EventUnregister(&waitEntry) 688 for { 689 select { 690 case <-deadline: 691 return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) 692 case <-notifyCh: 693 } 694 695 n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) 696 if err != tcpip.ErrWouldBlock { 697 break 698 } 699 } 700 } 701 702 if err == nil { 703 return int(n), nil 704 } 705 706 return int(n), c.newRemoteOpError("write", addr, errors.New(err.String())) 707 } 708 709 // Close implements net.PacketConn.Close. 710 func (c *PacketConn) Close() error { 711 c.ep.Close() 712 return nil 713 } 714 715 // LocalAddr implements net.PacketConn.LocalAddr. 716 func (c *PacketConn) LocalAddr() net.Addr { 717 a, err := c.ep.GetLocalAddress() 718 if err != nil { 719 return nil 720 } 721 return fullToUDPAddr(a) 722 }