github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/tcp/endpoint_state.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"fmt"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"github.com/SagerNet/gvisor/pkg/sync"
    23  	"github.com/SagerNet/gvisor/pkg/tcpip"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/ports"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    27  )
    28  
    29  // +checklocks:e.mu
    30  func (e *endpoint) drainSegmentLocked() {
    31  	// Drain only up to once.
    32  	if e.drainDone != nil {
    33  		return
    34  	}
    35  
    36  	e.drainDone = make(chan struct{})
    37  	e.undrain = make(chan struct{})
    38  	e.mu.Unlock()
    39  
    40  	e.notifyProtocolGoroutine(notifyDrain)
    41  	<-e.drainDone
    42  
    43  	e.mu.Lock()
    44  }
    45  
    46  // beforeSave is invoked by stateify.
    47  func (e *endpoint) beforeSave() {
    48  	// Stop incoming packets.
    49  	e.segmentQueue.freeze()
    50  
    51  	e.mu.Lock()
    52  	defer e.mu.Unlock()
    53  
    54  	epState := e.EndpointState()
    55  	switch {
    56  	case epState == StateInitial || epState == StateBound:
    57  	case epState.connected() || epState.handshake():
    58  		if !e.route.HasSaveRestoreCapability() {
    59  			if !e.route.HasDisconncetOkCapability() {
    60  				panic(&tcpip.ErrSaveRejection{
    61  					Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort),
    62  				})
    63  			}
    64  			e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
    65  			e.mu.Unlock()
    66  			e.Close()
    67  			e.mu.Lock()
    68  		}
    69  		if !e.workerRunning {
    70  			// The endpoint must be in the accepted queue or has been just
    71  			// disconnected and closed.
    72  			break
    73  		}
    74  		fallthrough
    75  	case epState == StateListen || epState == StateConnecting:
    76  		e.drainSegmentLocked()
    77  		// Refresh epState, since drainSegmentLocked may have changed it.
    78  		epState = e.EndpointState()
    79  		if !epState.closed() {
    80  			if !e.workerRunning {
    81  				panic("endpoint has no worker running in listen, connecting, or connected state")
    82  			}
    83  		}
    84  	case epState.closed():
    85  		for e.workerRunning {
    86  			e.mu.Unlock()
    87  			time.Sleep(100 * time.Millisecond)
    88  			e.mu.Lock()
    89  		}
    90  		if e.workerRunning {
    91  			panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID))
    92  		}
    93  	default:
    94  		panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
    95  	}
    96  
    97  	if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
    98  		panic("endpoint still has waiters upon save")
    99  	}
   100  }
   101  
   102  // saveEndpoints is invoked by stateify.
   103  func (a *accepted) saveEndpoints() []*endpoint {
   104  	acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
   105  	for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
   106  		acceptedEndpoints[i] = e.Value.(*endpoint)
   107  	}
   108  	return acceptedEndpoints
   109  }
   110  
   111  // loadEndpoints is invoked by stateify.
   112  func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
   113  	for _, ep := range acceptedEndpoints {
   114  		a.endpoints.PushBack(ep)
   115  	}
   116  }
   117  
   118  // saveState is invoked by stateify.
   119  func (e *endpoint) saveState() EndpointState {
   120  	return e.EndpointState()
   121  }
   122  
   123  // Endpoint loading must be done in the following ordering by their state, to
   124  // avoid dangling connecting w/o listening peer, and to avoid conflicts in port
   125  // reservation.
   126  var connectedLoading sync.WaitGroup
   127  var listenLoading sync.WaitGroup
   128  var connectingLoading sync.WaitGroup
   129  
   130  // Bound endpoint loading happens last.
   131  
   132  // loadState is invoked by stateify.
   133  func (e *endpoint) loadState(epState EndpointState) {
   134  	// This is to ensure that the loading wait groups include all applicable
   135  	// endpoints before any asynchronous calls to the Wait() methods.
   136  	// For restore purposes we treat TimeWait like a connected endpoint.
   137  	if epState.connected() || epState == StateTimeWait {
   138  		connectedLoading.Add(1)
   139  	}
   140  	switch {
   141  	case epState == StateListen:
   142  		listenLoading.Add(1)
   143  	case epState.connecting():
   144  		connectingLoading.Add(1)
   145  	}
   146  	// Directly update the state here rather than using e.setEndpointState
   147  	// as the endpoint is still being loaded and the stack reference is not
   148  	// yet initialized.
   149  	atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
   150  }
   151  
   152  // afterLoad is invoked by stateify.
   153  func (e *endpoint) afterLoad() {
   154  	e.origEndpointState = e.state
   155  	// Restore the endpoint to InitialState as it will be moved to
   156  	// its origEndpointState during Resume.
   157  	e.state = uint32(StateInitial)
   158  	// Condition variables and mutexs are not S/R'ed so reinitialize
   159  	// acceptCond with e.acceptMu.
   160  	e.acceptCond = sync.NewCond(&e.acceptMu)
   161  	stack.StackFromEnv.RegisterRestoredEndpoint(e)
   162  }
   163  
   164  // Resume implements tcpip.ResumableEndpoint.Resume.
   165  func (e *endpoint) Resume(s *stack.Stack) {
   166  	e.keepalive.timer.init(s.Clock(), &e.keepalive.waker)
   167  	if snd := e.snd; snd != nil {
   168  		snd.resendTimer.init(s.Clock(), &snd.resendWaker)
   169  		snd.reorderTimer.init(s.Clock(), &snd.reorderWaker)
   170  		snd.probeTimer.init(s.Clock(), &snd.probeWaker)
   171  	}
   172  	e.stack = s
   173  	e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
   174  	e.segmentQueue.thaw()
   175  	epState := EndpointState(e.origEndpointState)
   176  	switch epState {
   177  	case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
   178  		var ss tcpip.TCPSendBufferSizeRangeOption
   179  		if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
   180  			sendBufferSize := e.getSendBufferSize()
   181  			if sendBufferSize < ss.Min || sendBufferSize > ss.Max {
   182  				panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max))
   183  			}
   184  		}
   185  
   186  		var rs tcpip.TCPReceiveBufferSizeRangeOption
   187  		if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
   188  			if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) {
   189  				panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max))
   190  			}
   191  		}
   192  	}
   193  
   194  	bind := func() {
   195  		addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort})
   196  		if err != nil {
   197  			panic("unable to parse BindAddr: " + err.String())
   198  		}
   199  		portRes := ports.Reservation{
   200  			Networks:     e.effectiveNetProtos,
   201  			Transport:    ProtocolNumber,
   202  			Addr:         addr.Addr,
   203  			Port:         addr.Port,
   204  			Flags:        e.boundPortFlags,
   205  			BindToDevice: e.boundBindToDevice,
   206  			Dest:         e.boundDest,
   207  		}
   208  		if ok := e.stack.ReserveTuple(portRes); !ok {
   209  			panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
   210  		}
   211  		e.isPortReserved = true
   212  
   213  		// Mark endpoint as bound.
   214  		e.setEndpointState(StateBound)
   215  	}
   216  
   217  	switch {
   218  	case epState.connected():
   219  		bind()
   220  		if len(e.connectingAddress) == 0 {
   221  			e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress
   222  			// This endpoint is accepted by netstack but not yet by
   223  			// the app. If the endpoint is IPv6 but the remote
   224  			// address is IPv4, we need to connect as IPv6 so that
   225  			// dual-stack mode can be properly activated.
   226  			if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize {
   227  				e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress
   228  			}
   229  		}
   230  		// Reset the scoreboard to reinitialize the sack information as
   231  		// we do not restore SACK information.
   232  		e.scoreboard.Reset()
   233  		err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning)
   234  		if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   235  			panic("endpoint connecting failed: " + err.String())
   236  		}
   237  		e.mu.Lock()
   238  		e.state = e.origEndpointState
   239  		closed := e.closed
   240  		e.mu.Unlock()
   241  		e.notifyProtocolGoroutine(notifyTickleWorker)
   242  		if epState == StateFinWait2 && closed {
   243  			// If the endpoint has been closed then make sure we notify so
   244  			// that the FIN_WAIT2 timer is started after a restore.
   245  			e.notifyProtocolGoroutine(notifyClose)
   246  		}
   247  		connectedLoading.Done()
   248  	case epState == StateListen:
   249  		tcpip.AsyncLoading.Add(1)
   250  		go func() {
   251  			connectedLoading.Wait()
   252  			bind()
   253  			backlog := e.accepted.cap
   254  			if err := e.Listen(backlog); err != nil {
   255  				panic("endpoint listening failed: " + err.String())
   256  			}
   257  			e.LockUser()
   258  			if e.shutdownFlags != 0 {
   259  				e.shutdownLocked(e.shutdownFlags)
   260  			}
   261  			e.UnlockUser()
   262  			listenLoading.Done()
   263  			tcpip.AsyncLoading.Done()
   264  		}()
   265  	case epState.connecting():
   266  		tcpip.AsyncLoading.Add(1)
   267  		go func() {
   268  			connectedLoading.Wait()
   269  			listenLoading.Wait()
   270  			bind()
   271  			err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort})
   272  			if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   273  				panic("endpoint connecting failed: " + err.String())
   274  			}
   275  			connectingLoading.Done()
   276  			tcpip.AsyncLoading.Done()
   277  		}()
   278  	case epState == StateBound:
   279  		tcpip.AsyncLoading.Add(1)
   280  		go func() {
   281  			connectedLoading.Wait()
   282  			listenLoading.Wait()
   283  			connectingLoading.Wait()
   284  			bind()
   285  			tcpip.AsyncLoading.Done()
   286  		}()
   287  	case epState == StateClose:
   288  		e.isPortReserved = false
   289  		e.state = uint32(StateClose)
   290  		e.stack.CompleteTransportEndpointCleanup(e)
   291  		tcpip.DeleteDanglingEndpoint(e)
   292  	case epState == StateError:
   293  		e.state = uint32(StateError)
   294  		e.stack.CompleteTransportEndpointCleanup(e)
   295  		tcpip.DeleteDanglingEndpoint(e)
   296  	}
   297  }