github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/transport/tcp/endpoint_state.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "fmt" 19 20 "github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops" 21 "github.com/nicocha30/gvisor-ligolo/pkg/sync" 22 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip" 23 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header" 24 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/ports" 25 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack" 26 ) 27 28 // beforeSave is invoked by stateify. 29 func (e *endpoint) beforeSave() { 30 // Stop incoming packets. 31 e.segmentQueue.freeze() 32 33 e.mu.Lock() 34 defer e.mu.Unlock() 35 36 epState := e.EndpointState() 37 switch { 38 case epState == StateInitial || epState == StateBound: 39 case epState.connected() || epState.handshake(): 40 if !e.route.HasSaveRestoreCapability() { 41 if !e.route.HasDisconnectOkCapability() { 42 panic(&tcpip.ErrSaveRejection{ 43 Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort), 44 }) 45 } 46 e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) 47 e.mu.Unlock() 48 e.Close() 49 e.mu.Lock() 50 } 51 fallthrough 52 case epState == StateListen: 53 // Nothing to do. 54 case epState.closed(): 55 // Nothing to do. 56 default: 57 panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) 58 } 59 } 60 61 // saveEndpoints is invoked by stateify. 62 func (a *acceptQueue) saveEndpoints() []*endpoint { 63 acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) 64 for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { 65 acceptedEndpoints[i] = e.Value.(*endpoint) 66 } 67 return acceptedEndpoints 68 } 69 70 // loadEndpoints is invoked by stateify. 71 func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) { 72 for _, ep := range acceptedEndpoints { 73 a.endpoints.PushBack(ep) 74 } 75 } 76 77 // saveState is invoked by stateify. 78 func (e *endpoint) saveState() EndpointState { 79 return e.EndpointState() 80 } 81 82 // Endpoint loading must be done in the following ordering by their state, to 83 // avoid dangling connecting w/o listening peer, and to avoid conflicts in port 84 // reservation. 85 var connectedLoading sync.WaitGroup 86 var listenLoading sync.WaitGroup 87 var connectingLoading sync.WaitGroup 88 89 // Bound endpoint loading happens last. 90 91 // loadState is invoked by stateify. 92 func (e *endpoint) loadState(epState EndpointState) { 93 // This is to ensure that the loading wait groups include all applicable 94 // endpoints before any asynchronous calls to the Wait() methods. 95 // For restore purposes we treat TimeWait like a connected endpoint. 96 if epState.connected() || epState == StateTimeWait { 97 connectedLoading.Add(1) 98 } 99 switch { 100 case epState == StateListen: 101 listenLoading.Add(1) 102 case epState.connecting(): 103 connectingLoading.Add(1) 104 } 105 // Directly update the state here rather than using e.setEndpointState 106 // as the endpoint is still being loaded and the stack reference is not 107 // yet initialized. 108 e.state.Store(uint32(epState)) 109 } 110 111 // afterLoad is invoked by stateify. 112 func (e *endpoint) afterLoad() { 113 // RacyLoad() can be used because we are initializing e. 114 e.origEndpointState = e.state.RacyLoad() 115 // Restore the endpoint to InitialState as it will be moved to 116 // its origEndpointState during Resume. 117 e.state = atomicbitops.FromUint32(uint32(StateInitial)) 118 stack.StackFromEnv.RegisterRestoredEndpoint(e) 119 } 120 121 // Resume implements tcpip.ResumableEndpoint.Resume. 122 func (e *endpoint) Resume(s *stack.Stack) { 123 if !e.EndpointState().closed() { 124 e.keepalive.timer.init(s.Clock(), maybeFailTimerHandler(e, e.keepaliveTimerExpired)) 125 } 126 if snd := e.snd; snd != nil { 127 snd.resendTimer.init(s.Clock(), maybeFailTimerHandler(e, e.snd.retransmitTimerExpired)) 128 snd.reorderTimer.init(s.Clock(), timerHandler(e, e.snd.rc.reorderTimerExpired)) 129 snd.probeTimer.init(s.Clock(), timerHandler(e, e.snd.probeTimerExpired)) 130 } 131 e.stack = s 132 e.protocol = protocolFromStack(s) 133 e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) 134 e.segmentQueue.thaw() 135 136 bind := func() { 137 e.mu.Lock() 138 defer e.mu.Unlock() 139 addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}) 140 if err != nil { 141 panic("unable to parse BindAddr: " + err.String()) 142 } 143 portRes := ports.Reservation{ 144 Networks: e.effectiveNetProtos, 145 Transport: ProtocolNumber, 146 Addr: addr.Addr, 147 Port: addr.Port, 148 Flags: e.boundPortFlags, 149 BindToDevice: e.boundBindToDevice, 150 Dest: e.boundDest, 151 } 152 if ok := e.stack.ReserveTuple(portRes); !ok { 153 panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) 154 } 155 e.isPortReserved = true 156 157 // Mark endpoint as bound. 158 e.setEndpointState(StateBound) 159 } 160 161 epState := EndpointState(e.origEndpointState) 162 switch { 163 case epState.connected(): 164 bind() 165 if e.connectingAddress.BitLen() == 0 { 166 e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress 167 // This endpoint is accepted by netstack but not yet by 168 // the app. If the endpoint is IPv6 but the remote 169 // address is IPv4, we need to connect as IPv6 so that 170 // dual-stack mode can be properly activated. 171 if e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.RemoteAddress.BitLen() != header.IPv6AddressSizeBits { 172 e.connectingAddress = tcpip.AddrFrom16Slice(append( 173 []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff}, 174 e.TransportEndpointInfo.ID.RemoteAddress.AsSlice()..., 175 )) 176 } 177 } 178 // Reset the scoreboard to reinitialize the sack information as 179 // we do not restore SACK information. 180 e.scoreboard.Reset() 181 e.mu.Lock() 182 err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false /* handshake */) 183 if _, ok := err.(*tcpip.ErrConnectStarted); !ok { 184 panic("endpoint connecting failed: " + err.String()) 185 } 186 e.state.Store(e.origEndpointState) 187 // For FIN-WAIT-2 and TIME-WAIT we need to start the appropriate timers so 188 // that the socket is closed correctly. 189 switch epState { 190 case StateFinWait2: 191 e.finWait2Timer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, e.finWait2TimerExpired) 192 case StateTimeWait: 193 e.timeWaitTimer = e.stack.Clock().AfterFunc(e.getTimeWaitDuration(), e.timeWaitTimerExpired) 194 } 195 196 e.mu.Unlock() 197 connectedLoading.Done() 198 case epState == StateListen: 199 tcpip.AsyncLoading.Add(1) 200 go func() { 201 connectedLoading.Wait() 202 bind() 203 e.acceptMu.Lock() 204 backlog := e.acceptQueue.capacity 205 e.acceptMu.Unlock() 206 if err := e.Listen(backlog); err != nil { 207 panic("endpoint listening failed: " + err.String()) 208 } 209 e.LockUser() 210 if e.shutdownFlags != 0 { 211 e.shutdownLocked(e.shutdownFlags) 212 } 213 e.UnlockUser() 214 listenLoading.Done() 215 tcpip.AsyncLoading.Done() 216 }() 217 case epState == StateConnecting: 218 // Initial SYN hasn't been sent yet so initiate a connect. 219 tcpip.AsyncLoading.Add(1) 220 go func() { 221 connectedLoading.Wait() 222 listenLoading.Wait() 223 bind() 224 err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}) 225 if _, ok := err.(*tcpip.ErrConnectStarted); !ok { 226 panic("endpoint connecting failed: " + err.String()) 227 } 228 connectingLoading.Done() 229 tcpip.AsyncLoading.Done() 230 }() 231 case epState == StateSynSent || epState == StateSynRecv: 232 connectedLoading.Wait() 233 listenLoading.Wait() 234 // Initial SYN has been sent/received so we should bind the 235 // ports start the retransmit timer for the SYNs and let it 236 // naturally complete the connection. 237 bind() 238 e.mu.Lock() 239 defer e.mu.Unlock() 240 e.setEndpointState(epState) 241 r, err := e.stack.FindRoute(e.boundNICID, e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.RemoteAddress, e.effectiveNetProtos[0], false /* multicastLoop */) 242 if err != nil { 243 panic(fmt.Sprintf("FindRoute failed when restoring endpoint w/ ID: %+v", e.ID)) 244 } 245 e.route = r 246 timer, err := newBackoffTimer(e.stack.Clock(), InitialRTO, MaxRTO, maybeFailTimerHandler(e, e.h.retransmitHandlerLocked)) 247 if err != nil { 248 panic(fmt.Sprintf("newBackOffTimer(_, %s, %s, _) failed: %s", InitialRTO, MaxRTO, err)) 249 } 250 e.h.retransmitTimer = timer 251 connectingLoading.Done() 252 case epState == StateBound: 253 tcpip.AsyncLoading.Add(1) 254 go func() { 255 connectedLoading.Wait() 256 listenLoading.Wait() 257 connectingLoading.Wait() 258 bind() 259 tcpip.AsyncLoading.Done() 260 }() 261 case epState == StateClose: 262 e.isPortReserved = false 263 e.state.Store(uint32(StateClose)) 264 e.stack.CompleteTransportEndpointCleanup(e) 265 tcpip.DeleteDanglingEndpoint(e) 266 case epState == StateError: 267 e.state.Store(uint32(StateError)) 268 e.stack.CompleteTransportEndpointCleanup(e) 269 tcpip.DeleteDanglingEndpoint(e) 270 } 271 }