inet.af/netstack@v0.0.0-20220214151720-7585b01ddccf/tcpip/transport/tcp/endpoint_state.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "fmt" 19 "sync/atomic" 20 "time" 21 22 "inet.af/netstack/sync" 23 "inet.af/netstack/tcpip" 24 "inet.af/netstack/tcpip/header" 25 "inet.af/netstack/tcpip/ports" 26 "inet.af/netstack/tcpip/stack" 27 ) 28 29 // +checklocks:e.mu 30 func (e *endpoint) drainSegmentLocked() { 31 // Drain only up to once. 32 if e.drainDone != nil { 33 return 34 } 35 36 e.drainDone = make(chan struct{}) 37 e.undrain = make(chan struct{}) 38 e.mu.Unlock() 39 40 e.notifyProtocolGoroutine(notifyDrain) 41 <-e.drainDone 42 43 e.mu.Lock() 44 } 45 46 // beforeSave is invoked by stateify. 47 func (e *endpoint) beforeSave() { 48 // Stop incoming packets. 49 e.segmentQueue.freeze() 50 51 e.mu.Lock() 52 defer e.mu.Unlock() 53 54 epState := e.EndpointState() 55 switch { 56 case epState == StateInitial || epState == StateBound: 57 case epState.connected() || epState.handshake(): 58 if !e.route.HasSaveRestoreCapability() { 59 if !e.route.HasDisconncetOkCapability() { 60 panic(&tcpip.ErrSaveRejection{ 61 Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort), 62 }) 63 } 64 e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) 65 e.mu.Unlock() 66 e.Close() 67 e.mu.Lock() 68 } 69 if !e.workerRunning { 70 // The endpoint must be in the accepted queue or has been just 71 // disconnected and closed. 72 break 73 } 74 fallthrough 75 case epState == StateListen || epState == StateConnecting: 76 e.drainSegmentLocked() 77 // Refresh epState, since drainSegmentLocked may have changed it. 78 epState = e.EndpointState() 79 if !epState.closed() { 80 if !e.workerRunning { 81 panic("endpoint has no worker running in listen, connecting, or connected state") 82 } 83 } 84 case epState.closed(): 85 for e.workerRunning { 86 e.mu.Unlock() 87 time.Sleep(100 * time.Millisecond) 88 e.mu.Lock() 89 } 90 if e.workerRunning { 91 panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID)) 92 } 93 default: 94 panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) 95 } 96 97 if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { 98 panic("endpoint still has waiters upon save") 99 } 100 } 101 102 // saveEndpoints is invoked by stateify. 103 func (a *acceptQueue) saveEndpoints() []*endpoint { 104 acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) 105 for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { 106 acceptedEndpoints[i] = e.Value.(*endpoint) 107 } 108 return acceptedEndpoints 109 } 110 111 // loadEndpoints is invoked by stateify. 112 func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) { 113 for _, ep := range acceptedEndpoints { 114 a.endpoints.PushBack(ep) 115 } 116 } 117 118 // saveState is invoked by stateify. 119 func (e *endpoint) saveState() EndpointState { 120 return e.EndpointState() 121 } 122 123 // Endpoint loading must be done in the following ordering by their state, to 124 // avoid dangling connecting w/o listening peer, and to avoid conflicts in port 125 // reservation. 126 var connectedLoading sync.WaitGroup 127 var listenLoading sync.WaitGroup 128 var connectingLoading sync.WaitGroup 129 130 // Bound endpoint loading happens last. 131 132 // loadState is invoked by stateify. 133 func (e *endpoint) loadState(epState EndpointState) { 134 // This is to ensure that the loading wait groups include all applicable 135 // endpoints before any asynchronous calls to the Wait() methods. 136 // For restore purposes we treat TimeWait like a connected endpoint. 137 if epState.connected() || epState == StateTimeWait { 138 connectedLoading.Add(1) 139 } 140 switch { 141 case epState == StateListen: 142 listenLoading.Add(1) 143 case epState.connecting(): 144 connectingLoading.Add(1) 145 } 146 // Directly update the state here rather than using e.setEndpointState 147 // as the endpoint is still being loaded and the stack reference is not 148 // yet initialized. 149 atomic.StoreUint32((*uint32)(&e.state), uint32(epState)) 150 } 151 152 // afterLoad is invoked by stateify. 153 func (e *endpoint) afterLoad() { 154 e.origEndpointState = e.state 155 // Restore the endpoint to InitialState as it will be moved to 156 // its origEndpointState during Resume. 157 e.state = uint32(StateInitial) 158 // Condition variables and mutexs are not S/R'ed so reinitialize 159 // acceptCond with e.acceptMu. 160 e.acceptCond = sync.NewCond(&e.acceptMu) 161 stack.StackFromEnv.RegisterRestoredEndpoint(e) 162 } 163 164 // Resume implements tcpip.ResumableEndpoint.Resume. 165 func (e *endpoint) Resume(s *stack.Stack) { 166 e.keepalive.timer.init(s.Clock(), &e.keepalive.waker) 167 if snd := e.snd; snd != nil { 168 snd.resendTimer.init(s.Clock(), &snd.resendWaker) 169 snd.reorderTimer.init(s.Clock(), &snd.reorderWaker) 170 snd.probeTimer.init(s.Clock(), &snd.probeWaker) 171 } 172 e.stack = s 173 e.protocol = protocolFromStack(s) 174 e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) 175 e.segmentQueue.thaw() 176 epState := EndpointState(e.origEndpointState) 177 switch epState { 178 case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: 179 var ss tcpip.TCPSendBufferSizeRangeOption 180 if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { 181 sendBufferSize := e.getSendBufferSize() 182 if sendBufferSize < ss.Min || sendBufferSize > ss.Max { 183 panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max)) 184 } 185 } 186 187 var rs tcpip.TCPReceiveBufferSizeRangeOption 188 if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { 189 if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) { 190 panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max)) 191 } 192 } 193 } 194 195 bind := func() { 196 addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}) 197 if err != nil { 198 panic("unable to parse BindAddr: " + err.String()) 199 } 200 portRes := ports.Reservation{ 201 Networks: e.effectiveNetProtos, 202 Transport: ProtocolNumber, 203 Addr: addr.Addr, 204 Port: addr.Port, 205 Flags: e.boundPortFlags, 206 BindToDevice: e.boundBindToDevice, 207 Dest: e.boundDest, 208 } 209 if ok := e.stack.ReserveTuple(portRes); !ok { 210 panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) 211 } 212 e.isPortReserved = true 213 214 // Mark endpoint as bound. 215 e.setEndpointState(StateBound) 216 } 217 218 switch { 219 case epState.connected(): 220 bind() 221 if len(e.connectingAddress) == 0 { 222 e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress 223 // This endpoint is accepted by netstack but not yet by 224 // the app. If the endpoint is IPv6 but the remote 225 // address is IPv4, we need to connect as IPv6 so that 226 // dual-stack mode can be properly activated. 227 if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize { 228 e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress 229 } 230 } 231 // Reset the scoreboard to reinitialize the sack information as 232 // we do not restore SACK information. 233 e.scoreboard.Reset() 234 err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning) 235 if _, ok := err.(*tcpip.ErrConnectStarted); !ok { 236 panic("endpoint connecting failed: " + err.String()) 237 } 238 e.mu.Lock() 239 e.state = e.origEndpointState 240 closed := e.closed 241 e.mu.Unlock() 242 e.notifyProtocolGoroutine(notifyTickleWorker) 243 if epState == StateFinWait2 && closed { 244 // If the endpoint has been closed then make sure we notify so 245 // that the FIN_WAIT2 timer is started after a restore. 246 e.notifyProtocolGoroutine(notifyClose) 247 } 248 connectedLoading.Done() 249 case epState == StateListen: 250 tcpip.AsyncLoading.Add(1) 251 go func() { 252 connectedLoading.Wait() 253 bind() 254 e.acceptMu.Lock() 255 backlog := e.acceptQueue.capacity 256 e.acceptMu.Unlock() 257 if err := e.Listen(backlog); err != nil { 258 panic("endpoint listening failed: " + err.String()) 259 } 260 e.LockUser() 261 if e.shutdownFlags != 0 { 262 e.shutdownLocked(e.shutdownFlags) 263 } 264 e.UnlockUser() 265 listenLoading.Done() 266 tcpip.AsyncLoading.Done() 267 }() 268 case epState.connecting(): 269 tcpip.AsyncLoading.Add(1) 270 go func() { 271 connectedLoading.Wait() 272 listenLoading.Wait() 273 bind() 274 err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}) 275 if _, ok := err.(*tcpip.ErrConnectStarted); !ok { 276 panic("endpoint connecting failed: " + err.String()) 277 } 278 connectingLoading.Done() 279 tcpip.AsyncLoading.Done() 280 }() 281 case epState == StateBound: 282 tcpip.AsyncLoading.Add(1) 283 go func() { 284 connectedLoading.Wait() 285 listenLoading.Wait() 286 connectingLoading.Wait() 287 bind() 288 tcpip.AsyncLoading.Done() 289 }() 290 case epState == StateClose: 291 e.isPortReserved = false 292 e.state = uint32(StateClose) 293 e.stack.CompleteTransportEndpointCleanup(e) 294 tcpip.DeleteDanglingEndpoint(e) 295 case epState == StateError: 296 e.state = uint32(StateError) 297 e.stack.CompleteTransportEndpointCleanup(e) 298 tcpip.DeleteDanglingEndpoint(e) 299 } 300 }