gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/udp/forwarder.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp
    16  
    17  import (
    18  	"gvisor.dev/gvisor/pkg/tcpip"
    19  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    20  	"gvisor.dev/gvisor/pkg/waiter"
    21  )
    22  
    23  // Forwarder is a session request forwarder, which allows clients to decide
    24  // what to do with a session request, for example: ignore it, or process it.
    25  //
    26  // The canonical way of using it is to pass the Forwarder.HandlePacket function
    27  // to stack.SetTransportProtocolHandler.
    28  type Forwarder struct {
    29  	handler func(*ForwarderRequest)
    30  
    31  	stack *stack.Stack
    32  }
    33  
    34  // NewForwarder allocates and initializes a new forwarder.
    35  func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
    36  	return &Forwarder{
    37  		stack:   s,
    38  		handler: handler,
    39  	}
    40  }
    41  
    42  // HandlePacket handles all packets.
    43  //
    44  // This function is expected to be passed as an argument to the
    45  // stack.SetTransportProtocolHandler function.
    46  func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
    47  	f.handler(&ForwarderRequest{
    48  		stack: f.stack,
    49  		id:    id,
    50  		pkt:   pkt.IncRef(),
    51  	})
    52  
    53  	return true
    54  }
    55  
    56  // ForwarderRequest represents a session request received by the forwarder and
    57  // passed to the client. Clients may optionally create an endpoint to represent
    58  // it via CreateEndpoint.
    59  type ForwarderRequest struct {
    60  	stack *stack.Stack
    61  	id    stack.TransportEndpointID
    62  	pkt   *stack.PacketBuffer
    63  }
    64  
    65  // ID returns the 4-tuple (src address, src port, dst address, dst port) that
    66  // represents the session request.
    67  func (r *ForwarderRequest) ID() stack.TransportEndpointID {
    68  	return r.id
    69  }
    70  
    71  // CreateEndpoint creates a connected UDP endpoint for the session request.
    72  func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
    73  	ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
    74  	ep.mu.Lock()
    75  	defer ep.mu.Unlock()
    76  
    77  	netHdr := r.pkt.Network()
    78  	if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
    87  		ep.Close()
    88  		return nil, err
    89  	}
    90  
    91  	ep.localPort = r.id.LocalPort
    92  	ep.remotePort = r.id.RemotePort
    93  	ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
    94  	ep.boundPortFlags = ep.portFlags
    95  
    96  	ep.rcvMu.Lock()
    97  	ep.rcvReady = true
    98  	ep.rcvMu.Unlock()
    99  
   100  	ep.HandlePacket(r.id, r.pkt)
   101  
   102  	return ep, nil
   103  }