inet.af/netstack@v0.0.0-20220214151720-7585b01ddccf/tcpip/adapters/gonet/gonet.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package gonet provides a Go net package compatible wrapper for a tcpip stack. 16 package gonet 17 18 import ( 19 "bytes" 20 "context" 21 "errors" 22 "fmt" 23 "io" 24 "net" 25 "time" 26 27 "inet.af/netstack/sync" 28 "inet.af/netstack/tcpip" 29 "inet.af/netstack/tcpip/buffer" 30 "inet.af/netstack/tcpip/stack" 31 "inet.af/netstack/tcpip/transport/tcp" 32 "inet.af/netstack/tcpip/transport/udp" 33 "inet.af/netstack/waiter" 34 ) 35 36 var ( 37 errCanceled = errors.New("operation canceled") 38 errWouldBlock = errors.New("operation would block") 39 ) 40 41 // timeoutError is how the net package reports timeouts. 42 type timeoutError struct{} 43 44 func (e *timeoutError) Error() string { return "i/o timeout" } 45 func (e *timeoutError) Timeout() bool { return true } 46 func (e *timeoutError) Temporary() bool { return true } 47 48 // A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements 49 // net.Listener. 50 type TCPListener struct { 51 stack *stack.Stack 52 ep tcpip.Endpoint 53 wq *waiter.Queue 54 cancel chan struct{} 55 } 56 57 // NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint. 58 func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener { 59 return &TCPListener{ 60 stack: s, 61 ep: ep, 62 wq: wq, 63 cancel: make(chan struct{}), 64 } 65 } 66 67 // ListenTCP creates a new TCPListener. 68 func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) { 69 // Create a TCP endpoint, bind it, then start listening. 70 var wq waiter.Queue 71 ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) 72 if err != nil { 73 return nil, errors.New(err.String()) 74 } 75 76 if err := ep.Bind(addr); err != nil { 77 ep.Close() 78 return nil, &net.OpError{ 79 Op: "bind", 80 Net: "tcp", 81 Addr: fullToTCPAddr(addr), 82 Err: errors.New(err.String()), 83 } 84 } 85 86 if err := ep.Listen(10); err != nil { 87 ep.Close() 88 return nil, &net.OpError{ 89 Op: "listen", 90 Net: "tcp", 91 Addr: fullToTCPAddr(addr), 92 Err: errors.New(err.String()), 93 } 94 } 95 96 return NewTCPListener(s, &wq, ep), nil 97 } 98 99 // Close implements net.Listener.Close. 100 func (l *TCPListener) Close() error { 101 l.ep.Close() 102 return nil 103 } 104 105 // Shutdown stops the HTTP server. 106 func (l *TCPListener) Shutdown() { 107 l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) 108 close(l.cancel) // broadcast cancellation 109 } 110 111 // Addr implements net.Listener.Addr. 112 func (l *TCPListener) Addr() net.Addr { 113 a, err := l.ep.GetLocalAddress() 114 if err != nil { 115 return nil 116 } 117 return fullToTCPAddr(a) 118 } 119 120 type deadlineTimer struct { 121 // mu protects the fields below. 122 mu sync.Mutex 123 124 readTimer *time.Timer 125 readCancelCh chan struct{} 126 writeTimer *time.Timer 127 writeCancelCh chan struct{} 128 } 129 130 func (d *deadlineTimer) init() { 131 d.readCancelCh = make(chan struct{}) 132 d.writeCancelCh = make(chan struct{}) 133 } 134 135 func (d *deadlineTimer) readCancel() <-chan struct{} { 136 d.mu.Lock() 137 c := d.readCancelCh 138 d.mu.Unlock() 139 return c 140 } 141 func (d *deadlineTimer) writeCancel() <-chan struct{} { 142 d.mu.Lock() 143 c := d.writeCancelCh 144 d.mu.Unlock() 145 return c 146 } 147 148 // setDeadline contains the shared logic for setting a deadline. 149 // 150 // cancelCh and timer must be pointers to deadlineTimer.readCancelCh and 151 // deadlineTimer.readTimer or deadlineTimer.writeCancelCh and 152 // deadlineTimer.writeTimer. 153 // 154 // setDeadline must only be called while holding d.mu. 155 func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { 156 if *timer != nil && !(*timer).Stop() { 157 *cancelCh = make(chan struct{}) 158 } 159 160 // Create a new channel if we already closed it due to setting an already 161 // expired time. We won't race with the timer because we already handled 162 // that above. 163 select { 164 case <-*cancelCh: 165 *cancelCh = make(chan struct{}) 166 default: 167 } 168 169 // "A zero value for t means I/O operations will not time out." 170 // - net.Conn.SetDeadline 171 if t.IsZero() { 172 return 173 } 174 175 timeout := t.Sub(time.Now()) 176 if timeout <= 0 { 177 close(*cancelCh) 178 return 179 } 180 181 // Timer.Stop returns whether or not the AfterFunc has started, but 182 // does not indicate whether or not it has completed. Make a copy of 183 // the cancel channel to prevent this code from racing with the next 184 // call of setDeadline replacing *cancelCh. 185 ch := *cancelCh 186 *timer = time.AfterFunc(timeout, func() { 187 close(ch) 188 }) 189 } 190 191 // SetReadDeadline implements net.Conn.SetReadDeadline and 192 // net.PacketConn.SetReadDeadline. 193 func (d *deadlineTimer) SetReadDeadline(t time.Time) error { 194 d.mu.Lock() 195 d.setDeadline(&d.readCancelCh, &d.readTimer, t) 196 d.mu.Unlock() 197 return nil 198 } 199 200 // SetWriteDeadline implements net.Conn.SetWriteDeadline and 201 // net.PacketConn.SetWriteDeadline. 202 func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { 203 d.mu.Lock() 204 d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) 205 d.mu.Unlock() 206 return nil 207 } 208 209 // SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. 210 func (d *deadlineTimer) SetDeadline(t time.Time) error { 211 d.mu.Lock() 212 d.setDeadline(&d.readCancelCh, &d.readTimer, t) 213 d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) 214 d.mu.Unlock() 215 return nil 216 } 217 218 // A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn 219 // interface. 220 type TCPConn struct { 221 deadlineTimer 222 223 wq *waiter.Queue 224 ep tcpip.Endpoint 225 226 // readMu serializes reads and implicitly protects read. 227 // 228 // Lock ordering: 229 // If both readMu and deadlineTimer.mu are to be used in a single 230 // request, readMu must be acquired before deadlineTimer.mu. 231 readMu sync.Mutex 232 233 // read contains bytes that have been read from the endpoint, 234 // but haven't yet been returned. 235 read buffer.View 236 } 237 238 // NewTCPConn creates a new TCPConn. 239 func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { 240 c := &TCPConn{ 241 wq: wq, 242 ep: ep, 243 } 244 c.deadlineTimer.init() 245 return c 246 } 247 248 // Accept implements net.Conn.Accept. 249 func (l *TCPListener) Accept() (net.Conn, error) { 250 n, wq, err := l.ep.Accept(nil) 251 252 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 253 // Create wait queue entry that notifies a channel. 254 waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) 255 l.wq.EventRegister(&waitEntry) 256 defer l.wq.EventUnregister(&waitEntry) 257 258 for { 259 n, wq, err = l.ep.Accept(nil) 260 261 if _, ok := err.(*tcpip.ErrWouldBlock); !ok { 262 break 263 } 264 265 select { 266 case <-l.cancel: 267 return nil, errCanceled 268 case <-notifyCh: 269 } 270 } 271 } 272 273 if err != nil { 274 return nil, &net.OpError{ 275 Op: "accept", 276 Net: "tcp", 277 Addr: l.Addr(), 278 Err: errors.New(err.String()), 279 } 280 } 281 282 return NewTCPConn(wq, n), nil 283 } 284 285 type opErrorer interface { 286 newOpError(op string, err error) *net.OpError 287 } 288 289 // commonRead implements the common logic between net.Conn.Read and 290 // net.PacketConn.ReadFrom. 291 func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) { 292 select { 293 case <-deadline: 294 return 0, errorer.newOpError("read", &timeoutError{}) 295 default: 296 } 297 298 w := tcpip.SliceWriter(b) 299 opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} 300 res, err := ep.Read(&w, opts) 301 302 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 303 // Create wait queue entry that notifies a channel. 304 waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) 305 wq.EventRegister(&waitEntry) 306 defer wq.EventUnregister(&waitEntry) 307 for { 308 res, err = ep.Read(&w, opts) 309 if _, ok := err.(*tcpip.ErrWouldBlock); !ok { 310 break 311 } 312 select { 313 case <-deadline: 314 return 0, errorer.newOpError("read", &timeoutError{}) 315 case <-notifyCh: 316 } 317 } 318 } 319 320 if _, ok := err.(*tcpip.ErrClosedForReceive); ok { 321 return 0, io.EOF 322 } 323 324 if err != nil { 325 return 0, errorer.newOpError("read", errors.New(err.String())) 326 } 327 328 if addr != nil { 329 *addr = res.RemoteAddr 330 } 331 return res.Count, nil 332 } 333 334 // Read implements net.Conn.Read. 335 func (c *TCPConn) Read(b []byte) (int, error) { 336 c.readMu.Lock() 337 defer c.readMu.Unlock() 338 339 deadline := c.readCancel() 340 341 n, err := commonRead(b, c.ep, c.wq, deadline, nil, c) 342 if n != 0 { 343 c.ep.ModerateRecvBuf(n) 344 } 345 return n, err 346 } 347 348 // Write implements net.Conn.Write. 349 func (c *TCPConn) Write(b []byte) (int, error) { 350 deadline := c.writeCancel() 351 352 // Check if deadlineTimer has already expired. 353 select { 354 case <-deadline: 355 return 0, c.newOpError("write", &timeoutError{}) 356 default: 357 } 358 359 // We must handle two soft failure conditions simultaneously: 360 // 1. Write may write nothing and return *tcpip.ErrWouldBlock. 361 // If this happens, we need to register for notifications if we have 362 // not already and wait to try again. 363 // 2. Write may write fewer than the full number of bytes and return 364 // without error. In this case we need to try writing the remaining 365 // bytes again. I do not need to register for notifications. 366 // 367 // What is more, these two soft failure conditions can be interspersed. 368 // There is no guarantee that all of the condition #1s will occur before 369 // all of the condition #2s or visa-versa. 370 var ( 371 r bytes.Reader 372 nbytes int 373 entry waiter.Entry 374 ch <-chan struct{} 375 ) 376 for nbytes != len(b) { 377 r.Reset(b[nbytes:]) 378 n, err := c.ep.Write(&r, tcpip.WriteOptions{}) 379 nbytes += int(n) 380 switch err.(type) { 381 case nil: 382 case *tcpip.ErrWouldBlock: 383 if ch == nil { 384 entry, ch = waiter.NewChannelEntry(waiter.WritableEvents) 385 c.wq.EventRegister(&entry) 386 defer c.wq.EventUnregister(&entry) 387 } else { 388 // Don't wait immediately after registration in case more data 389 // became available between when we last checked and when we setup 390 // the notification. 391 select { 392 case <-deadline: 393 return nbytes, c.newOpError("write", &timeoutError{}) 394 case <-ch: 395 continue 396 } 397 } 398 default: 399 return nbytes, c.newOpError("write", errors.New(err.String())) 400 } 401 } 402 return nbytes, nil 403 } 404 405 // Close implements net.Conn.Close. 406 func (c *TCPConn) Close() error { 407 c.ep.Close() 408 return nil 409 } 410 411 // CloseRead shuts down the reading side of the TCP connection. Most callers 412 // should just use Close. 413 // 414 // A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. 415 func (c *TCPConn) CloseRead() error { 416 if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { 417 return c.newOpError("close", errors.New(terr.String())) 418 } 419 return nil 420 } 421 422 // CloseWrite shuts down the writing side of the TCP connection. Most callers 423 // should just use Close. 424 // 425 // A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. 426 func (c *TCPConn) CloseWrite() error { 427 if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { 428 return c.newOpError("close", errors.New(terr.String())) 429 } 430 return nil 431 } 432 433 // LocalAddr implements net.Conn.LocalAddr. 434 func (c *TCPConn) LocalAddr() net.Addr { 435 a, err := c.ep.GetLocalAddress() 436 if err != nil { 437 return nil 438 } 439 return fullToTCPAddr(a) 440 } 441 442 // RemoteAddr implements net.Conn.RemoteAddr. 443 func (c *TCPConn) RemoteAddr() net.Addr { 444 a, err := c.ep.GetRemoteAddress() 445 if err != nil { 446 return nil 447 } 448 return fullToTCPAddr(a) 449 } 450 451 func (c *TCPConn) newOpError(op string, err error) *net.OpError { 452 return &net.OpError{ 453 Op: op, 454 Net: "tcp", 455 Source: c.LocalAddr(), 456 Addr: c.RemoteAddr(), 457 Err: err, 458 } 459 } 460 461 func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr { 462 return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} 463 } 464 465 func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { 466 return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} 467 } 468 469 // DialTCP creates a new TCPConn connected to the specified address. 470 func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { 471 return DialContextTCP(context.Background(), s, addr, network) 472 } 473 474 // DialTCPWithBind creates a new TCPConn connected to the specified 475 // remoteAddress with its local address bound to localAddr. 476 func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { 477 // Create TCP endpoint, then connect. 478 var wq waiter.Queue 479 ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) 480 if err != nil { 481 return nil, errors.New(err.String()) 482 } 483 484 // Create wait queue entry that notifies a channel. 485 // 486 // We do this unconditionally as Connect will always return an error. 487 waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents) 488 wq.EventRegister(&waitEntry) 489 defer wq.EventUnregister(&waitEntry) 490 491 select { 492 case <-ctx.Done(): 493 return nil, ctx.Err() 494 default: 495 } 496 497 // Bind before connect if requested. 498 if localAddr != (tcpip.FullAddress{}) { 499 if err = ep.Bind(localAddr); err != nil { 500 return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err) 501 } 502 } 503 504 err = ep.Connect(remoteAddr) 505 if _, ok := err.(*tcpip.ErrConnectStarted); ok { 506 select { 507 case <-ctx.Done(): 508 ep.Close() 509 return nil, ctx.Err() 510 case <-notifyCh: 511 } 512 513 err = ep.LastError() 514 } 515 if err != nil { 516 ep.Close() 517 return nil, &net.OpError{ 518 Op: "connect", 519 Net: "tcp", 520 Addr: fullToTCPAddr(remoteAddr), 521 Err: errors.New(err.String()), 522 } 523 } 524 525 return NewTCPConn(&wq, ep), nil 526 } 527 528 // DialContextTCP creates a new TCPConn connected to the specified address 529 // with the option of adding cancellation and timeouts. 530 func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { 531 return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network) 532 } 533 534 // A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements 535 // net.Conn and net.PacketConn. 536 type UDPConn struct { 537 deadlineTimer 538 539 stack *stack.Stack 540 ep tcpip.Endpoint 541 wq *waiter.Queue 542 } 543 544 // NewUDPConn creates a new UDPConn. 545 func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn { 546 c := &UDPConn{ 547 stack: s, 548 ep: ep, 549 wq: wq, 550 } 551 c.deadlineTimer.init() 552 return c 553 } 554 555 // DialUDP creates a new UDPConn. 556 // 557 // If laddr is nil, a local address is automatically chosen. 558 // 559 // If raddr is nil, the UDPConn is left unconnected. 560 func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) { 561 var wq waiter.Queue 562 ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) 563 if err != nil { 564 return nil, errors.New(err.String()) 565 } 566 567 if laddr != nil { 568 if err := ep.Bind(*laddr); err != nil { 569 ep.Close() 570 return nil, &net.OpError{ 571 Op: "bind", 572 Net: "udp", 573 Addr: fullToUDPAddr(*laddr), 574 Err: errors.New(err.String()), 575 } 576 } 577 } 578 579 c := NewUDPConn(s, &wq, ep) 580 581 if raddr != nil { 582 if err := c.ep.Connect(*raddr); err != nil { 583 c.ep.Close() 584 return nil, &net.OpError{ 585 Op: "connect", 586 Net: "udp", 587 Addr: fullToUDPAddr(*raddr), 588 Err: errors.New(err.String()), 589 } 590 } 591 } 592 593 return c, nil 594 } 595 596 func (c *UDPConn) newOpError(op string, err error) *net.OpError { 597 return c.newRemoteOpError(op, nil, err) 598 } 599 600 func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { 601 return &net.OpError{ 602 Op: op, 603 Net: "udp", 604 Source: c.LocalAddr(), 605 Addr: remote, 606 Err: err, 607 } 608 } 609 610 // RemoteAddr implements net.Conn.RemoteAddr. 611 func (c *UDPConn) RemoteAddr() net.Addr { 612 a, err := c.ep.GetRemoteAddress() 613 if err != nil { 614 return nil 615 } 616 return fullToUDPAddr(a) 617 } 618 619 // Read implements net.Conn.Read 620 func (c *UDPConn) Read(b []byte) (int, error) { 621 bytesRead, _, err := c.ReadFrom(b) 622 return bytesRead, err 623 } 624 625 // ReadFrom implements net.PacketConn.ReadFrom. 626 func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { 627 deadline := c.readCancel() 628 629 var addr tcpip.FullAddress 630 n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c) 631 if err != nil { 632 return 0, nil, err 633 } 634 return n, fullToUDPAddr(addr), nil 635 } 636 637 func (c *UDPConn) Write(b []byte) (int, error) { 638 return c.WriteTo(b, nil) 639 } 640 641 // WriteTo implements net.PacketConn.WriteTo. 642 func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { 643 deadline := c.writeCancel() 644 645 // Check if deadline has already expired. 646 select { 647 case <-deadline: 648 return 0, c.newRemoteOpError("write", addr, &timeoutError{}) 649 default: 650 } 651 652 // If we're being called by Write, there is no addr 653 writeOptions := tcpip.WriteOptions{} 654 if addr != nil { 655 ua := addr.(*net.UDPAddr) 656 writeOptions.To = &tcpip.FullAddress{ 657 Addr: tcpip.Address(ua.IP), 658 Port: uint16(ua.Port), 659 } 660 } 661 662 var r bytes.Reader 663 r.Reset(b) 664 n, err := c.ep.Write(&r, writeOptions) 665 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 666 // Create wait queue entry that notifies a channel. 667 waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents) 668 c.wq.EventRegister(&waitEntry) 669 defer c.wq.EventUnregister(&waitEntry) 670 for { 671 select { 672 case <-deadline: 673 return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) 674 case <-notifyCh: 675 } 676 677 n, err = c.ep.Write(&r, writeOptions) 678 if _, ok := err.(*tcpip.ErrWouldBlock); !ok { 679 break 680 } 681 } 682 } 683 684 if err == nil { 685 return int(n), nil 686 } 687 688 return int(n), c.newRemoteOpError("write", addr, errors.New(err.String())) 689 } 690 691 // Close implements net.PacketConn.Close. 692 func (c *UDPConn) Close() error { 693 c.ep.Close() 694 return nil 695 } 696 697 // LocalAddr implements net.PacketConn.LocalAddr. 698 func (c *UDPConn) LocalAddr() net.Addr { 699 a, err := c.ep.GetLocalAddress() 700 if err != nil { 701 return nil 702 } 703 return fullToUDPAddr(a) 704 }