github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/transport/tcp/endpoint_state.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops"
    21  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/ports"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack"
    26  )
    27  
    28  // beforeSave is invoked by stateify.
    29  func (e *endpoint) beforeSave() {
    30  	// Stop incoming packets.
    31  	e.segmentQueue.freeze()
    32  
    33  	e.mu.Lock()
    34  	defer e.mu.Unlock()
    35  
    36  	epState := e.EndpointState()
    37  	switch {
    38  	case epState == StateInitial || epState == StateBound:
    39  	case epState.connected() || epState.handshake():
    40  		if !e.route.HasSaveRestoreCapability() {
    41  			if !e.route.HasDisconnectOkCapability() {
    42  				panic(&tcpip.ErrSaveRejection{
    43  					Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort),
    44  				})
    45  			}
    46  			e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
    47  			e.mu.Unlock()
    48  			e.Close()
    49  			e.mu.Lock()
    50  		}
    51  		fallthrough
    52  	case epState == StateListen:
    53  		// Nothing to do.
    54  	case epState.closed():
    55  		// Nothing to do.
    56  	default:
    57  		panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
    58  	}
    59  }
    60  
    61  // saveEndpoints is invoked by stateify.
    62  func (a *acceptQueue) saveEndpoints() []*endpoint {
    63  	acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
    64  	for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
    65  		acceptedEndpoints[i] = e.Value.(*endpoint)
    66  	}
    67  	return acceptedEndpoints
    68  }
    69  
    70  // loadEndpoints is invoked by stateify.
    71  func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) {
    72  	for _, ep := range acceptedEndpoints {
    73  		a.endpoints.PushBack(ep)
    74  	}
    75  }
    76  
    77  // saveState is invoked by stateify.
    78  func (e *endpoint) saveState() EndpointState {
    79  	return e.EndpointState()
    80  }
    81  
    82  // Endpoint loading must be done in the following ordering by their state, to
    83  // avoid dangling connecting w/o listening peer, and to avoid conflicts in port
    84  // reservation.
    85  var connectedLoading sync.WaitGroup
    86  var listenLoading sync.WaitGroup
    87  var connectingLoading sync.WaitGroup
    88  
    89  // Bound endpoint loading happens last.
    90  
    91  // loadState is invoked by stateify.
    92  func (e *endpoint) loadState(epState EndpointState) {
    93  	// This is to ensure that the loading wait groups include all applicable
    94  	// endpoints before any asynchronous calls to the Wait() methods.
    95  	// For restore purposes we treat TimeWait like a connected endpoint.
    96  	if epState.connected() || epState == StateTimeWait {
    97  		connectedLoading.Add(1)
    98  	}
    99  	switch {
   100  	case epState == StateListen:
   101  		listenLoading.Add(1)
   102  	case epState.connecting():
   103  		connectingLoading.Add(1)
   104  	}
   105  	// Directly update the state here rather than using e.setEndpointState
   106  	// as the endpoint is still being loaded and the stack reference is not
   107  	// yet initialized.
   108  	e.state.Store(uint32(epState))
   109  }
   110  
   111  // afterLoad is invoked by stateify.
   112  func (e *endpoint) afterLoad() {
   113  	// RacyLoad() can be used because we are initializing e.
   114  	e.origEndpointState = e.state.RacyLoad()
   115  	// Restore the endpoint to InitialState as it will be moved to
   116  	// its origEndpointState during Resume.
   117  	e.state = atomicbitops.FromUint32(uint32(StateInitial))
   118  	stack.StackFromEnv.RegisterRestoredEndpoint(e)
   119  }
   120  
   121  // Resume implements tcpip.ResumableEndpoint.Resume.
   122  func (e *endpoint) Resume(s *stack.Stack) {
   123  	if !e.EndpointState().closed() {
   124  		e.keepalive.timer.init(s.Clock(), maybeFailTimerHandler(e, e.keepaliveTimerExpired))
   125  	}
   126  	if snd := e.snd; snd != nil {
   127  		snd.resendTimer.init(s.Clock(), maybeFailTimerHandler(e, e.snd.retransmitTimerExpired))
   128  		snd.reorderTimer.init(s.Clock(), timerHandler(e, e.snd.rc.reorderTimerExpired))
   129  		snd.probeTimer.init(s.Clock(), timerHandler(e, e.snd.probeTimerExpired))
   130  	}
   131  	e.stack = s
   132  	e.protocol = protocolFromStack(s)
   133  	e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
   134  	e.segmentQueue.thaw()
   135  
   136  	bind := func() {
   137  		e.mu.Lock()
   138  		defer e.mu.Unlock()
   139  		addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort})
   140  		if err != nil {
   141  			panic("unable to parse BindAddr: " + err.String())
   142  		}
   143  		portRes := ports.Reservation{
   144  			Networks:     e.effectiveNetProtos,
   145  			Transport:    ProtocolNumber,
   146  			Addr:         addr.Addr,
   147  			Port:         addr.Port,
   148  			Flags:        e.boundPortFlags,
   149  			BindToDevice: e.boundBindToDevice,
   150  			Dest:         e.boundDest,
   151  		}
   152  		if ok := e.stack.ReserveTuple(portRes); !ok {
   153  			panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
   154  		}
   155  		e.isPortReserved = true
   156  
   157  		// Mark endpoint as bound.
   158  		e.setEndpointState(StateBound)
   159  	}
   160  
   161  	epState := EndpointState(e.origEndpointState)
   162  	switch {
   163  	case epState.connected():
   164  		bind()
   165  		if e.connectingAddress.BitLen() == 0 {
   166  			e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress
   167  			// This endpoint is accepted by netstack but not yet by
   168  			// the app. If the endpoint is IPv6 but the remote
   169  			// address is IPv4, we need to connect as IPv6 so that
   170  			// dual-stack mode can be properly activated.
   171  			if e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.RemoteAddress.BitLen() != header.IPv6AddressSizeBits {
   172  				e.connectingAddress = tcpip.AddrFrom16Slice(append(
   173  					[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff},
   174  					e.TransportEndpointInfo.ID.RemoteAddress.AsSlice()...,
   175  				))
   176  			}
   177  		}
   178  		// Reset the scoreboard to reinitialize the sack information as
   179  		// we do not restore SACK information.
   180  		e.scoreboard.Reset()
   181  		e.mu.Lock()
   182  		err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false /* handshake */)
   183  		if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   184  			panic("endpoint connecting failed: " + err.String())
   185  		}
   186  		e.state.Store(e.origEndpointState)
   187  		// For FIN-WAIT-2 and TIME-WAIT we need to start the appropriate timers so
   188  		// that the socket is closed correctly.
   189  		switch epState {
   190  		case StateFinWait2:
   191  			e.finWait2Timer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, e.finWait2TimerExpired)
   192  		case StateTimeWait:
   193  			e.timeWaitTimer = e.stack.Clock().AfterFunc(e.getTimeWaitDuration(), e.timeWaitTimerExpired)
   194  		}
   195  
   196  		e.mu.Unlock()
   197  		connectedLoading.Done()
   198  	case epState == StateListen:
   199  		tcpip.AsyncLoading.Add(1)
   200  		go func() {
   201  			connectedLoading.Wait()
   202  			bind()
   203  			e.acceptMu.Lock()
   204  			backlog := e.acceptQueue.capacity
   205  			e.acceptMu.Unlock()
   206  			if err := e.Listen(backlog); err != nil {
   207  				panic("endpoint listening failed: " + err.String())
   208  			}
   209  			e.LockUser()
   210  			if e.shutdownFlags != 0 {
   211  				e.shutdownLocked(e.shutdownFlags)
   212  			}
   213  			e.UnlockUser()
   214  			listenLoading.Done()
   215  			tcpip.AsyncLoading.Done()
   216  		}()
   217  	case epState == StateConnecting:
   218  		// Initial SYN hasn't been sent yet so initiate a connect.
   219  		tcpip.AsyncLoading.Add(1)
   220  		go func() {
   221  			connectedLoading.Wait()
   222  			listenLoading.Wait()
   223  			bind()
   224  			err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort})
   225  			if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   226  				panic("endpoint connecting failed: " + err.String())
   227  			}
   228  			connectingLoading.Done()
   229  			tcpip.AsyncLoading.Done()
   230  		}()
   231  	case epState == StateSynSent || epState == StateSynRecv:
   232  		connectedLoading.Wait()
   233  		listenLoading.Wait()
   234  		// Initial SYN has been sent/received so we should bind the
   235  		// ports start the retransmit timer for the SYNs and let it
   236  		// naturally complete the connection.
   237  		bind()
   238  		e.mu.Lock()
   239  		defer e.mu.Unlock()
   240  		e.setEndpointState(epState)
   241  		r, err := e.stack.FindRoute(e.boundNICID, e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.RemoteAddress, e.effectiveNetProtos[0], false /* multicastLoop */)
   242  		if err != nil {
   243  			panic(fmt.Sprintf("FindRoute failed when restoring endpoint w/ ID: %+v", e.ID))
   244  		}
   245  		e.route = r
   246  		timer, err := newBackoffTimer(e.stack.Clock(), InitialRTO, MaxRTO, maybeFailTimerHandler(e, e.h.retransmitHandlerLocked))
   247  		if err != nil {
   248  			panic(fmt.Sprintf("newBackOffTimer(_, %s, %s, _) failed: %s", InitialRTO, MaxRTO, err))
   249  		}
   250  		e.h.retransmitTimer = timer
   251  		connectingLoading.Done()
   252  	case epState == StateBound:
   253  		tcpip.AsyncLoading.Add(1)
   254  		go func() {
   255  			connectedLoading.Wait()
   256  			listenLoading.Wait()
   257  			connectingLoading.Wait()
   258  			bind()
   259  			tcpip.AsyncLoading.Done()
   260  		}()
   261  	case epState == StateClose:
   262  		e.isPortReserved = false
   263  		e.state.Store(uint32(StateClose))
   264  		e.stack.CompleteTransportEndpointCleanup(e)
   265  		tcpip.DeleteDanglingEndpoint(e)
   266  	case epState == StateError:
   267  		e.state.Store(uint32(StateError))
   268  		e.stack.CompleteTransportEndpointCleanup(e)
   269  		tcpip.DeleteDanglingEndpoint(e)
   270  	}
   271  }