github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/transport/tcp/forwarder.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "github.com/sagernet/gvisor/pkg/sync" 19 "github.com/sagernet/gvisor/pkg/tcpip" 20 "github.com/sagernet/gvisor/pkg/tcpip/header" 21 "github.com/sagernet/gvisor/pkg/tcpip/seqnum" 22 "github.com/sagernet/gvisor/pkg/tcpip/stack" 23 "github.com/sagernet/gvisor/pkg/waiter" 24 ) 25 26 // Forwarder is a connection request forwarder, which allows clients to decide 27 // what to do with a connection request, for example: ignore it, send a RST, or 28 // attempt to complete the 3-way handshake. 29 // 30 // The canonical way of using it is to pass the Forwarder.HandlePacket function 31 // to stack.SetTransportProtocolHandler. 32 type Forwarder struct { 33 stack *stack.Stack 34 35 maxInFlight int 36 handler func(*ForwarderRequest) 37 38 mu sync.Mutex 39 inFlight map[stack.TransportEndpointID]struct{} 40 listen *listenContext 41 } 42 43 // NewForwarder allocates and initializes a new forwarder with the given 44 // maximum number of in-flight connection attempts. Once the maximum is reached 45 // new incoming connection requests will be ignored. 46 // 47 // If rcvWnd is set to zero, the default buffer size is used instead. 48 func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder { 49 if rcvWnd == 0 { 50 rcvWnd = DefaultReceiveBufferSize 51 } 52 return &Forwarder{ 53 stack: s, 54 maxInFlight: maxInFlight, 55 handler: handler, 56 inFlight: make(map[stack.TransportEndpointID]struct{}), 57 listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), 58 } 59 } 60 61 // HandlePacket handles a packet if it is of interest to the forwarder (i.e., if 62 // it's a SYN packet), returning true if it's the case. Otherwise the packet 63 // is not handled and false is returned. 64 // 65 // This function is expected to be passed as an argument to the 66 // stack.SetTransportProtocolHandler function. 67 func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { 68 s, err := newIncomingSegment(id, f.stack.Clock(), pkt) 69 if err != nil { 70 return false 71 } 72 defer s.DecRef() 73 74 // We only care about well-formed SYN packets (not SYN-ACK) packets. 75 if !s.csumValid || !s.flags.Contains(header.TCPFlagSyn) || s.flags.Contains(header.TCPFlagAck) { 76 return false 77 } 78 79 opts := parseSynSegmentOptions(s) 80 81 f.mu.Lock() 82 defer f.mu.Unlock() 83 84 // We have an inflight request for this id, ignore this one for now. 85 if _, ok := f.inFlight[id]; ok { 86 return true 87 } 88 89 // Ignore the segment if we're beyond the limit. 90 if len(f.inFlight) >= f.maxInFlight { 91 f.stack.Stats().TCP.ForwardMaxInFlightDrop.Increment() 92 return true 93 } 94 95 // Launch a new goroutine to handle the request. 96 f.inFlight[id] = struct{}{} 97 s.IncRef() 98 go f.handler(&ForwarderRequest{ // S/R-SAFE: not used by Sentry. 99 forwarder: f, 100 segment: s, 101 synOptions: opts, 102 }) 103 104 return true 105 } 106 107 // ForwarderRequest represents a connection request received by the forwarder 108 // and passed to the client. Clients must eventually call Complete() on it, and 109 // may optionally create an endpoint to represent it via CreateEndpoint. 110 type ForwarderRequest struct { 111 mu sync.Mutex 112 forwarder *Forwarder 113 segment *segment 114 synOptions header.TCPSynOptions 115 } 116 117 // ID returns the 4-tuple (src address, src port, dst address, dst port) that 118 // represents the connection request. 119 func (r *ForwarderRequest) ID() stack.TransportEndpointID { 120 return r.segment.id 121 } 122 123 // Complete completes the request, and optionally sends a RST segment back to the 124 // sender. 125 func (r *ForwarderRequest) Complete(sendReset bool) { 126 r.mu.Lock() 127 defer r.mu.Unlock() 128 129 if r.segment == nil { 130 panic("Completing already completed forwarder request") 131 } 132 133 // Remove request from the forwarder. 134 r.forwarder.mu.Lock() 135 delete(r.forwarder.inFlight, r.segment.id) 136 r.forwarder.mu.Unlock() 137 138 if sendReset { 139 replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, tcpip.UseDefaultIPv4TTL, tcpip.UseDefaultIPv6HopLimit) 140 } 141 142 // Release all resources. 143 r.segment.DecRef() 144 r.segment = nil 145 r.forwarder = nil 146 } 147 148 // CreateEndpoint creates a TCP endpoint for the connection request, performing 149 // the 3-way handshake in the process. 150 func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { 151 r.mu.Lock() 152 defer r.mu.Unlock() 153 154 if r.segment == nil { 155 return nil, &tcpip.ErrInvalidEndpointState{} 156 } 157 158 f := r.forwarder 159 ep, err := f.listen.performHandshake(r.segment, header.TCPSynOptions{ 160 MSS: r.synOptions.MSS, 161 WS: r.synOptions.WS, 162 TS: r.synOptions.TS, 163 TSVal: r.synOptions.TSVal, 164 TSEcr: r.synOptions.TSEcr, 165 SACKPermitted: r.synOptions.SACKPermitted, 166 }, queue, nil) 167 if err != nil { 168 return nil, err 169 } 170 171 return ep, nil 172 }