github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/transport/tcp/endpoint_state.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	"github.com/sagernet/gvisor/pkg/atomicbitops"
    22  	"github.com/sagernet/gvisor/pkg/sync"
    23  	"github.com/sagernet/gvisor/pkg/tcpip"
    24  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    25  	"github.com/sagernet/gvisor/pkg/tcpip/ports"
    26  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    27  )
    28  
    29  // beforeSave is invoked by stateify.
    30  func (e *Endpoint) beforeSave() {
    31  	// Stop incoming packets.
    32  	e.segmentQueue.freeze()
    33  
    34  	e.mu.Lock()
    35  	defer e.mu.Unlock()
    36  
    37  	epState := e.EndpointState()
    38  	switch {
    39  	case epState == StateInitial || epState == StateBound:
    40  	case epState.connected() || epState.handshake():
    41  		if !e.route.HasSaveRestoreCapability() {
    42  			if !e.route.HasDisconnectOkCapability() {
    43  				panic(&tcpip.ErrSaveRejection{
    44  					Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort),
    45  				})
    46  			}
    47  			e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
    48  			e.mu.Unlock()
    49  			e.Close()
    50  			e.mu.Lock()
    51  		}
    52  		fallthrough
    53  	case epState == StateListen:
    54  		// Nothing to do.
    55  	case epState.closed():
    56  		// Nothing to do.
    57  	default:
    58  		panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
    59  	}
    60  
    61  	e.stack.RegisterResumableEndpoint(e)
    62  }
    63  
    64  // saveEndpoints is invoked by stateify.
    65  func (a *acceptQueue) saveEndpoints() []*Endpoint {
    66  	acceptedEndpoints := make([]*Endpoint, a.endpoints.Len())
    67  	for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
    68  		acceptedEndpoints[i] = e.Value.(*Endpoint)
    69  	}
    70  	return acceptedEndpoints
    71  }
    72  
    73  // loadEndpoints is invoked by stateify.
    74  func (a *acceptQueue) loadEndpoints(_ context.Context, acceptedEndpoints []*Endpoint) {
    75  	for _, ep := range acceptedEndpoints {
    76  		a.endpoints.PushBack(ep)
    77  	}
    78  }
    79  
    80  // saveState is invoked by stateify.
    81  func (e *Endpoint) saveState() EndpointState {
    82  	return e.EndpointState()
    83  }
    84  
    85  // Endpoint loading must be done in the following ordering by their state, to
    86  // avoid dangling connecting w/o listening peer, and to avoid conflicts in port
    87  // reservation.
    88  var connectedLoading sync.WaitGroup
    89  var listenLoading sync.WaitGroup
    90  var connectingLoading sync.WaitGroup
    91  
    92  // Bound endpoint loading happens last.
    93  
    94  // loadState is invoked by stateify.
    95  func (e *Endpoint) loadState(_ context.Context, epState EndpointState) {
    96  	// This is to ensure that the loading wait groups include all applicable
    97  	// endpoints before any asynchronous calls to the Wait() methods.
    98  	// For restore purposes we treat TimeWait like a connected endpoint.
    99  	if epState.connected() || epState == StateTimeWait {
   100  		connectedLoading.Add(1)
   101  	}
   102  	switch {
   103  	case epState == StateListen:
   104  		listenLoading.Add(1)
   105  	case epState.connecting():
   106  		connectingLoading.Add(1)
   107  	}
   108  	// Directly update the state here rather than using e.setEndpointState
   109  	// as the endpoint is still being loaded and the stack reference is not
   110  	// yet initialized.
   111  	e.state.Store(uint32(epState))
   112  }
   113  
   114  // afterLoad is invoked by stateify.
   115  func (e *Endpoint) afterLoad(ctx context.Context) {
   116  	// RacyLoad() can be used because we are initializing e.
   117  	e.origEndpointState = e.state.RacyLoad()
   118  	// Restore the endpoint to InitialState as it will be moved to
   119  	// its origEndpointState during Restore.
   120  	e.state = atomicbitops.FromUint32(uint32(StateInitial))
   121  	stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
   122  }
   123  
   124  // Restore implements tcpip.RestoredEndpoint.Restore.
   125  func (e *Endpoint) Restore(s *stack.Stack) {
   126  	if !e.EndpointState().closed() {
   127  		e.keepalive.timer.init(s.Clock(), timerHandler(e, e.keepaliveTimerExpired))
   128  	}
   129  	if snd := e.snd; snd != nil {
   130  		snd.resendTimer.init(s.Clock(), timerHandler(e, e.snd.retransmitTimerExpired))
   131  		snd.reorderTimer.init(s.Clock(), timerHandler(e, e.snd.rc.reorderTimerExpired))
   132  		snd.probeTimer.init(s.Clock(), timerHandler(e, e.snd.probeTimerExpired))
   133  		snd.corkTimer.init(s.Clock(), timerHandler(e, e.snd.corkTimerExpired))
   134  	}
   135  	e.stack = s
   136  	e.protocol = protocolFromStack(s)
   137  	e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
   138  	e.segmentQueue.thaw()
   139  
   140  	bind := func() {
   141  		e.mu.Lock()
   142  		defer e.mu.Unlock()
   143  		addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort})
   144  		if err != nil {
   145  			panic("unable to parse BindAddr: " + err.String())
   146  		}
   147  		portRes := ports.Reservation{
   148  			Networks:     e.effectiveNetProtos,
   149  			Transport:    ProtocolNumber,
   150  			Addr:         addr.Addr,
   151  			Port:         addr.Port,
   152  			Flags:        e.boundPortFlags,
   153  			BindToDevice: e.boundBindToDevice,
   154  			Dest:         e.boundDest,
   155  		}
   156  		if ok := e.stack.ReserveTuple(portRes); !ok {
   157  			panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
   158  		}
   159  		e.isPortReserved = true
   160  
   161  		// Mark endpoint as bound.
   162  		e.setEndpointState(StateBound)
   163  	}
   164  
   165  	epState := EndpointState(e.origEndpointState)
   166  	switch {
   167  	case epState.connected():
   168  		bind()
   169  		if e.connectingAddress.BitLen() == 0 {
   170  			e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress
   171  			// This endpoint is accepted by netstack but not yet by
   172  			// the app. If the endpoint is IPv6 but the remote
   173  			// address is IPv4, we need to connect as IPv6 so that
   174  			// dual-stack mode can be properly activated.
   175  			if e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.RemoteAddress.BitLen() != header.IPv6AddressSizeBits {
   176  				e.connectingAddress = tcpip.AddrFrom16Slice(append(
   177  					[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff},
   178  					e.TransportEndpointInfo.ID.RemoteAddress.AsSlice()...,
   179  				))
   180  			}
   181  		}
   182  		// Reset the scoreboard to reinitialize the sack information as
   183  		// we do not restore SACK information.
   184  		e.scoreboard.Reset()
   185  		e.mu.Lock()
   186  		err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false /* handshake */)
   187  		if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   188  			panic("endpoint connecting failed: " + err.String())
   189  		}
   190  		e.state.Store(e.origEndpointState)
   191  		// For FIN-WAIT-2 and TIME-WAIT we need to start the appropriate timers so
   192  		// that the socket is closed correctly.
   193  		switch epState {
   194  		case StateFinWait2:
   195  			e.finWait2Timer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, e.finWait2TimerExpired)
   196  		case StateTimeWait:
   197  			e.timeWaitTimer = e.stack.Clock().AfterFunc(e.getTimeWaitDuration(), e.timeWaitTimerExpired)
   198  		}
   199  
   200  		if e.ops.GetCorkOption() {
   201  			// Rearm the timer if TCP_CORK is enabled which will
   202  			// drain all the segments in the queue after restore.
   203  			e.snd.corkTimer.enable(MinRTO)
   204  		}
   205  		e.mu.Unlock()
   206  		connectedLoading.Done()
   207  	case epState == StateListen:
   208  		tcpip.AsyncLoading.Add(1)
   209  		go func() {
   210  			connectedLoading.Wait()
   211  			bind()
   212  			e.acceptMu.Lock()
   213  			backlog := e.acceptQueue.capacity
   214  			e.acceptMu.Unlock()
   215  			if err := e.Listen(backlog); err != nil {
   216  				panic("endpoint listening failed: " + err.String())
   217  			}
   218  			e.LockUser()
   219  			if e.shutdownFlags != 0 {
   220  				e.shutdownLocked(e.shutdownFlags)
   221  			}
   222  			e.UnlockUser()
   223  			listenLoading.Done()
   224  			tcpip.AsyncLoading.Done()
   225  		}()
   226  	case epState == StateConnecting:
   227  		// Initial SYN hasn't been sent yet so initiate a connect.
   228  		tcpip.AsyncLoading.Add(1)
   229  		go func() {
   230  			connectedLoading.Wait()
   231  			listenLoading.Wait()
   232  			bind()
   233  			err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort})
   234  			if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   235  				panic("endpoint connecting failed: " + err.String())
   236  			}
   237  			connectingLoading.Done()
   238  			tcpip.AsyncLoading.Done()
   239  		}()
   240  	case epState == StateSynSent || epState == StateSynRecv:
   241  		connectedLoading.Wait()
   242  		listenLoading.Wait()
   243  		// Initial SYN has been sent/received so we should bind the
   244  		// ports start the retransmit timer for the SYNs and let it
   245  		// naturally complete the connection.
   246  		bind()
   247  		e.mu.Lock()
   248  		defer e.mu.Unlock()
   249  		e.setEndpointState(epState)
   250  		r, err := e.stack.FindRoute(e.boundNICID, e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.RemoteAddress, e.effectiveNetProtos[0], false /* multicastLoop */)
   251  		if err != nil {
   252  			panic(fmt.Sprintf("FindRoute failed when restoring endpoint w/ ID: %+v", e.ID))
   253  		}
   254  		e.route = r
   255  		timer, err := newBackoffTimer(e.stack.Clock(), InitialRTO, MaxRTO, timerHandler(e, e.h.retransmitHandlerLocked))
   256  		if err != nil {
   257  			panic(fmt.Sprintf("newBackOffTimer(_, %s, %s, _) failed: %s", InitialRTO, MaxRTO, err))
   258  		}
   259  		e.h.retransmitTimer = timer
   260  		connectingLoading.Done()
   261  	case epState == StateBound:
   262  		tcpip.AsyncLoading.Add(1)
   263  		go func() {
   264  			connectedLoading.Wait()
   265  			listenLoading.Wait()
   266  			connectingLoading.Wait()
   267  			bind()
   268  			tcpip.AsyncLoading.Done()
   269  		}()
   270  	case epState == StateClose:
   271  		e.isPortReserved = false
   272  		e.state.Store(uint32(StateClose))
   273  		e.stack.CompleteTransportEndpointCleanup(e)
   274  		tcpip.DeleteDanglingEndpoint(e)
   275  	case epState == StateError:
   276  		e.state.Store(uint32(StateError))
   277  		e.stack.CompleteTransportEndpointCleanup(e)
   278  		tcpip.DeleteDanglingEndpoint(e)
   279  	}
   280  }
   281  
   282  // Resume implements tcpip.ResumableEndpoint.Resume.
   283  func (e *Endpoint) Resume() {
   284  	e.segmentQueue.thaw()
   285  }