github.com/google/netstack@v0.0.0-20191123085552-55fcc16cd0eb/tcpip/transport/tcp/forwarder.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "sync" 19 20 "github.com/google/netstack/tcpip" 21 "github.com/google/netstack/tcpip/header" 22 "github.com/google/netstack/tcpip/seqnum" 23 "github.com/google/netstack/tcpip/stack" 24 "github.com/google/netstack/waiter" 25 ) 26 27 // Forwarder is a connection request forwarder, which allows clients to decide 28 // what to do with a connection request, for example: ignore it, send a RST, or 29 // attempt to complete the 3-way handshake. 30 // 31 // The canonical way of using it is to pass the Forwarder.HandlePacket function 32 // to stack.SetTransportProtocolHandler. 33 type Forwarder struct { 34 maxInFlight int 35 handler func(*ForwarderRequest) 36 37 mu sync.Mutex 38 inFlight map[stack.TransportEndpointID]struct{} 39 listen *listenContext 40 } 41 42 // NewForwarder allocates and initializes a new forwarder with the given 43 // maximum number of in-flight connection attempts. Once the maximum is reached 44 // new incoming connection requests will be ignored. 45 // 46 // If rcvWnd is set to zero, the default buffer size is used instead. 47 func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder { 48 if rcvWnd == 0 { 49 rcvWnd = DefaultReceiveBufferSize 50 } 51 return &Forwarder{ 52 maxInFlight: maxInFlight, 53 handler: handler, 54 inFlight: make(map[stack.TransportEndpointID]struct{}), 55 listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), 56 } 57 } 58 59 // HandlePacket handles a packet if it is of interest to the forwarder (i.e., if 60 // it's a SYN packet), returning true if it's the case. Otherwise the packet 61 // is not handled and false is returned. 62 // 63 // This function is expected to be passed as an argument to the 64 // stack.SetTransportProtocolHandler function. 65 func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { 66 s := newSegment(r, id, pkt) 67 defer s.decRef() 68 69 // We only care about well-formed SYN packets. 70 if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn { 71 return false 72 } 73 74 opts := parseSynSegmentOptions(s) 75 76 f.mu.Lock() 77 defer f.mu.Unlock() 78 79 // We have an inflight request for this id, ignore this one for now. 80 if _, ok := f.inFlight[id]; ok { 81 return true 82 } 83 84 // Ignore the segment if we're beyond the limit. 85 if len(f.inFlight) >= f.maxInFlight { 86 return true 87 } 88 89 // Launch a new goroutine to handle the request. 90 f.inFlight[id] = struct{}{} 91 s.incRef() 92 go f.handler(&ForwarderRequest{ 93 forwarder: f, 94 segment: s, 95 synOptions: opts, 96 }) 97 98 return true 99 } 100 101 // ForwarderRequest represents a connection request received by the forwarder 102 // and passed to the client. Clients must eventually call Complete() on it, and 103 // may optionally create an endpoint to represent it via CreateEndpoint. 104 type ForwarderRequest struct { 105 mu sync.Mutex 106 forwarder *Forwarder 107 segment *segment 108 synOptions header.TCPSynOptions 109 } 110 111 // ID returns the 4-tuple (src address, src port, dst address, dst port) that 112 // represents the connection request. 113 func (r *ForwarderRequest) ID() stack.TransportEndpointID { 114 return r.segment.id 115 } 116 117 // Complete completes the request, and optionally sends a RST segment back to the 118 // sender. 119 func (r *ForwarderRequest) Complete(sendReset bool) { 120 r.mu.Lock() 121 defer r.mu.Unlock() 122 123 if r.segment == nil { 124 panic("Completing already completed forwarder request") 125 } 126 127 // Remove request from the forwarder. 128 r.forwarder.mu.Lock() 129 delete(r.forwarder.inFlight, r.segment.id) 130 r.forwarder.mu.Unlock() 131 132 // If the caller requested, send a reset. 133 if sendReset { 134 replyWithReset(r.segment) 135 } 136 137 // Release all resources. 138 r.segment.decRef() 139 r.segment = nil 140 r.forwarder = nil 141 } 142 143 // CreateEndpoint creates a TCP endpoint for the connection request, performing 144 // the 3-way handshake in the process. 145 func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { 146 r.mu.Lock() 147 defer r.mu.Unlock() 148 149 if r.segment == nil { 150 return nil, tcpip.ErrInvalidEndpointState 151 } 152 153 f := r.forwarder 154 ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ 155 MSS: r.synOptions.MSS, 156 WS: r.synOptions.WS, 157 TS: r.synOptions.TS, 158 TSVal: r.synOptions.TSVal, 159 TSEcr: r.synOptions.TSEcr, 160 SACKPermitted: r.synOptions.SACKPermitted, 161 }) 162 if err != nil { 163 return nil, err 164 } 165 166 // Start the protocol goroutine. 167 ep.startAcceptedLoop(queue) 168 169 return ep, nil 170 }