github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/tcp/endpoint_state.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "fmt" 19 "sync/atomic" 20 "time" 21 22 "github.com/SagerNet/gvisor/pkg/sync" 23 "github.com/SagerNet/gvisor/pkg/tcpip" 24 "github.com/SagerNet/gvisor/pkg/tcpip/header" 25 "github.com/SagerNet/gvisor/pkg/tcpip/ports" 26 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 27 ) 28 29 // +checklocks:e.mu 30 func (e *endpoint) drainSegmentLocked() { 31 // Drain only up to once. 32 if e.drainDone != nil { 33 return 34 } 35 36 e.drainDone = make(chan struct{}) 37 e.undrain = make(chan struct{}) 38 e.mu.Unlock() 39 40 e.notifyProtocolGoroutine(notifyDrain) 41 <-e.drainDone 42 43 e.mu.Lock() 44 } 45 46 // beforeSave is invoked by stateify. 47 func (e *endpoint) beforeSave() { 48 // Stop incoming packets. 49 e.segmentQueue.freeze() 50 51 e.mu.Lock() 52 defer e.mu.Unlock() 53 54 epState := e.EndpointState() 55 switch { 56 case epState == StateInitial || epState == StateBound: 57 case epState.connected() || epState.handshake(): 58 if !e.route.HasSaveRestoreCapability() { 59 if !e.route.HasDisconncetOkCapability() { 60 panic(&tcpip.ErrSaveRejection{ 61 Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort), 62 }) 63 } 64 e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) 65 e.mu.Unlock() 66 e.Close() 67 e.mu.Lock() 68 } 69 if !e.workerRunning { 70 // The endpoint must be in the accepted queue or has been just 71 // disconnected and closed. 72 break 73 } 74 fallthrough 75 case epState == StateListen || epState == StateConnecting: 76 e.drainSegmentLocked() 77 // Refresh epState, since drainSegmentLocked may have changed it. 78 epState = e.EndpointState() 79 if !epState.closed() { 80 if !e.workerRunning { 81 panic("endpoint has no worker running in listen, connecting, or connected state") 82 } 83 } 84 case epState.closed(): 85 for e.workerRunning { 86 e.mu.Unlock() 87 time.Sleep(100 * time.Millisecond) 88 e.mu.Lock() 89 } 90 if e.workerRunning { 91 panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID)) 92 } 93 default: 94 panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) 95 } 96 97 if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { 98 panic("endpoint still has waiters upon save") 99 } 100 } 101 102 // saveEndpoints is invoked by stateify. 103 func (a *accepted) saveEndpoints() []*endpoint { 104 acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) 105 for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { 106 acceptedEndpoints[i] = e.Value.(*endpoint) 107 } 108 return acceptedEndpoints 109 } 110 111 // loadEndpoints is invoked by stateify. 112 func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { 113 for _, ep := range acceptedEndpoints { 114 a.endpoints.PushBack(ep) 115 } 116 } 117 118 // saveState is invoked by stateify. 119 func (e *endpoint) saveState() EndpointState { 120 return e.EndpointState() 121 } 122 123 // Endpoint loading must be done in the following ordering by their state, to 124 // avoid dangling connecting w/o listening peer, and to avoid conflicts in port 125 // reservation. 126 var connectedLoading sync.WaitGroup 127 var listenLoading sync.WaitGroup 128 var connectingLoading sync.WaitGroup 129 130 // Bound endpoint loading happens last. 131 132 // loadState is invoked by stateify. 133 func (e *endpoint) loadState(epState EndpointState) { 134 // This is to ensure that the loading wait groups include all applicable 135 // endpoints before any asynchronous calls to the Wait() methods. 136 // For restore purposes we treat TimeWait like a connected endpoint. 137 if epState.connected() || epState == StateTimeWait { 138 connectedLoading.Add(1) 139 } 140 switch { 141 case epState == StateListen: 142 listenLoading.Add(1) 143 case epState.connecting(): 144 connectingLoading.Add(1) 145 } 146 // Directly update the state here rather than using e.setEndpointState 147 // as the endpoint is still being loaded and the stack reference is not 148 // yet initialized. 149 atomic.StoreUint32((*uint32)(&e.state), uint32(epState)) 150 } 151 152 // afterLoad is invoked by stateify. 153 func (e *endpoint) afterLoad() { 154 e.origEndpointState = e.state 155 // Restore the endpoint to InitialState as it will be moved to 156 // its origEndpointState during Resume. 157 e.state = uint32(StateInitial) 158 // Condition variables and mutexs are not S/R'ed so reinitialize 159 // acceptCond with e.acceptMu. 160 e.acceptCond = sync.NewCond(&e.acceptMu) 161 stack.StackFromEnv.RegisterRestoredEndpoint(e) 162 } 163 164 // Resume implements tcpip.ResumableEndpoint.Resume. 165 func (e *endpoint) Resume(s *stack.Stack) { 166 e.keepalive.timer.init(s.Clock(), &e.keepalive.waker) 167 if snd := e.snd; snd != nil { 168 snd.resendTimer.init(s.Clock(), &snd.resendWaker) 169 snd.reorderTimer.init(s.Clock(), &snd.reorderWaker) 170 snd.probeTimer.init(s.Clock(), &snd.probeWaker) 171 } 172 e.stack = s 173 e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) 174 e.segmentQueue.thaw() 175 epState := EndpointState(e.origEndpointState) 176 switch epState { 177 case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: 178 var ss tcpip.TCPSendBufferSizeRangeOption 179 if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { 180 sendBufferSize := e.getSendBufferSize() 181 if sendBufferSize < ss.Min || sendBufferSize > ss.Max { 182 panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max)) 183 } 184 } 185 186 var rs tcpip.TCPReceiveBufferSizeRangeOption 187 if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { 188 if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) { 189 panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max)) 190 } 191 } 192 } 193 194 bind := func() { 195 addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}) 196 if err != nil { 197 panic("unable to parse BindAddr: " + err.String()) 198 } 199 portRes := ports.Reservation{ 200 Networks: e.effectiveNetProtos, 201 Transport: ProtocolNumber, 202 Addr: addr.Addr, 203 Port: addr.Port, 204 Flags: e.boundPortFlags, 205 BindToDevice: e.boundBindToDevice, 206 Dest: e.boundDest, 207 } 208 if ok := e.stack.ReserveTuple(portRes); !ok { 209 panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) 210 } 211 e.isPortReserved = true 212 213 // Mark endpoint as bound. 214 e.setEndpointState(StateBound) 215 } 216 217 switch { 218 case epState.connected(): 219 bind() 220 if len(e.connectingAddress) == 0 { 221 e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress 222 // This endpoint is accepted by netstack but not yet by 223 // the app. If the endpoint is IPv6 but the remote 224 // address is IPv4, we need to connect as IPv6 so that 225 // dual-stack mode can be properly activated. 226 if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize { 227 e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress 228 } 229 } 230 // Reset the scoreboard to reinitialize the sack information as 231 // we do not restore SACK information. 232 e.scoreboard.Reset() 233 err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning) 234 if _, ok := err.(*tcpip.ErrConnectStarted); !ok { 235 panic("endpoint connecting failed: " + err.String()) 236 } 237 e.mu.Lock() 238 e.state = e.origEndpointState 239 closed := e.closed 240 e.mu.Unlock() 241 e.notifyProtocolGoroutine(notifyTickleWorker) 242 if epState == StateFinWait2 && closed { 243 // If the endpoint has been closed then make sure we notify so 244 // that the FIN_WAIT2 timer is started after a restore. 245 e.notifyProtocolGoroutine(notifyClose) 246 } 247 connectedLoading.Done() 248 case epState == StateListen: 249 tcpip.AsyncLoading.Add(1) 250 go func() { 251 connectedLoading.Wait() 252 bind() 253 backlog := e.accepted.cap 254 if err := e.Listen(backlog); err != nil { 255 panic("endpoint listening failed: " + err.String()) 256 } 257 e.LockUser() 258 if e.shutdownFlags != 0 { 259 e.shutdownLocked(e.shutdownFlags) 260 } 261 e.UnlockUser() 262 listenLoading.Done() 263 tcpip.AsyncLoading.Done() 264 }() 265 case epState.connecting(): 266 tcpip.AsyncLoading.Add(1) 267 go func() { 268 connectedLoading.Wait() 269 listenLoading.Wait() 270 bind() 271 err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}) 272 if _, ok := err.(*tcpip.ErrConnectStarted); !ok { 273 panic("endpoint connecting failed: " + err.String()) 274 } 275 connectingLoading.Done() 276 tcpip.AsyncLoading.Done() 277 }() 278 case epState == StateBound: 279 tcpip.AsyncLoading.Add(1) 280 go func() { 281 connectedLoading.Wait() 282 listenLoading.Wait() 283 connectingLoading.Wait() 284 bind() 285 tcpip.AsyncLoading.Done() 286 }() 287 case epState == StateClose: 288 e.isPortReserved = false 289 e.state = uint32(StateClose) 290 e.stack.CompleteTransportEndpointCleanup(e) 291 tcpip.DeleteDanglingEndpoint(e) 292 case epState == StateError: 293 e.state = uint32(StateError) 294 e.stack.CompleteTransportEndpointCleanup(e) 295 tcpip.DeleteDanglingEndpoint(e) 296 } 297 }