github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/vif/stack.go (about)

     1  package vif
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"time"
     8  
     9  	"go.opentelemetry.io/otel"
    10  	"go.opentelemetry.io/otel/attribute"
    11  	"go.opentelemetry.io/otel/codes"
    12  	"go.opentelemetry.io/otel/trace"
    13  	"gvisor.dev/gvisor/pkg/tcpip"
    14  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    15  	"gvisor.dev/gvisor/pkg/tcpip/header"
    16  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    17  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    18  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    19  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    20  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    21  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    22  	"gvisor.dev/gvisor/pkg/waiter"
    23  
    24  	"github.com/datawire/dlib/dlog"
    25  	"github.com/telepresenceio/telepresence/v2/pkg/iputil"
    26  	"github.com/telepresenceio/telepresence/v2/pkg/tunnel"
    27  )
    28  
    29  func NewStack(ctx context.Context, dev stack.LinkEndpoint, streamCreator tunnel.StreamCreator) (*stack.Stack, error) {
    30  	s := stack.New(stack.Options{
    31  		NetworkProtocols: []stack.NetworkProtocolFactory{
    32  			ipv4.NewProtocol,
    33  			ipv6.NewProtocol,
    34  		},
    35  		TransportProtocols: []stack.TransportProtocolFactory{
    36  			icmp.NewProtocol4,
    37  			icmp.NewProtocol6,
    38  			tcp.NewProtocol,
    39  			udp.NewProtocol,
    40  		},
    41  		HandleLocal: false,
    42  	})
    43  	if err := setDefaultOptions(s); err != nil {
    44  		return nil, err
    45  	}
    46  	if err := setNIC(ctx, s, dev); err != nil {
    47  		return nil, err
    48  	}
    49  	setTCPHandler(ctx, s, streamCreator)
    50  	setUDPHandler(ctx, s, streamCreator)
    51  	return s, nil
    52  }
    53  
    54  const (
    55  	myWindowScale    = 6
    56  	maxReceiveWindow = 1 << (myWindowScale + 14) // 1MiB
    57  )
    58  
    59  // maxInFlight specifies the max number of in-flight connection attempts.
    60  const maxInFlight = 512
    61  
    62  // keepAliveIdle is used as the very first alive interval. Subsequent intervals
    63  // use keepAliveInterval.
    64  const keepAliveIdle = 60 * time.Second
    65  
    66  // keepAliveInterval is the interval between sending keep-alive packets.
    67  const keepAliveInterval = 30 * time.Second
    68  
    69  // keepAliveCount is the max number of keep-alive probes that can be sent
    70  // before the connection is killed due to lack of response.
    71  const keepAliveCount = 9
    72  
    73  type idStringer stack.TransportEndpointID
    74  
    75  func (i idStringer) String() string {
    76  	return fmt.Sprintf("%s -> %s",
    77  		iputil.JoinIpPort(i.RemoteAddress.AsSlice(), i.RemotePort),
    78  		iputil.JoinIpPort(i.LocalAddress.AsSlice(), i.LocalPort))
    79  }
    80  
    81  func setDefaultOptions(s *stack.Stack) error {
    82  	// Forwarding
    83  	if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
    84  		return fmt.Errorf("SetForwardingDefaultAndAllNICs(ipv4, %t): %s", true, err)
    85  	}
    86  	if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
    87  		return fmt.Errorf("SetForwardingDefaultAndAllNICs(ipv6, %t): %s", true, err)
    88  	}
    89  	ttl := tcpip.DefaultTTLOption(64)
    90  	if err := s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &ttl); err != nil {
    91  		return fmt.Errorf("SetDefaultTTL(ipv4, %d): %s", ttl, err)
    92  	}
    93  	if err := s.SetNetworkProtocolOption(ipv6.ProtocolNumber, &ttl); err != nil {
    94  		return fmt.Errorf("SetDefaultTTL(ipv6, %d): %s", ttl, err)
    95  	}
    96  	return nil
    97  }
    98  
    99  func setNIC(ctx context.Context, s *stack.Stack, ep stack.LinkEndpoint) error {
   100  	nicID := tcpip.NICID(s.UniqueID())
   101  	if err := s.CreateNICWithOptions(nicID, ep, stack.NICOptions{Name: "tel", Context: ctx}); err != nil {
   102  		return fmt.Errorf("create NIC failed: %s", err)
   103  	}
   104  	if err := s.SetPromiscuousMode(nicID, true); err != nil {
   105  		return fmt.Errorf("SetPromiscuousMode(%d, %t): %s", nicID, true, err)
   106  	}
   107  	if err := s.SetSpoofing(nicID, true); err != nil {
   108  		return fmt.Errorf("SetSpoofing(%d, %t): %s", nicID, true, err)
   109  	}
   110  	s.SetRouteTable([]tcpip.Route{
   111  		{
   112  			Destination: header.IPv4EmptySubnet,
   113  			NIC:         nicID,
   114  		},
   115  		{
   116  			Destination: header.IPv6EmptySubnet,
   117  			NIC:         nicID,
   118  		},
   119  	})
   120  	return nil
   121  }
   122  
   123  func forwardTCP(ctx context.Context, streamCreator tunnel.StreamCreator, fr *tcp.ForwarderRequest) {
   124  	var ep tcpip.Endpoint
   125  	var err tcpip.Error
   126  	id := fr.ID()
   127  
   128  	ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "TCPHandler",
   129  		trace.WithNewRoot(),
   130  		trace.WithAttributes(
   131  			attribute.String("tel2.remote-ip", id.RemoteAddress.String()),
   132  			attribute.String("tel2.local-ip", id.LocalAddress.String()),
   133  			attribute.Int("tel2.local-port", int(id.LocalPort)),
   134  			attribute.Int("tel2.remote-port", int(id.RemotePort)),
   135  		))
   136  	defer func() {
   137  		if err != nil {
   138  			msg := fmt.Sprintf("forward TCP %s: %s", idStringer(id), err)
   139  			span.SetStatus(codes.Error, msg)
   140  			dlog.Errorf(ctx, msg)
   141  		}
   142  		span.End()
   143  	}()
   144  
   145  	wq := waiter.Queue{}
   146  	if ep, err = fr.CreateEndpoint(&wq); err != nil {
   147  		fr.Complete(true)
   148  		return
   149  	}
   150  	defer fr.Complete(false)
   151  
   152  	so := ep.SocketOptions()
   153  	so.SetKeepAlive(true)
   154  
   155  	idle := tcpip.KeepaliveIdleOption(keepAliveIdle)
   156  	if err = ep.SetSockOpt(&idle); err != nil {
   157  		return
   158  	}
   159  
   160  	ivl := tcpip.KeepaliveIntervalOption(keepAliveInterval)
   161  	if err = ep.SetSockOpt(&ivl); err != nil {
   162  		return
   163  	}
   164  
   165  	if err = ep.SetSockOptInt(tcpip.KeepaliveCountOption, keepAliveCount); err != nil {
   166  		return
   167  	}
   168  	dispatchToStream(ctx, newConnID(header.TCPProtocolNumber, id), gonet.NewTCPConn(&wq, ep), streamCreator)
   169  }
   170  
   171  func setTCPHandler(ctx context.Context, s *stack.Stack, streamCreator tunnel.StreamCreator) {
   172  	if err := s.SetTransportProtocolOption(tcp.ProtocolNumber,
   173  		&tcpip.TCPSendBufferSizeRangeOption{
   174  			Min:     tcp.MinBufferSize,
   175  			Default: tcp.DefaultSendBufferSize,
   176  			Max:     tcp.MaxBufferSize,
   177  		}); err != nil {
   178  		return
   179  	}
   180  
   181  	if err := s.SetTransportProtocolOption(tcp.ProtocolNumber,
   182  		&tcpip.TCPReceiveBufferSizeRangeOption{
   183  			Min:     tcp.MinBufferSize,
   184  			Default: tcp.DefaultSendBufferSize,
   185  			Max:     tcp.MaxBufferSize,
   186  		}); err != nil {
   187  		return
   188  	}
   189  
   190  	sa := tcpip.TCPSACKEnabled(true)
   191  	s.SetTransportProtocolOption(tcp.ProtocolNumber, &sa)
   192  
   193  	// Enable Receive Buffer Auto-Tuning, see:
   194  	// https://github.com/google/gvisor/issues/1666
   195  	mo := tcpip.TCPModerateReceiveBufferOption(true)
   196  	s.SetTransportProtocolOption(tcp.ProtocolNumber, &mo)
   197  
   198  	f := tcp.NewForwarder(s, maxReceiveWindow, maxInFlight, func(fr *tcp.ForwarderRequest) {
   199  		forwardTCP(ctx, streamCreator, fr)
   200  	})
   201  	s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
   202  }
   203  
   204  var blockedUDPPorts = map[uint16]bool{ //nolint:gochecknoglobals // constant
   205  	137: true, // NETBIOS Name Service
   206  	138: true, // NETBIOS Datagram Service
   207  	139: true, // NETBIOS
   208  }
   209  
   210  func forwardUDP(ctx context.Context, streamCreator tunnel.StreamCreator, fr *udp.ForwarderRequest) {
   211  	id := fr.ID()
   212  	ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "UDPHandler",
   213  		trace.WithNewRoot(),
   214  		trace.WithAttributes(
   215  			attribute.String("tel2.remote-ip", id.RemoteAddress.To4().String()),
   216  			attribute.String("tel2.local-ip", id.LocalAddress.To4().String()),
   217  			attribute.Int("tel2.local-port", int(id.LocalPort)),
   218  			attribute.Int("tel2.remote-port", int(id.RemotePort)),
   219  			attribute.Bool("tel2.port-blocked", false),
   220  		))
   221  	defer span.End()
   222  
   223  	if _, ok := blockedUDPPorts[id.LocalPort]; ok {
   224  		span.SetAttributes(attribute.Bool("tel2.port-blocked", true))
   225  		return
   226  	}
   227  
   228  	wq := waiter.Queue{}
   229  	ep, err := fr.CreateEndpoint(&wq)
   230  	if err != nil {
   231  		msg := fmt.Sprintf("forward UDP %s: %s", idStringer(id), err)
   232  		span.SetStatus(codes.Error, msg)
   233  		dlog.Errorf(ctx, msg)
   234  		return
   235  	}
   236  	dispatchToStream(ctx, newConnID(udp.ProtocolNumber, id), gonet.NewUDPConn(&wq, ep), streamCreator)
   237  }
   238  
   239  func setUDPHandler(ctx context.Context, s *stack.Stack, streamCreator tunnel.StreamCreator) {
   240  	f := udp.NewForwarder(s, func(fr *udp.ForwarderRequest) {
   241  		forwardUDP(ctx, streamCreator, fr)
   242  	})
   243  	s.SetTransportProtocolHandler(udp.ProtocolNumber, f.HandlePacket)
   244  }
   245  
   246  func newConnID(proto tcpip.TransportProtocolNumber, id stack.TransportEndpointID) tunnel.ConnID {
   247  	return tunnel.NewConnID(int(proto), id.RemoteAddress.AsSlice(), id.LocalAddress.AsSlice(), id.RemotePort, id.LocalPort)
   248  }
   249  
   250  func dispatchToStream(ctx context.Context, id tunnel.ConnID, conn net.Conn, streamCreator tunnel.StreamCreator) {
   251  	ctx, cancel := context.WithCancel(ctx)
   252  	stream, err := streamCreator(ctx, id)
   253  	if err != nil {
   254  		dlog.Errorf(ctx, "forward %s: %s", id, err)
   255  		cancel()
   256  		return
   257  	}
   258  	ep := tunnel.NewConnEndpoint(stream, conn, cancel, nil, nil)
   259  	ep.Start(ctx)
   260  }