github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/adapters/gonet/gonet.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package gonet provides a Go net package compatible wrapper for a tcpip stack. 16 package gonet 17 18 import ( 19 "bytes" 20 "context" 21 "errors" 22 "io" 23 "net" 24 "time" 25 26 "github.com/SagerNet/gvisor/pkg/sync" 27 "github.com/SagerNet/gvisor/pkg/tcpip" 28 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 29 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 30 "github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp" 31 "github.com/SagerNet/gvisor/pkg/tcpip/transport/udp" 32 "github.com/SagerNet/gvisor/pkg/waiter" 33 ) 34 35 var ( 36 errCanceled = errors.New("operation canceled") 37 errWouldBlock = errors.New("operation would block") 38 ) 39 40 // timeoutError is how the net package reports timeouts. 41 type timeoutError struct{} 42 43 func (e *timeoutError) Error() string { return "i/o timeout" } 44 func (e *timeoutError) Timeout() bool { return true } 45 func (e *timeoutError) Temporary() bool { return true } 46 47 // A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements 48 // net.Listener. 49 type TCPListener struct { 50 stack *stack.Stack 51 ep tcpip.Endpoint 52 wq *waiter.Queue 53 cancel chan struct{} 54 } 55 56 // NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint. 57 func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener { 58 return &TCPListener{ 59 stack: s, 60 ep: ep, 61 wq: wq, 62 cancel: make(chan struct{}), 63 } 64 } 65 66 // ListenTCP creates a new TCPListener. 67 func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) { 68 // Create a TCP endpoint, bind it, then start listening. 69 var wq waiter.Queue 70 ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) 71 if err != nil { 72 return nil, errors.New(err.String()) 73 } 74 75 if err := ep.Bind(addr); err != nil { 76 ep.Close() 77 return nil, &net.OpError{ 78 Op: "bind", 79 Net: "tcp", 80 Addr: fullToTCPAddr(addr), 81 Err: errors.New(err.String()), 82 } 83 } 84 85 if err := ep.Listen(10); err != nil { 86 ep.Close() 87 return nil, &net.OpError{ 88 Op: "listen", 89 Net: "tcp", 90 Addr: fullToTCPAddr(addr), 91 Err: errors.New(err.String()), 92 } 93 } 94 95 return NewTCPListener(s, &wq, ep), nil 96 } 97 98 // Close implements net.Listener.Close. 99 func (l *TCPListener) Close() error { 100 l.ep.Close() 101 return nil 102 } 103 104 // Shutdown stops the HTTP server. 105 func (l *TCPListener) Shutdown() { 106 l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) 107 close(l.cancel) // broadcast cancellation 108 } 109 110 // Addr implements net.Listener.Addr. 111 func (l *TCPListener) Addr() net.Addr { 112 a, err := l.ep.GetLocalAddress() 113 if err != nil { 114 return nil 115 } 116 return fullToTCPAddr(a) 117 } 118 119 type deadlineTimer struct { 120 // mu protects the fields below. 121 mu sync.Mutex 122 123 readTimer *time.Timer 124 readCancelCh chan struct{} 125 writeTimer *time.Timer 126 writeCancelCh chan struct{} 127 } 128 129 func (d *deadlineTimer) init() { 130 d.readCancelCh = make(chan struct{}) 131 d.writeCancelCh = make(chan struct{}) 132 } 133 134 func (d *deadlineTimer) readCancel() <-chan struct{} { 135 d.mu.Lock() 136 c := d.readCancelCh 137 d.mu.Unlock() 138 return c 139 } 140 func (d *deadlineTimer) writeCancel() <-chan struct{} { 141 d.mu.Lock() 142 c := d.writeCancelCh 143 d.mu.Unlock() 144 return c 145 } 146 147 // setDeadline contains the shared logic for setting a deadline. 148 // 149 // cancelCh and timer must be pointers to deadlineTimer.readCancelCh and 150 // deadlineTimer.readTimer or deadlineTimer.writeCancelCh and 151 // deadlineTimer.writeTimer. 152 // 153 // setDeadline must only be called while holding d.mu. 154 func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { 155 if *timer != nil && !(*timer).Stop() { 156 *cancelCh = make(chan struct{}) 157 } 158 159 // Create a new channel if we already closed it due to setting an already 160 // expired time. We won't race with the timer because we already handled 161 // that above. 162 select { 163 case <-*cancelCh: 164 *cancelCh = make(chan struct{}) 165 default: 166 } 167 168 // "A zero value for t means I/O operations will not time out." 169 // - net.Conn.SetDeadline 170 if t.IsZero() { 171 return 172 } 173 174 timeout := t.Sub(time.Now()) 175 if timeout <= 0 { 176 close(*cancelCh) 177 return 178 } 179 180 // Timer.Stop returns whether or not the AfterFunc has started, but 181 // does not indicate whether or not it has completed. Make a copy of 182 // the cancel channel to prevent this code from racing with the next 183 // call of setDeadline replacing *cancelCh. 184 ch := *cancelCh 185 *timer = time.AfterFunc(timeout, func() { 186 close(ch) 187 }) 188 } 189 190 // SetReadDeadline implements net.Conn.SetReadDeadline and 191 // net.PacketConn.SetReadDeadline. 192 func (d *deadlineTimer) SetReadDeadline(t time.Time) error { 193 d.mu.Lock() 194 d.setDeadline(&d.readCancelCh, &d.readTimer, t) 195 d.mu.Unlock() 196 return nil 197 } 198 199 // SetWriteDeadline implements net.Conn.SetWriteDeadline and 200 // net.PacketConn.SetWriteDeadline. 201 func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { 202 d.mu.Lock() 203 d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) 204 d.mu.Unlock() 205 return nil 206 } 207 208 // SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. 209 func (d *deadlineTimer) SetDeadline(t time.Time) error { 210 d.mu.Lock() 211 d.setDeadline(&d.readCancelCh, &d.readTimer, t) 212 d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) 213 d.mu.Unlock() 214 return nil 215 } 216 217 // A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn 218 // interface. 219 type TCPConn struct { 220 deadlineTimer 221 222 wq *waiter.Queue 223 ep tcpip.Endpoint 224 225 // readMu serializes reads and implicitly protects read. 226 // 227 // Lock ordering: 228 // If both readMu and deadlineTimer.mu are to be used in a single 229 // request, readMu must be acquired before deadlineTimer.mu. 230 readMu sync.Mutex 231 232 // read contains bytes that have been read from the endpoint, 233 // but haven't yet been returned. 234 read buffer.View 235 } 236 237 // NewTCPConn creates a new TCPConn. 238 func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { 239 c := &TCPConn{ 240 wq: wq, 241 ep: ep, 242 } 243 c.deadlineTimer.init() 244 return c 245 } 246 247 // Accept implements net.Conn.Accept. 248 func (l *TCPListener) Accept() (net.Conn, error) { 249 n, wq, err := l.ep.Accept(nil) 250 251 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 252 // Create wait queue entry that notifies a channel. 253 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 254 l.wq.EventRegister(&waitEntry, waiter.ReadableEvents) 255 defer l.wq.EventUnregister(&waitEntry) 256 257 for { 258 n, wq, err = l.ep.Accept(nil) 259 260 if _, ok := err.(*tcpip.ErrWouldBlock); !ok { 261 break 262 } 263 264 select { 265 case <-l.cancel: 266 return nil, errCanceled 267 case <-notifyCh: 268 } 269 } 270 } 271 272 if err != nil { 273 return nil, &net.OpError{ 274 Op: "accept", 275 Net: "tcp", 276 Addr: l.Addr(), 277 Err: errors.New(err.String()), 278 } 279 } 280 281 return NewTCPConn(wq, n), nil 282 } 283 284 type opErrorer interface { 285 newOpError(op string, err error) *net.OpError 286 } 287 288 // commonRead implements the common logic between net.Conn.Read and 289 // net.PacketConn.ReadFrom. 290 func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) { 291 select { 292 case <-deadline: 293 return 0, errorer.newOpError("read", &timeoutError{}) 294 default: 295 } 296 297 w := tcpip.SliceWriter(b) 298 opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} 299 res, err := ep.Read(&w, opts) 300 301 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 302 // Create wait queue entry that notifies a channel. 303 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 304 wq.EventRegister(&waitEntry, waiter.ReadableEvents) 305 defer wq.EventUnregister(&waitEntry) 306 for { 307 res, err = ep.Read(&w, opts) 308 if _, ok := err.(*tcpip.ErrWouldBlock); !ok { 309 break 310 } 311 select { 312 case <-deadline: 313 return 0, errorer.newOpError("read", &timeoutError{}) 314 case <-notifyCh: 315 } 316 } 317 } 318 319 if _, ok := err.(*tcpip.ErrClosedForReceive); ok { 320 return 0, io.EOF 321 } 322 323 if err != nil { 324 return 0, errorer.newOpError("read", errors.New(err.String())) 325 } 326 327 if addr != nil { 328 *addr = res.RemoteAddr 329 } 330 return res.Count, nil 331 } 332 333 // Read implements net.Conn.Read. 334 func (c *TCPConn) Read(b []byte) (int, error) { 335 c.readMu.Lock() 336 defer c.readMu.Unlock() 337 338 deadline := c.readCancel() 339 340 n, err := commonRead(b, c.ep, c.wq, deadline, nil, c) 341 if n != 0 { 342 c.ep.ModerateRecvBuf(n) 343 } 344 return n, err 345 } 346 347 // Write implements net.Conn.Write. 348 func (c *TCPConn) Write(b []byte) (int, error) { 349 deadline := c.writeCancel() 350 351 // Check if deadlineTimer has already expired. 352 select { 353 case <-deadline: 354 return 0, c.newOpError("write", &timeoutError{}) 355 default: 356 } 357 358 // We must handle two soft failure conditions simultaneously: 359 // 1. Write may write nothing and return *tcpip.ErrWouldBlock. 360 // If this happens, we need to register for notifications if we have 361 // not already and wait to try again. 362 // 2. Write may write fewer than the full number of bytes and return 363 // without error. In this case we need to try writing the remaining 364 // bytes again. I do not need to register for notifications. 365 // 366 // What is more, these two soft failure conditions can be interspersed. 367 // There is no guarantee that all of the condition #1s will occur before 368 // all of the condition #2s or visa-versa. 369 var ( 370 r bytes.Reader 371 nbytes int 372 entry waiter.Entry 373 ch <-chan struct{} 374 ) 375 for nbytes != len(b) { 376 r.Reset(b[nbytes:]) 377 n, err := c.ep.Write(&r, tcpip.WriteOptions{}) 378 nbytes += int(n) 379 switch err.(type) { 380 case nil: 381 case *tcpip.ErrWouldBlock: 382 if ch == nil { 383 entry, ch = waiter.NewChannelEntry(nil) 384 385 c.wq.EventRegister(&entry, waiter.WritableEvents) 386 defer c.wq.EventUnregister(&entry) 387 } else { 388 // Don't wait immediately after registration in case more data 389 // became available between when we last checked and when we setup 390 // the notification. 391 select { 392 case <-deadline: 393 return nbytes, c.newOpError("write", &timeoutError{}) 394 case <-ch: 395 continue 396 } 397 } 398 default: 399 return nbytes, c.newOpError("write", errors.New(err.String())) 400 } 401 } 402 return nbytes, nil 403 } 404 405 // Close implements net.Conn.Close. 406 func (c *TCPConn) Close() error { 407 c.ep.Close() 408 return nil 409 } 410 411 // CloseRead shuts down the reading side of the TCP connection. Most callers 412 // should just use Close. 413 // 414 // A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. 415 func (c *TCPConn) CloseRead() error { 416 if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { 417 return c.newOpError("close", errors.New(terr.String())) 418 } 419 return nil 420 } 421 422 // CloseWrite shuts down the writing side of the TCP connection. Most callers 423 // should just use Close. 424 // 425 // A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. 426 func (c *TCPConn) CloseWrite() error { 427 if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { 428 return c.newOpError("close", errors.New(terr.String())) 429 } 430 return nil 431 } 432 433 // LocalAddr implements net.Conn.LocalAddr. 434 func (c *TCPConn) LocalAddr() net.Addr { 435 a, err := c.ep.GetLocalAddress() 436 if err != nil { 437 return nil 438 } 439 return fullToTCPAddr(a) 440 } 441 442 // RemoteAddr implements net.Conn.RemoteAddr. 443 func (c *TCPConn) RemoteAddr() net.Addr { 444 a, err := c.ep.GetRemoteAddress() 445 if err != nil { 446 return nil 447 } 448 return fullToTCPAddr(a) 449 } 450 451 func (c *TCPConn) newOpError(op string, err error) *net.OpError { 452 return &net.OpError{ 453 Op: op, 454 Net: "tcp", 455 Source: c.LocalAddr(), 456 Addr: c.RemoteAddr(), 457 Err: err, 458 } 459 } 460 461 func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr { 462 return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} 463 } 464 465 func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { 466 return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} 467 } 468 469 // DialTCP creates a new TCPConn connected to the specified address. 470 func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { 471 return DialContextTCP(context.Background(), s, addr, network) 472 } 473 474 // DialContextTCP creates a new TCPConn connected to the specified address 475 // with the option of adding cancellation and timeouts. 476 func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { 477 // Create TCP endpoint, then connect. 478 var wq waiter.Queue 479 ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) 480 if err != nil { 481 return nil, errors.New(err.String()) 482 } 483 484 // Create wait queue entry that notifies a channel. 485 // 486 // We do this unconditionally as Connect will always return an error. 487 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 488 wq.EventRegister(&waitEntry, waiter.WritableEvents) 489 defer wq.EventUnregister(&waitEntry) 490 491 select { 492 case <-ctx.Done(): 493 return nil, ctx.Err() 494 default: 495 } 496 497 err = ep.Connect(addr) 498 if _, ok := err.(*tcpip.ErrConnectStarted); ok { 499 select { 500 case <-ctx.Done(): 501 ep.Close() 502 return nil, ctx.Err() 503 case <-notifyCh: 504 } 505 506 err = ep.LastError() 507 } 508 if err != nil { 509 ep.Close() 510 return nil, &net.OpError{ 511 Op: "connect", 512 Net: "tcp", 513 Addr: fullToTCPAddr(addr), 514 Err: errors.New(err.String()), 515 } 516 } 517 518 return NewTCPConn(&wq, ep), nil 519 } 520 521 // A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements 522 // net.Conn and net.PacketConn. 523 type UDPConn struct { 524 deadlineTimer 525 526 stack *stack.Stack 527 ep tcpip.Endpoint 528 wq *waiter.Queue 529 } 530 531 // NewUDPConn creates a new UDPConn. 532 func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn { 533 c := &UDPConn{ 534 stack: s, 535 ep: ep, 536 wq: wq, 537 } 538 c.deadlineTimer.init() 539 return c 540 } 541 542 // DialUDP creates a new UDPConn. 543 // 544 // If laddr is nil, a local address is automatically chosen. 545 // 546 // If raddr is nil, the UDPConn is left unconnected. 547 func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) { 548 var wq waiter.Queue 549 ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) 550 if err != nil { 551 return nil, errors.New(err.String()) 552 } 553 554 if laddr != nil { 555 if err := ep.Bind(*laddr); err != nil { 556 ep.Close() 557 return nil, &net.OpError{ 558 Op: "bind", 559 Net: "udp", 560 Addr: fullToUDPAddr(*laddr), 561 Err: errors.New(err.String()), 562 } 563 } 564 } 565 566 c := NewUDPConn(s, &wq, ep) 567 568 if raddr != nil { 569 if err := c.ep.Connect(*raddr); err != nil { 570 c.ep.Close() 571 return nil, &net.OpError{ 572 Op: "connect", 573 Net: "udp", 574 Addr: fullToUDPAddr(*raddr), 575 Err: errors.New(err.String()), 576 } 577 } 578 } 579 580 return c, nil 581 } 582 583 func (c *UDPConn) newOpError(op string, err error) *net.OpError { 584 return c.newRemoteOpError(op, nil, err) 585 } 586 587 func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { 588 return &net.OpError{ 589 Op: op, 590 Net: "udp", 591 Source: c.LocalAddr(), 592 Addr: remote, 593 Err: err, 594 } 595 } 596 597 // RemoteAddr implements net.Conn.RemoteAddr. 598 func (c *UDPConn) RemoteAddr() net.Addr { 599 a, err := c.ep.GetRemoteAddress() 600 if err != nil { 601 return nil 602 } 603 return fullToUDPAddr(a) 604 } 605 606 // Read implements net.Conn.Read 607 func (c *UDPConn) Read(b []byte) (int, error) { 608 bytesRead, _, err := c.ReadFrom(b) 609 return bytesRead, err 610 } 611 612 // ReadFrom implements net.PacketConn.ReadFrom. 613 func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { 614 deadline := c.readCancel() 615 616 var addr tcpip.FullAddress 617 n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c) 618 if err != nil { 619 return 0, nil, err 620 } 621 return n, fullToUDPAddr(addr), nil 622 } 623 624 func (c *UDPConn) Write(b []byte) (int, error) { 625 return c.WriteTo(b, nil) 626 } 627 628 // WriteTo implements net.PacketConn.WriteTo. 629 func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { 630 deadline := c.writeCancel() 631 632 // Check if deadline has already expired. 633 select { 634 case <-deadline: 635 return 0, c.newRemoteOpError("write", addr, &timeoutError{}) 636 default: 637 } 638 639 // If we're being called by Write, there is no addr 640 writeOptions := tcpip.WriteOptions{} 641 if addr != nil { 642 ua := addr.(*net.UDPAddr) 643 writeOptions.To = &tcpip.FullAddress{ 644 Addr: tcpip.Address(ua.IP), 645 Port: uint16(ua.Port), 646 } 647 } 648 649 var r bytes.Reader 650 r.Reset(b) 651 n, err := c.ep.Write(&r, writeOptions) 652 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 653 // Create wait queue entry that notifies a channel. 654 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 655 c.wq.EventRegister(&waitEntry, waiter.WritableEvents) 656 defer c.wq.EventUnregister(&waitEntry) 657 for { 658 select { 659 case <-deadline: 660 return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) 661 case <-notifyCh: 662 } 663 664 n, err = c.ep.Write(&r, writeOptions) 665 if _, ok := err.(*tcpip.ErrWouldBlock); !ok { 666 break 667 } 668 } 669 } 670 671 if err == nil { 672 return int(n), nil 673 } 674 675 return int(n), c.newRemoteOpError("write", addr, errors.New(err.String())) 676 } 677 678 // Close implements net.PacketConn.Close. 679 func (c *UDPConn) Close() error { 680 c.ep.Close() 681 return nil 682 } 683 684 // LocalAddr implements net.PacketConn.LocalAddr. 685 func (c *UDPConn) LocalAddr() net.Addr { 686 a, err := c.ep.GetLocalAddress() 687 if err != nil { 688 return nil 689 } 690 return fullToUDPAddr(a) 691 }