github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/vif/stack.go (about) 1 package vif 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "time" 8 9 "go.opentelemetry.io/otel" 10 "go.opentelemetry.io/otel/attribute" 11 "go.opentelemetry.io/otel/codes" 12 "go.opentelemetry.io/otel/trace" 13 "gvisor.dev/gvisor/pkg/tcpip" 14 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 15 "gvisor.dev/gvisor/pkg/tcpip/header" 16 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 17 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 18 "gvisor.dev/gvisor/pkg/tcpip/stack" 19 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 20 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 21 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 22 "gvisor.dev/gvisor/pkg/waiter" 23 24 "github.com/datawire/dlib/dlog" 25 "github.com/telepresenceio/telepresence/v2/pkg/iputil" 26 "github.com/telepresenceio/telepresence/v2/pkg/tunnel" 27 ) 28 29 func NewStack(ctx context.Context, dev stack.LinkEndpoint, streamCreator tunnel.StreamCreator) (*stack.Stack, error) { 30 s := stack.New(stack.Options{ 31 NetworkProtocols: []stack.NetworkProtocolFactory{ 32 ipv4.NewProtocol, 33 ipv6.NewProtocol, 34 }, 35 TransportProtocols: []stack.TransportProtocolFactory{ 36 icmp.NewProtocol4, 37 icmp.NewProtocol6, 38 tcp.NewProtocol, 39 udp.NewProtocol, 40 }, 41 HandleLocal: false, 42 }) 43 if err := setDefaultOptions(s); err != nil { 44 return nil, err 45 } 46 if err := setNIC(ctx, s, dev); err != nil { 47 return nil, err 48 } 49 setTCPHandler(ctx, s, streamCreator) 50 setUDPHandler(ctx, s, streamCreator) 51 return s, nil 52 } 53 54 const ( 55 myWindowScale = 6 56 maxReceiveWindow = 1 << (myWindowScale + 14) // 1MiB 57 ) 58 59 // maxInFlight specifies the max number of in-flight connection attempts. 60 const maxInFlight = 512 61 62 // keepAliveIdle is used as the very first alive interval. Subsequent intervals 63 // use keepAliveInterval. 64 const keepAliveIdle = 60 * time.Second 65 66 // keepAliveInterval is the interval between sending keep-alive packets. 67 const keepAliveInterval = 30 * time.Second 68 69 // keepAliveCount is the max number of keep-alive probes that can be sent 70 // before the connection is killed due to lack of response. 71 const keepAliveCount = 9 72 73 type idStringer stack.TransportEndpointID 74 75 func (i idStringer) String() string { 76 return fmt.Sprintf("%s -> %s", 77 iputil.JoinIpPort(i.RemoteAddress.AsSlice(), i.RemotePort), 78 iputil.JoinIpPort(i.LocalAddress.AsSlice(), i.LocalPort)) 79 } 80 81 func setDefaultOptions(s *stack.Stack) error { 82 // Forwarding 83 if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { 84 return fmt.Errorf("SetForwardingDefaultAndAllNICs(ipv4, %t): %s", true, err) 85 } 86 if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { 87 return fmt.Errorf("SetForwardingDefaultAndAllNICs(ipv6, %t): %s", true, err) 88 } 89 ttl := tcpip.DefaultTTLOption(64) 90 if err := s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &ttl); err != nil { 91 return fmt.Errorf("SetDefaultTTL(ipv4, %d): %s", ttl, err) 92 } 93 if err := s.SetNetworkProtocolOption(ipv6.ProtocolNumber, &ttl); err != nil { 94 return fmt.Errorf("SetDefaultTTL(ipv6, %d): %s", ttl, err) 95 } 96 return nil 97 } 98 99 func setNIC(ctx context.Context, s *stack.Stack, ep stack.LinkEndpoint) error { 100 nicID := tcpip.NICID(s.UniqueID()) 101 if err := s.CreateNICWithOptions(nicID, ep, stack.NICOptions{Name: "tel", Context: ctx}); err != nil { 102 return fmt.Errorf("create NIC failed: %s", err) 103 } 104 if err := s.SetPromiscuousMode(nicID, true); err != nil { 105 return fmt.Errorf("SetPromiscuousMode(%d, %t): %s", nicID, true, err) 106 } 107 if err := s.SetSpoofing(nicID, true); err != nil { 108 return fmt.Errorf("SetSpoofing(%d, %t): %s", nicID, true, err) 109 } 110 s.SetRouteTable([]tcpip.Route{ 111 { 112 Destination: header.IPv4EmptySubnet, 113 NIC: nicID, 114 }, 115 { 116 Destination: header.IPv6EmptySubnet, 117 NIC: nicID, 118 }, 119 }) 120 return nil 121 } 122 123 func forwardTCP(ctx context.Context, streamCreator tunnel.StreamCreator, fr *tcp.ForwarderRequest) { 124 var ep tcpip.Endpoint 125 var err tcpip.Error 126 id := fr.ID() 127 128 ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "TCPHandler", 129 trace.WithNewRoot(), 130 trace.WithAttributes( 131 attribute.String("tel2.remote-ip", id.RemoteAddress.String()), 132 attribute.String("tel2.local-ip", id.LocalAddress.String()), 133 attribute.Int("tel2.local-port", int(id.LocalPort)), 134 attribute.Int("tel2.remote-port", int(id.RemotePort)), 135 )) 136 defer func() { 137 if err != nil { 138 msg := fmt.Sprintf("forward TCP %s: %s", idStringer(id), err) 139 span.SetStatus(codes.Error, msg) 140 dlog.Errorf(ctx, msg) 141 } 142 span.End() 143 }() 144 145 wq := waiter.Queue{} 146 if ep, err = fr.CreateEndpoint(&wq); err != nil { 147 fr.Complete(true) 148 return 149 } 150 defer fr.Complete(false) 151 152 so := ep.SocketOptions() 153 so.SetKeepAlive(true) 154 155 idle := tcpip.KeepaliveIdleOption(keepAliveIdle) 156 if err = ep.SetSockOpt(&idle); err != nil { 157 return 158 } 159 160 ivl := tcpip.KeepaliveIntervalOption(keepAliveInterval) 161 if err = ep.SetSockOpt(&ivl); err != nil { 162 return 163 } 164 165 if err = ep.SetSockOptInt(tcpip.KeepaliveCountOption, keepAliveCount); err != nil { 166 return 167 } 168 dispatchToStream(ctx, newConnID(header.TCPProtocolNumber, id), gonet.NewTCPConn(&wq, ep), streamCreator) 169 } 170 171 func setTCPHandler(ctx context.Context, s *stack.Stack, streamCreator tunnel.StreamCreator) { 172 if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, 173 &tcpip.TCPSendBufferSizeRangeOption{ 174 Min: tcp.MinBufferSize, 175 Default: tcp.DefaultSendBufferSize, 176 Max: tcp.MaxBufferSize, 177 }); err != nil { 178 return 179 } 180 181 if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, 182 &tcpip.TCPReceiveBufferSizeRangeOption{ 183 Min: tcp.MinBufferSize, 184 Default: tcp.DefaultSendBufferSize, 185 Max: tcp.MaxBufferSize, 186 }); err != nil { 187 return 188 } 189 190 sa := tcpip.TCPSACKEnabled(true) 191 s.SetTransportProtocolOption(tcp.ProtocolNumber, &sa) 192 193 // Enable Receive Buffer Auto-Tuning, see: 194 // https://github.com/google/gvisor/issues/1666 195 mo := tcpip.TCPModerateReceiveBufferOption(true) 196 s.SetTransportProtocolOption(tcp.ProtocolNumber, &mo) 197 198 f := tcp.NewForwarder(s, maxReceiveWindow, maxInFlight, func(fr *tcp.ForwarderRequest) { 199 forwardTCP(ctx, streamCreator, fr) 200 }) 201 s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) 202 } 203 204 var blockedUDPPorts = map[uint16]bool{ //nolint:gochecknoglobals // constant 205 137: true, // NETBIOS Name Service 206 138: true, // NETBIOS Datagram Service 207 139: true, // NETBIOS 208 } 209 210 func forwardUDP(ctx context.Context, streamCreator tunnel.StreamCreator, fr *udp.ForwarderRequest) { 211 id := fr.ID() 212 ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "UDPHandler", 213 trace.WithNewRoot(), 214 trace.WithAttributes( 215 attribute.String("tel2.remote-ip", id.RemoteAddress.To4().String()), 216 attribute.String("tel2.local-ip", id.LocalAddress.To4().String()), 217 attribute.Int("tel2.local-port", int(id.LocalPort)), 218 attribute.Int("tel2.remote-port", int(id.RemotePort)), 219 attribute.Bool("tel2.port-blocked", false), 220 )) 221 defer span.End() 222 223 if _, ok := blockedUDPPorts[id.LocalPort]; ok { 224 span.SetAttributes(attribute.Bool("tel2.port-blocked", true)) 225 return 226 } 227 228 wq := waiter.Queue{} 229 ep, err := fr.CreateEndpoint(&wq) 230 if err != nil { 231 msg := fmt.Sprintf("forward UDP %s: %s", idStringer(id), err) 232 span.SetStatus(codes.Error, msg) 233 dlog.Errorf(ctx, msg) 234 return 235 } 236 dispatchToStream(ctx, newConnID(udp.ProtocolNumber, id), gonet.NewUDPConn(&wq, ep), streamCreator) 237 } 238 239 func setUDPHandler(ctx context.Context, s *stack.Stack, streamCreator tunnel.StreamCreator) { 240 f := udp.NewForwarder(s, func(fr *udp.ForwarderRequest) { 241 forwardUDP(ctx, streamCreator, fr) 242 }) 243 s.SetTransportProtocolHandler(udp.ProtocolNumber, f.HandlePacket) 244 } 245 246 func newConnID(proto tcpip.TransportProtocolNumber, id stack.TransportEndpointID) tunnel.ConnID { 247 return tunnel.NewConnID(int(proto), id.RemoteAddress.AsSlice(), id.LocalAddress.AsSlice(), id.RemotePort, id.LocalPort) 248 } 249 250 func dispatchToStream(ctx context.Context, id tunnel.ConnID, conn net.Conn, streamCreator tunnel.StreamCreator) { 251 ctx, cancel := context.WithCancel(ctx) 252 stream, err := streamCreator(ctx, id) 253 if err != nil { 254 dlog.Errorf(ctx, "forward %s: %s", id, err) 255 cancel() 256 return 257 } 258 ep := tunnel.NewConnEndpoint(stream, conn, cancel, nil, nil) 259 ep.Start(ctx) 260 }