inet.af/netstack@v0.0.0-20220214151720-7585b01ddccf/tcpip/transport/tcp/endpoint_state.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"fmt"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"inet.af/netstack/sync"
    23  	"inet.af/netstack/tcpip"
    24  	"inet.af/netstack/tcpip/header"
    25  	"inet.af/netstack/tcpip/ports"
    26  	"inet.af/netstack/tcpip/stack"
    27  )
    28  
    29  // +checklocks:e.mu
    30  func (e *endpoint) drainSegmentLocked() {
    31  	// Drain only up to once.
    32  	if e.drainDone != nil {
    33  		return
    34  	}
    35  
    36  	e.drainDone = make(chan struct{})
    37  	e.undrain = make(chan struct{})
    38  	e.mu.Unlock()
    39  
    40  	e.notifyProtocolGoroutine(notifyDrain)
    41  	<-e.drainDone
    42  
    43  	e.mu.Lock()
    44  }
    45  
    46  // beforeSave is invoked by stateify.
    47  func (e *endpoint) beforeSave() {
    48  	// Stop incoming packets.
    49  	e.segmentQueue.freeze()
    50  
    51  	e.mu.Lock()
    52  	defer e.mu.Unlock()
    53  
    54  	epState := e.EndpointState()
    55  	switch {
    56  	case epState == StateInitial || epState == StateBound:
    57  	case epState.connected() || epState.handshake():
    58  		if !e.route.HasSaveRestoreCapability() {
    59  			if !e.route.HasDisconncetOkCapability() {
    60  				panic(&tcpip.ErrSaveRejection{
    61  					Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort),
    62  				})
    63  			}
    64  			e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
    65  			e.mu.Unlock()
    66  			e.Close()
    67  			e.mu.Lock()
    68  		}
    69  		if !e.workerRunning {
    70  			// The endpoint must be in the accepted queue or has been just
    71  			// disconnected and closed.
    72  			break
    73  		}
    74  		fallthrough
    75  	case epState == StateListen || epState == StateConnecting:
    76  		e.drainSegmentLocked()
    77  		// Refresh epState, since drainSegmentLocked may have changed it.
    78  		epState = e.EndpointState()
    79  		if !epState.closed() {
    80  			if !e.workerRunning {
    81  				panic("endpoint has no worker running in listen, connecting, or connected state")
    82  			}
    83  		}
    84  	case epState.closed():
    85  		for e.workerRunning {
    86  			e.mu.Unlock()
    87  			time.Sleep(100 * time.Millisecond)
    88  			e.mu.Lock()
    89  		}
    90  		if e.workerRunning {
    91  			panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID))
    92  		}
    93  	default:
    94  		panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
    95  	}
    96  
    97  	if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
    98  		panic("endpoint still has waiters upon save")
    99  	}
   100  }
   101  
   102  // saveEndpoints is invoked by stateify.
   103  func (a *acceptQueue) saveEndpoints() []*endpoint {
   104  	acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
   105  	for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
   106  		acceptedEndpoints[i] = e.Value.(*endpoint)
   107  	}
   108  	return acceptedEndpoints
   109  }
   110  
   111  // loadEndpoints is invoked by stateify.
   112  func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) {
   113  	for _, ep := range acceptedEndpoints {
   114  		a.endpoints.PushBack(ep)
   115  	}
   116  }
   117  
   118  // saveState is invoked by stateify.
   119  func (e *endpoint) saveState() EndpointState {
   120  	return e.EndpointState()
   121  }
   122  
   123  // Endpoint loading must be done in the following ordering by their state, to
   124  // avoid dangling connecting w/o listening peer, and to avoid conflicts in port
   125  // reservation.
   126  var connectedLoading sync.WaitGroup
   127  var listenLoading sync.WaitGroup
   128  var connectingLoading sync.WaitGroup
   129  
   130  // Bound endpoint loading happens last.
   131  
   132  // loadState is invoked by stateify.
   133  func (e *endpoint) loadState(epState EndpointState) {
   134  	// This is to ensure that the loading wait groups include all applicable
   135  	// endpoints before any asynchronous calls to the Wait() methods.
   136  	// For restore purposes we treat TimeWait like a connected endpoint.
   137  	if epState.connected() || epState == StateTimeWait {
   138  		connectedLoading.Add(1)
   139  	}
   140  	switch {
   141  	case epState == StateListen:
   142  		listenLoading.Add(1)
   143  	case epState.connecting():
   144  		connectingLoading.Add(1)
   145  	}
   146  	// Directly update the state here rather than using e.setEndpointState
   147  	// as the endpoint is still being loaded and the stack reference is not
   148  	// yet initialized.
   149  	atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
   150  }
   151  
   152  // afterLoad is invoked by stateify.
   153  func (e *endpoint) afterLoad() {
   154  	e.origEndpointState = e.state
   155  	// Restore the endpoint to InitialState as it will be moved to
   156  	// its origEndpointState during Resume.
   157  	e.state = uint32(StateInitial)
   158  	// Condition variables and mutexs are not S/R'ed so reinitialize
   159  	// acceptCond with e.acceptMu.
   160  	e.acceptCond = sync.NewCond(&e.acceptMu)
   161  	stack.StackFromEnv.RegisterRestoredEndpoint(e)
   162  }
   163  
   164  // Resume implements tcpip.ResumableEndpoint.Resume.
   165  func (e *endpoint) Resume(s *stack.Stack) {
   166  	e.keepalive.timer.init(s.Clock(), &e.keepalive.waker)
   167  	if snd := e.snd; snd != nil {
   168  		snd.resendTimer.init(s.Clock(), &snd.resendWaker)
   169  		snd.reorderTimer.init(s.Clock(), &snd.reorderWaker)
   170  		snd.probeTimer.init(s.Clock(), &snd.probeWaker)
   171  	}
   172  	e.stack = s
   173  	e.protocol = protocolFromStack(s)
   174  	e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
   175  	e.segmentQueue.thaw()
   176  	epState := EndpointState(e.origEndpointState)
   177  	switch epState {
   178  	case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
   179  		var ss tcpip.TCPSendBufferSizeRangeOption
   180  		if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
   181  			sendBufferSize := e.getSendBufferSize()
   182  			if sendBufferSize < ss.Min || sendBufferSize > ss.Max {
   183  				panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max))
   184  			}
   185  		}
   186  
   187  		var rs tcpip.TCPReceiveBufferSizeRangeOption
   188  		if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
   189  			if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) {
   190  				panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max))
   191  			}
   192  		}
   193  	}
   194  
   195  	bind := func() {
   196  		addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort})
   197  		if err != nil {
   198  			panic("unable to parse BindAddr: " + err.String())
   199  		}
   200  		portRes := ports.Reservation{
   201  			Networks:     e.effectiveNetProtos,
   202  			Transport:    ProtocolNumber,
   203  			Addr:         addr.Addr,
   204  			Port:         addr.Port,
   205  			Flags:        e.boundPortFlags,
   206  			BindToDevice: e.boundBindToDevice,
   207  			Dest:         e.boundDest,
   208  		}
   209  		if ok := e.stack.ReserveTuple(portRes); !ok {
   210  			panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
   211  		}
   212  		e.isPortReserved = true
   213  
   214  		// Mark endpoint as bound.
   215  		e.setEndpointState(StateBound)
   216  	}
   217  
   218  	switch {
   219  	case epState.connected():
   220  		bind()
   221  		if len(e.connectingAddress) == 0 {
   222  			e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress
   223  			// This endpoint is accepted by netstack but not yet by
   224  			// the app. If the endpoint is IPv6 but the remote
   225  			// address is IPv4, we need to connect as IPv6 so that
   226  			// dual-stack mode can be properly activated.
   227  			if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize {
   228  				e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress
   229  			}
   230  		}
   231  		// Reset the scoreboard to reinitialize the sack information as
   232  		// we do not restore SACK information.
   233  		e.scoreboard.Reset()
   234  		err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning)
   235  		if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   236  			panic("endpoint connecting failed: " + err.String())
   237  		}
   238  		e.mu.Lock()
   239  		e.state = e.origEndpointState
   240  		closed := e.closed
   241  		e.mu.Unlock()
   242  		e.notifyProtocolGoroutine(notifyTickleWorker)
   243  		if epState == StateFinWait2 && closed {
   244  			// If the endpoint has been closed then make sure we notify so
   245  			// that the FIN_WAIT2 timer is started after a restore.
   246  			e.notifyProtocolGoroutine(notifyClose)
   247  		}
   248  		connectedLoading.Done()
   249  	case epState == StateListen:
   250  		tcpip.AsyncLoading.Add(1)
   251  		go func() {
   252  			connectedLoading.Wait()
   253  			bind()
   254  			e.acceptMu.Lock()
   255  			backlog := e.acceptQueue.capacity
   256  			e.acceptMu.Unlock()
   257  			if err := e.Listen(backlog); err != nil {
   258  				panic("endpoint listening failed: " + err.String())
   259  			}
   260  			e.LockUser()
   261  			if e.shutdownFlags != 0 {
   262  				e.shutdownLocked(e.shutdownFlags)
   263  			}
   264  			e.UnlockUser()
   265  			listenLoading.Done()
   266  			tcpip.AsyncLoading.Done()
   267  		}()
   268  	case epState.connecting():
   269  		tcpip.AsyncLoading.Add(1)
   270  		go func() {
   271  			connectedLoading.Wait()
   272  			listenLoading.Wait()
   273  			bind()
   274  			err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort})
   275  			if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   276  				panic("endpoint connecting failed: " + err.String())
   277  			}
   278  			connectingLoading.Done()
   279  			tcpip.AsyncLoading.Done()
   280  		}()
   281  	case epState == StateBound:
   282  		tcpip.AsyncLoading.Add(1)
   283  		go func() {
   284  			connectedLoading.Wait()
   285  			listenLoading.Wait()
   286  			connectingLoading.Wait()
   287  			bind()
   288  			tcpip.AsyncLoading.Done()
   289  		}()
   290  	case epState == StateClose:
   291  		e.isPortReserved = false
   292  		e.state = uint32(StateClose)
   293  		e.stack.CompleteTransportEndpointCleanup(e)
   294  		tcpip.DeleteDanglingEndpoint(e)
   295  	case epState == StateError:
   296  		e.state = uint32(StateError)
   297  		e.stack.CompleteTransportEndpointCleanup(e)
   298  		tcpip.DeleteDanglingEndpoint(e)
   299  	}
   300  }