github.com/ttpreport/gvisor-ligolo@v0.0.0-20240123134145-a858404967ba/pkg/tcpip/transport/udp/forwarder.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package udp 16 17 import ( 18 "github.com/ttpreport/gvisor-ligolo/pkg/tcpip" 19 "github.com/ttpreport/gvisor-ligolo/pkg/tcpip/stack" 20 "github.com/ttpreport/gvisor-ligolo/pkg/waiter" 21 ) 22 23 // Forwarder is a session request forwarder, which allows clients to decide 24 // what to do with a session request, for example: ignore it, or process it. 25 // 26 // The canonical way of using it is to pass the Forwarder.HandlePacket function 27 // to stack.SetTransportProtocolHandler. 28 type Forwarder struct { 29 handler func(*ForwarderRequest) 30 31 stack *stack.Stack 32 } 33 34 // NewForwarder allocates and initializes a new forwarder. 35 func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { 36 return &Forwarder{ 37 stack: s, 38 handler: handler, 39 } 40 } 41 42 // HandlePacket handles all packets. 43 // 44 // This function is expected to be passed as an argument to the 45 // stack.SetTransportProtocolHandler function. 46 func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { 47 f.handler(&ForwarderRequest{ 48 stack: f.stack, 49 id: id, 50 pkt: pkt.IncRef(), 51 }) 52 53 return true 54 } 55 56 // ForwarderRequest represents a session request received by the forwarder and 57 // passed to the client. Clients may optionally create an endpoint to represent 58 // it via CreateEndpoint. 59 type ForwarderRequest struct { 60 stack *stack.Stack 61 id stack.TransportEndpointID 62 pkt stack.PacketBufferPtr 63 } 64 65 // ID returns the 4-tuple (src address, src port, dst address, dst port) that 66 // represents the session request. 67 func (r *ForwarderRequest) ID() stack.TransportEndpointID { 68 return r.id 69 } 70 71 // CreateEndpoint creates a connected UDP endpoint for the session request. 72 func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { 73 ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) 74 ep.mu.Lock() 75 defer ep.mu.Unlock() 76 77 netHdr := r.pkt.Network() 78 if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil { 79 return nil, err 80 } 81 82 if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil { 83 return nil, err 84 } 85 86 if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { 87 ep.Close() 88 return nil, err 89 } 90 91 ep.localPort = r.id.LocalPort 92 ep.remotePort = r.id.RemotePort 93 ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} 94 ep.boundPortFlags = ep.portFlags 95 96 ep.rcvMu.Lock() 97 ep.rcvReady = true 98 ep.rcvMu.Unlock() 99 100 ep.HandlePacket(r.id, r.pkt) 101 102 return ep, nil 103 }