github.com/flowerwrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/transport/tcp/forwarder.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"sync"
    19  
    20  	"github.com/FlowerWrong/netstack/tcpip"
    21  	"github.com/FlowerWrong/netstack/tcpip/buffer"
    22  	"github.com/FlowerWrong/netstack/tcpip/header"
    23  	"github.com/FlowerWrong/netstack/tcpip/seqnum"
    24  	"github.com/FlowerWrong/netstack/tcpip/stack"
    25  	"github.com/FlowerWrong/netstack/waiter"
    26  )
    27  
    28  // Forwarder is a connection request forwarder, which allows clients to decide
    29  // what to do with a connection request, for example: ignore it, send a RST, or
    30  // attempt to complete the 3-way handshake.
    31  //
    32  // The canonical way of using it is to pass the Forwarder.HandlePacket function
    33  // to stack.SetTransportProtocolHandler.
    34  type Forwarder struct {
    35  	maxInFlight int
    36  	handler     func(*ForwarderRequest)
    37  
    38  	mu       sync.Mutex
    39  	inFlight map[stack.TransportEndpointID]struct{}
    40  	listen   *listenContext
    41  }
    42  
    43  // NewForwarder allocates and initializes a new forwarder with the given
    44  // maximum number of in-flight connection attempts. Once the maximum is reached
    45  // new incoming connection requests will be ignored.
    46  //
    47  // If rcvWnd is set to zero, the default buffer size is used instead.
    48  func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
    49  	if rcvWnd == 0 {
    50  		rcvWnd = DefaultReceiveBufferSize
    51  	}
    52  	return &Forwarder{
    53  		maxInFlight: maxInFlight,
    54  		handler:     handler,
    55  		inFlight:    make(map[stack.TransportEndpointID]struct{}),
    56  		listen:      newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
    57  	}
    58  }
    59  
    60  // HandlePacket handles a packet if it is of interest to the forwarder (i.e., if
    61  // it's a SYN packet), returning true if it's the case. Otherwise the packet
    62  // is not handled and false is returned.
    63  //
    64  // This function is expected to be passed as an argument to the
    65  // stack.SetTransportProtocolHandler function.
    66  func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
    67  	s := newSegment(r, id, vv)
    68  	defer s.decRef()
    69  
    70  	// We only care about well-formed SYN packets.
    71  	if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
    72  		return false
    73  	}
    74  
    75  	opts := parseSynSegmentOptions(s)
    76  
    77  	f.mu.Lock()
    78  	defer f.mu.Unlock()
    79  
    80  	// We have an inflight request for this id, ignore this one for now.
    81  	if _, ok := f.inFlight[id]; ok {
    82  		return true
    83  	}
    84  
    85  	// Ignore the segment if we're beyond the limit.
    86  	if len(f.inFlight) >= f.maxInFlight {
    87  		return true
    88  	}
    89  
    90  	// Launch a new goroutine to handle the request.
    91  	f.inFlight[id] = struct{}{}
    92  	s.incRef()
    93  	go f.handler(&ForwarderRequest{
    94  		forwarder:  f,
    95  		segment:    s,
    96  		synOptions: opts,
    97  	})
    98  
    99  	return true
   100  }
   101  
   102  // ForwarderRequest represents a connection request received by the forwarder
   103  // and passed to the client. Clients must eventually call Complete() on it, and
   104  // may optionally create an endpoint to represent it via CreateEndpoint.
   105  type ForwarderRequest struct {
   106  	mu         sync.Mutex
   107  	forwarder  *Forwarder
   108  	segment    *segment
   109  	synOptions header.TCPSynOptions
   110  }
   111  
   112  // ID returns the 4-tuple (src address, src port, dst address, dst port) that
   113  // represents the connection request.
   114  func (r *ForwarderRequest) ID() stack.TransportEndpointID {
   115  	return r.segment.id
   116  }
   117  
   118  // Complete completes the request, and optionally sends a RST segment back to the
   119  // sender.
   120  func (r *ForwarderRequest) Complete(sendReset bool) {
   121  	r.mu.Lock()
   122  	defer r.mu.Unlock()
   123  
   124  	if r.segment == nil {
   125  		panic("Completing already completed forwarder request")
   126  	}
   127  
   128  	// Remove request from the forwarder.
   129  	r.forwarder.mu.Lock()
   130  	delete(r.forwarder.inFlight, r.segment.id)
   131  	r.forwarder.mu.Unlock()
   132  
   133  	// If the caller requested, send a reset.
   134  	if sendReset {
   135  		replyWithReset(r.segment)
   136  	}
   137  
   138  	// Release all resources.
   139  	r.segment.decRef()
   140  	r.segment = nil
   141  	r.forwarder = nil
   142  }
   143  
   144  // CreateEndpoint creates a TCP endpoint for the connection request, performing
   145  // the 3-way handshake in the process.
   146  func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
   147  	r.mu.Lock()
   148  	defer r.mu.Unlock()
   149  
   150  	if r.segment == nil {
   151  		return nil, tcpip.ErrInvalidEndpointState
   152  	}
   153  
   154  	f := r.forwarder
   155  	ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
   156  		MSS:           r.synOptions.MSS,
   157  		WS:            r.synOptions.WS,
   158  		TS:            r.synOptions.TS,
   159  		TSVal:         r.synOptions.TSVal,
   160  		TSEcr:         r.synOptions.TSEcr,
   161  		SACKPermitted: r.synOptions.SACKPermitted,
   162  	})
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  
   167  	// Start the protocol goroutine.
   168  	ep.startAcceptedLoop(queue)
   169  
   170  	return ep, nil
   171  }