github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/transport/tcp/endpoint_state.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "context" 19 "fmt" 20 21 "github.com/sagernet/gvisor/pkg/atomicbitops" 22 "github.com/sagernet/gvisor/pkg/sync" 23 "github.com/sagernet/gvisor/pkg/tcpip" 24 "github.com/sagernet/gvisor/pkg/tcpip/header" 25 "github.com/sagernet/gvisor/pkg/tcpip/ports" 26 "github.com/sagernet/gvisor/pkg/tcpip/stack" 27 ) 28 29 // beforeSave is invoked by stateify. 30 func (e *Endpoint) beforeSave() { 31 // Stop incoming packets. 32 e.segmentQueue.freeze() 33 34 e.mu.Lock() 35 defer e.mu.Unlock() 36 37 epState := e.EndpointState() 38 switch { 39 case epState == StateInitial || epState == StateBound: 40 case epState.connected() || epState.handshake(): 41 if !e.route.HasSaveRestoreCapability() { 42 if !e.route.HasDisconnectOkCapability() { 43 panic(&tcpip.ErrSaveRejection{ 44 Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort), 45 }) 46 } 47 e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) 48 e.mu.Unlock() 49 e.Close() 50 e.mu.Lock() 51 } 52 fallthrough 53 case epState == StateListen: 54 // Nothing to do. 55 case epState.closed(): 56 // Nothing to do. 57 default: 58 panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) 59 } 60 61 e.stack.RegisterResumableEndpoint(e) 62 } 63 64 // saveEndpoints is invoked by stateify. 65 func (a *acceptQueue) saveEndpoints() []*Endpoint { 66 acceptedEndpoints := make([]*Endpoint, a.endpoints.Len()) 67 for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { 68 acceptedEndpoints[i] = e.Value.(*Endpoint) 69 } 70 return acceptedEndpoints 71 } 72 73 // loadEndpoints is invoked by stateify. 74 func (a *acceptQueue) loadEndpoints(_ context.Context, acceptedEndpoints []*Endpoint) { 75 for _, ep := range acceptedEndpoints { 76 a.endpoints.PushBack(ep) 77 } 78 } 79 80 // saveState is invoked by stateify. 81 func (e *Endpoint) saveState() EndpointState { 82 return e.EndpointState() 83 } 84 85 // Endpoint loading must be done in the following ordering by their state, to 86 // avoid dangling connecting w/o listening peer, and to avoid conflicts in port 87 // reservation. 88 var connectedLoading sync.WaitGroup 89 var listenLoading sync.WaitGroup 90 var connectingLoading sync.WaitGroup 91 92 // Bound endpoint loading happens last. 93 94 // loadState is invoked by stateify. 95 func (e *Endpoint) loadState(_ context.Context, epState EndpointState) { 96 // This is to ensure that the loading wait groups include all applicable 97 // endpoints before any asynchronous calls to the Wait() methods. 98 // For restore purposes we treat TimeWait like a connected endpoint. 99 if epState.connected() || epState == StateTimeWait { 100 connectedLoading.Add(1) 101 } 102 switch { 103 case epState == StateListen: 104 listenLoading.Add(1) 105 case epState.connecting(): 106 connectingLoading.Add(1) 107 } 108 // Directly update the state here rather than using e.setEndpointState 109 // as the endpoint is still being loaded and the stack reference is not 110 // yet initialized. 111 e.state.Store(uint32(epState)) 112 } 113 114 // afterLoad is invoked by stateify. 115 func (e *Endpoint) afterLoad(ctx context.Context) { 116 // RacyLoad() can be used because we are initializing e. 117 e.origEndpointState = e.state.RacyLoad() 118 // Restore the endpoint to InitialState as it will be moved to 119 // its origEndpointState during Restore. 120 e.state = atomicbitops.FromUint32(uint32(StateInitial)) 121 stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e) 122 } 123 124 // Restore implements tcpip.RestoredEndpoint.Restore. 125 func (e *Endpoint) Restore(s *stack.Stack) { 126 if !e.EndpointState().closed() { 127 e.keepalive.timer.init(s.Clock(), timerHandler(e, e.keepaliveTimerExpired)) 128 } 129 if snd := e.snd; snd != nil { 130 snd.resendTimer.init(s.Clock(), timerHandler(e, e.snd.retransmitTimerExpired)) 131 snd.reorderTimer.init(s.Clock(), timerHandler(e, e.snd.rc.reorderTimerExpired)) 132 snd.probeTimer.init(s.Clock(), timerHandler(e, e.snd.probeTimerExpired)) 133 snd.corkTimer.init(s.Clock(), timerHandler(e, e.snd.corkTimerExpired)) 134 } 135 e.stack = s 136 e.protocol = protocolFromStack(s) 137 e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) 138 e.segmentQueue.thaw() 139 140 bind := func() { 141 e.mu.Lock() 142 defer e.mu.Unlock() 143 addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}) 144 if err != nil { 145 panic("unable to parse BindAddr: " + err.String()) 146 } 147 portRes := ports.Reservation{ 148 Networks: e.effectiveNetProtos, 149 Transport: ProtocolNumber, 150 Addr: addr.Addr, 151 Port: addr.Port, 152 Flags: e.boundPortFlags, 153 BindToDevice: e.boundBindToDevice, 154 Dest: e.boundDest, 155 } 156 if ok := e.stack.ReserveTuple(portRes); !ok { 157 panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) 158 } 159 e.isPortReserved = true 160 161 // Mark endpoint as bound. 162 e.setEndpointState(StateBound) 163 } 164 165 epState := EndpointState(e.origEndpointState) 166 switch { 167 case epState.connected(): 168 bind() 169 if e.connectingAddress.BitLen() == 0 { 170 e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress 171 // This endpoint is accepted by netstack but not yet by 172 // the app. If the endpoint is IPv6 but the remote 173 // address is IPv4, we need to connect as IPv6 so that 174 // dual-stack mode can be properly activated. 175 if e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.RemoteAddress.BitLen() != header.IPv6AddressSizeBits { 176 e.connectingAddress = tcpip.AddrFrom16Slice(append( 177 []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff}, 178 e.TransportEndpointInfo.ID.RemoteAddress.AsSlice()..., 179 )) 180 } 181 } 182 // Reset the scoreboard to reinitialize the sack information as 183 // we do not restore SACK information. 184 e.scoreboard.Reset() 185 e.mu.Lock() 186 err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false /* handshake */) 187 if _, ok := err.(*tcpip.ErrConnectStarted); !ok { 188 panic("endpoint connecting failed: " + err.String()) 189 } 190 e.state.Store(e.origEndpointState) 191 // For FIN-WAIT-2 and TIME-WAIT we need to start the appropriate timers so 192 // that the socket is closed correctly. 193 switch epState { 194 case StateFinWait2: 195 e.finWait2Timer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, e.finWait2TimerExpired) 196 case StateTimeWait: 197 e.timeWaitTimer = e.stack.Clock().AfterFunc(e.getTimeWaitDuration(), e.timeWaitTimerExpired) 198 } 199 200 if e.ops.GetCorkOption() { 201 // Rearm the timer if TCP_CORK is enabled which will 202 // drain all the segments in the queue after restore. 203 e.snd.corkTimer.enable(MinRTO) 204 } 205 e.mu.Unlock() 206 connectedLoading.Done() 207 case epState == StateListen: 208 tcpip.AsyncLoading.Add(1) 209 go func() { 210 connectedLoading.Wait() 211 bind() 212 e.acceptMu.Lock() 213 backlog := e.acceptQueue.capacity 214 e.acceptMu.Unlock() 215 if err := e.Listen(backlog); err != nil { 216 panic("endpoint listening failed: " + err.String()) 217 } 218 e.LockUser() 219 if e.shutdownFlags != 0 { 220 e.shutdownLocked(e.shutdownFlags) 221 } 222 e.UnlockUser() 223 listenLoading.Done() 224 tcpip.AsyncLoading.Done() 225 }() 226 case epState == StateConnecting: 227 // Initial SYN hasn't been sent yet so initiate a connect. 228 tcpip.AsyncLoading.Add(1) 229 go func() { 230 connectedLoading.Wait() 231 listenLoading.Wait() 232 bind() 233 err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}) 234 if _, ok := err.(*tcpip.ErrConnectStarted); !ok { 235 panic("endpoint connecting failed: " + err.String()) 236 } 237 connectingLoading.Done() 238 tcpip.AsyncLoading.Done() 239 }() 240 case epState == StateSynSent || epState == StateSynRecv: 241 connectedLoading.Wait() 242 listenLoading.Wait() 243 // Initial SYN has been sent/received so we should bind the 244 // ports start the retransmit timer for the SYNs and let it 245 // naturally complete the connection. 246 bind() 247 e.mu.Lock() 248 defer e.mu.Unlock() 249 e.setEndpointState(epState) 250 r, err := e.stack.FindRoute(e.boundNICID, e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.RemoteAddress, e.effectiveNetProtos[0], false /* multicastLoop */) 251 if err != nil { 252 panic(fmt.Sprintf("FindRoute failed when restoring endpoint w/ ID: %+v", e.ID)) 253 } 254 e.route = r 255 timer, err := newBackoffTimer(e.stack.Clock(), InitialRTO, MaxRTO, timerHandler(e, e.h.retransmitHandlerLocked)) 256 if err != nil { 257 panic(fmt.Sprintf("newBackOffTimer(_, %s, %s, _) failed: %s", InitialRTO, MaxRTO, err)) 258 } 259 e.h.retransmitTimer = timer 260 connectingLoading.Done() 261 case epState == StateBound: 262 tcpip.AsyncLoading.Add(1) 263 go func() { 264 connectedLoading.Wait() 265 listenLoading.Wait() 266 connectingLoading.Wait() 267 bind() 268 tcpip.AsyncLoading.Done() 269 }() 270 case epState == StateClose: 271 e.isPortReserved = false 272 e.state.Store(uint32(StateClose)) 273 e.stack.CompleteTransportEndpointCleanup(e) 274 tcpip.DeleteDanglingEndpoint(e) 275 case epState == StateError: 276 e.state.Store(uint32(StateError)) 277 e.stack.CompleteTransportEndpointCleanup(e) 278 tcpip.DeleteDanglingEndpoint(e) 279 } 280 } 281 282 // Resume implements tcpip.ResumableEndpoint.Resume. 283 func (e *Endpoint) Resume() { 284 e.segmentQueue.thaw() 285 }