github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/vif/device.go (about)

     1  package vif
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"sync"
     8  
     9  	"go.opentelemetry.io/otel"
    10  	"go.opentelemetry.io/otel/attribute"
    11  	"go.opentelemetry.io/otel/trace"
    12  	"gvisor.dev/gvisor/pkg/buffer"
    13  	"gvisor.dev/gvisor/pkg/tcpip"
    14  	"gvisor.dev/gvisor/pkg/tcpip/header"
    15  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    16  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    17  
    18  	"github.com/datawire/dlib/dlog"
    19  	"github.com/telepresenceio/telepresence/v2/pkg/tracing"
    20  	vifBuffer "github.com/telepresenceio/telepresence/v2/pkg/vif/buffer"
    21  )
    22  
    23  type device struct {
    24  	*channel.Endpoint
    25  	ctx context.Context
    26  	wg  sync.WaitGroup
    27  	dev *nativeDevice
    28  }
    29  
    30  type Device interface {
    31  	stack.LinkEndpoint
    32  	io.Closer
    33  	Index() int32
    34  	Name() string
    35  	AddSubnet(context.Context, *net.IPNet) error
    36  	RemoveSubnet(context.Context, *net.IPNet) error
    37  	SetDNS(context.Context, string, net.IP, []string) (err error)
    38  	WaitForDevice()
    39  }
    40  
    41  const defaultDevMtu = 1500
    42  
    43  // Queue length for outbound packet, arriving at fd side for read. Overflow
    44  // causes packet drops. gVisor implementation-specific.
    45  const defaultDevOutQueueLen = 1024
    46  
    47  var _ Device = (*device)(nil)
    48  
    49  // OpenTun creates a new TUN device and ensures that it is up and running.
    50  func OpenTun(ctx context.Context) (Device, error) {
    51  	dev, err := openTun(ctx)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	return &device{
    57  		Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
    58  		ctx:      ctx,
    59  		dev:      dev,
    60  	}, nil
    61  }
    62  
    63  func (d *device) Attach(dp stack.NetworkDispatcher) {
    64  	go func() {
    65  		d.Endpoint.Attach(dp)
    66  		if dp == nil {
    67  			// Stack is closing
    68  			return
    69  		}
    70  		dlog.Info(d.ctx, "Starting Endpoint")
    71  		ctx, cancel := context.WithCancel(d.ctx)
    72  		d.wg.Add(2)
    73  		go d.tunToDispatch(cancel)
    74  		d.dispatchToTun(ctx)
    75  	}()
    76  }
    77  
    78  // AddSubnet adds a subnet to this TUN device and creates a route for that subnet which
    79  // is associated with the device (removing the device will automatically remove the route).
    80  func (d *device) AddSubnet(ctx context.Context, subnet *net.IPNet) (err error) {
    81  	ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "AddSubnet", trace.WithAttributes(attribute.Stringer("tel2.subnet", subnet)))
    82  	defer tracing.EndAndRecord(span, err)
    83  	return d.dev.addSubnet(ctx, subnet)
    84  }
    85  
    86  func (d *device) Close() error {
    87  	return d.dev.Close()
    88  }
    89  
    90  // Index returns the index of this device.
    91  func (d *device) Index() int32 {
    92  	return d.dev.index()
    93  }
    94  
    95  // Name returns the name of this device, e.g. "tun0".
    96  func (d *device) Name() string {
    97  	return d.dev.name
    98  }
    99  
   100  // SetDNS sets the DNS configuration for the device on the windows platform.
   101  func (d *device) SetDNS(ctx context.Context, clusterDomain string, server net.IP, domains []string) (err error) {
   102  	return d.dev.setDNS(ctx, clusterDomain, server, domains)
   103  }
   104  
   105  func (d *device) SetMTU(mtu int) error {
   106  	return d.dev.setMTU(mtu)
   107  }
   108  
   109  // RemoveSubnet removes a subnet from this TUN device and also removes the route for that subnet which
   110  // is associated with the device.
   111  func (d *device) RemoveSubnet(ctx context.Context, subnet *net.IPNet) (err error) {
   112  	// Staticcheck screams if this is ctx, span := because it thinks the context argument is being overwritten before being used.
   113  	sCtx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "RemoveSubnet", trace.WithAttributes(attribute.Stringer("tel2.subnet", subnet)))
   114  	defer tracing.EndAndRecord(span, err)
   115  	return d.dev.removeSubnet(sCtx, subnet)
   116  }
   117  
   118  func (d *device) WaitForDevice() {
   119  	d.wg.Wait()
   120  	dlog.Info(d.ctx, "Endpoint done")
   121  }
   122  
   123  func (d *device) tunToDispatch(cancel context.CancelFunc) {
   124  	defer func() {
   125  		cancel()
   126  		d.wg.Done()
   127  	}()
   128  	buf := vifBuffer.NewData(0x10000)
   129  	data := buf.Buf()
   130  	for ok := true; ok; {
   131  		n, err := d.dev.readPacket(buf)
   132  		if err != nil {
   133  			ok = d.IsAttached()
   134  			if ok && d.ctx.Err() == nil {
   135  				dlog.Errorf(d.ctx, "read packet error: %v", err)
   136  			}
   137  			return
   138  		}
   139  		if n == 0 {
   140  			continue
   141  		}
   142  
   143  		var ipv tcpip.NetworkProtocolNumber
   144  		switch header.IPVersion(data) {
   145  		case header.IPv4Version:
   146  			ipv = header.IPv4ProtocolNumber
   147  		case header.IPv6Version:
   148  			ipv = header.IPv6ProtocolNumber
   149  		default:
   150  			continue
   151  		}
   152  
   153  		pb := stack.NewPacketBuffer(stack.PacketBufferOptions{
   154  			Payload: buffer.MakeWithData(data[:n]),
   155  		})
   156  
   157  		d.InjectInbound(ipv, pb)
   158  		pb.DecRef()
   159  	}
   160  }
   161  
   162  func (d *device) dispatchToTun(ctx context.Context) {
   163  	defer d.wg.Done()
   164  	buf := vifBuffer.NewData(0x10000)
   165  	for {
   166  		pb := d.ReadContext(ctx)
   167  		if pb == nil {
   168  			break
   169  		}
   170  		buf.Resize(pb.Size())
   171  		b := buf.Buf()
   172  		for _, s := range pb.AsSlices() {
   173  			copy(b, s)
   174  			b = b[len(s):]
   175  		}
   176  		pb.DecRef()
   177  		if _, err := d.dev.writePacket(buf, 0); err != nil {
   178  			dlog.Errorf(ctx, "WritePacket failed: %v", err)
   179  		}
   180  	}
   181  }