github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/link/tun/device.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tun
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/sagernet/gvisor/pkg/buffer"
    21  	"github.com/sagernet/gvisor/pkg/context"
    22  	"github.com/sagernet/gvisor/pkg/errors/linuxerr"
    23  	"github.com/sagernet/gvisor/pkg/sync"
    24  	"github.com/sagernet/gvisor/pkg/tcpip"
    25  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    26  	"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
    27  	"github.com/sagernet/gvisor/pkg/tcpip/link/packetsocket"
    28  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    29  	"github.com/sagernet/gvisor/pkg/waiter"
    30  )
    31  
    32  const (
    33  	// drivers/net/tun.c:tun_net_init()
    34  	defaultDevMtu = 1500
    35  
    36  	// Queue length for outbound packet, arriving at fd side for read. Overflow
    37  	// causes packet drops. gVisor implementation-specific.
    38  	defaultDevOutQueueLen = 1024
    39  )
    40  
    41  var zeroMAC [6]byte
    42  
    43  // Device is an opened /dev/net/tun device.
    44  //
    45  // +stateify savable
    46  type Device struct {
    47  	waiter.Queue
    48  
    49  	mu           sync.RWMutex `state:"nosave"`
    50  	endpoint     *tunEndpoint
    51  	notifyHandle *channel.NotificationHandle
    52  	flags        Flags
    53  }
    54  
    55  // Flags set properties of a Device
    56  //
    57  // +stateify savable
    58  type Flags struct {
    59  	TUN          bool
    60  	TAP          bool
    61  	NoPacketInfo bool
    62  }
    63  
    64  // beforeSave is invoked by stateify.
    65  func (d *Device) beforeSave() {
    66  	d.mu.Lock()
    67  	defer d.mu.Unlock()
    68  	// TODO(b/110961832): Restore the device to stack. At this moment, the stack
    69  	// is not savable.
    70  	if d.endpoint != nil {
    71  		panic("/dev/net/tun does not support save/restore when a device is associated with it.")
    72  	}
    73  }
    74  
    75  // Release implements fs.FileOperations.Release.
    76  func (d *Device) Release(ctx context.Context) {
    77  	d.mu.Lock()
    78  	defer d.mu.Unlock()
    79  
    80  	// Decrease refcount if there is an endpoint associated with this file.
    81  	if d.endpoint != nil {
    82  		d.endpoint.Drain()
    83  		d.endpoint.RemoveNotify(d.notifyHandle)
    84  		d.endpoint.DecRef(ctx)
    85  		d.endpoint = nil
    86  	}
    87  }
    88  
    89  // SetIff services TUNSETIFF ioctl(2) request.
    90  func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
    91  	d.mu.Lock()
    92  	defer d.mu.Unlock()
    93  
    94  	if d.endpoint != nil {
    95  		return linuxerr.EINVAL
    96  	}
    97  
    98  	// Input validation.
    99  	if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN {
   100  		return linuxerr.EINVAL
   101  	}
   102  
   103  	prefix := "tun"
   104  	if flags.TAP {
   105  		prefix = "tap"
   106  	}
   107  
   108  	linkCaps := stack.CapabilityNone
   109  	if flags.TAP {
   110  		linkCaps |= stack.CapabilityResolutionRequired
   111  	}
   112  
   113  	endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
   114  	if err != nil {
   115  		return linuxerr.EINVAL
   116  	}
   117  
   118  	d.endpoint = endpoint
   119  	d.notifyHandle = d.endpoint.AddNotify(d)
   120  	d.flags = flags
   121  	return nil
   122  }
   123  
   124  func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
   125  	for {
   126  		// 1. Try to attach to an existing NIC.
   127  		if name != "" {
   128  			if linkEP := s.GetLinkEndpointByName(name); linkEP != nil {
   129  				endpoint, ok := linkEP.(*tunEndpoint)
   130  				if !ok {
   131  					// Not a NIC created by tun device.
   132  					return nil, linuxerr.EOPNOTSUPP
   133  				}
   134  				if !endpoint.TryIncRef() {
   135  					// Race detected: NIC got deleted in between.
   136  					continue
   137  				}
   138  				return endpoint, nil
   139  			}
   140  		}
   141  
   142  		// 2. Creating a new NIC.
   143  		id := tcpip.NICID(s.UniqueID())
   144  		endpoint := &tunEndpoint{
   145  			Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
   146  			stack:    s,
   147  			nicID:    id,
   148  			name:     name,
   149  			isTap:    prefix == "tap",
   150  		}
   151  		endpoint.InitRefs()
   152  		endpoint.Endpoint.LinkEPCapabilities = linkCaps
   153  		if endpoint.name == "" {
   154  			endpoint.name = fmt.Sprintf("%s%d", prefix, id)
   155  		}
   156  		err := s.CreateNICWithOptions(endpoint.nicID, packetsocket.New(endpoint), stack.NICOptions{
   157  			Name: endpoint.name,
   158  		})
   159  		switch err.(type) {
   160  		case nil:
   161  			return endpoint, nil
   162  		case *tcpip.ErrDuplicateNICID:
   163  			// Race detected: A NIC has been created in between.
   164  			continue
   165  		default:
   166  			return nil, linuxerr.EINVAL
   167  		}
   168  	}
   169  }
   170  
   171  // MTU returns the tun endpoint MTU (maximum transmission unit).
   172  func (d *Device) MTU() (uint32, error) {
   173  	d.mu.RLock()
   174  	endpoint := d.endpoint
   175  	d.mu.RUnlock()
   176  	if endpoint == nil {
   177  		return 0, linuxerr.EBADFD
   178  	}
   179  	if !endpoint.IsAttached() {
   180  		return 0, linuxerr.EIO
   181  	}
   182  	return endpoint.MTU(), nil
   183  }
   184  
   185  // Write inject one inbound packet to the network interface.
   186  func (d *Device) Write(data *buffer.View) (int64, error) {
   187  	d.mu.RLock()
   188  	endpoint := d.endpoint
   189  	d.mu.RUnlock()
   190  	if endpoint == nil {
   191  		return 0, linuxerr.EBADFD
   192  	}
   193  	if !endpoint.IsAttached() {
   194  		return 0, linuxerr.EIO
   195  	}
   196  
   197  	dataLen := int64(data.Size())
   198  
   199  	// Packet information.
   200  	var pktInfoHdr PacketInfoHeader
   201  	if !d.flags.NoPacketInfo {
   202  		if dataLen < PacketInfoHeaderSize {
   203  			// Ignore bad packet.
   204  			return dataLen, nil
   205  		}
   206  		pktInfoHdrView := data.Clone()
   207  		defer pktInfoHdrView.Release()
   208  		pktInfoHdrView.CapLength(PacketInfoHeaderSize)
   209  		pktInfoHdr = PacketInfoHeader(pktInfoHdrView.AsSlice())
   210  		data.TrimFront(PacketInfoHeaderSize)
   211  	}
   212  
   213  	// Ethernet header (TAP only).
   214  	var ethHdr header.Ethernet
   215  	if d.flags.TAP {
   216  		if data.Size() < header.EthernetMinimumSize {
   217  			// Ignore bad packet.
   218  			return dataLen, nil
   219  		}
   220  		ethHdrView := data.Clone()
   221  		defer ethHdrView.Release()
   222  		ethHdrView.CapLength(header.EthernetMinimumSize)
   223  		ethHdr = header.Ethernet(ethHdrView.AsSlice())
   224  		data.TrimFront(header.EthernetMinimumSize)
   225  	}
   226  
   227  	// Try to determine network protocol number, default zero.
   228  	var protocol tcpip.NetworkProtocolNumber
   229  	switch {
   230  	case pktInfoHdr != nil:
   231  		protocol = pktInfoHdr.Protocol()
   232  	case ethHdr != nil:
   233  		protocol = ethHdr.Type()
   234  	case d.flags.TUN:
   235  		// TUN interface with IFF_NO_PI enabled, thus
   236  		// we need to determine protocol from version field
   237  		version := data.AsSlice()[0] >> 4
   238  		if version == 4 {
   239  			protocol = header.IPv4ProtocolNumber
   240  		} else if version == 6 {
   241  			protocol = header.IPv6ProtocolNumber
   242  		}
   243  	}
   244  
   245  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   246  		ReserveHeaderBytes: len(ethHdr),
   247  		Payload:            buffer.MakeWithView(data.Clone()),
   248  	})
   249  	defer pkt.DecRef()
   250  	copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr)
   251  	endpoint.InjectInbound(protocol, pkt)
   252  	return dataLen, nil
   253  }
   254  
   255  // Read reads one outgoing packet from the network interface.
   256  func (d *Device) Read() (*buffer.View, error) {
   257  	d.mu.RLock()
   258  	endpoint := d.endpoint
   259  	d.mu.RUnlock()
   260  	if endpoint == nil {
   261  		return nil, linuxerr.EBADFD
   262  	}
   263  
   264  	pkt := endpoint.Read()
   265  	if pkt == nil {
   266  		return nil, linuxerr.ErrWouldBlock
   267  	}
   268  	v := d.encodePkt(pkt)
   269  	pkt.DecRef()
   270  	return v, nil
   271  }
   272  
   273  // encodePkt encodes packet for fd side.
   274  func (d *Device) encodePkt(pkt *stack.PacketBuffer) *buffer.View {
   275  	var view *buffer.View
   276  
   277  	// Packet information.
   278  	if !d.flags.NoPacketInfo {
   279  		view = buffer.NewView(PacketInfoHeaderSize + pkt.Size())
   280  		view.Grow(PacketInfoHeaderSize)
   281  		hdr := PacketInfoHeader(view.AsSlice())
   282  		hdr.Encode(&PacketInfoFields{
   283  			Protocol: pkt.NetworkProtocolNumber,
   284  		})
   285  		pktView := pkt.ToView()
   286  		view.Write(pktView.AsSlice())
   287  		pktView.Release()
   288  	} else {
   289  		view = pkt.ToView()
   290  	}
   291  
   292  	return view
   293  }
   294  
   295  // Name returns the name of the attached network interface. Empty string if
   296  // unattached.
   297  func (d *Device) Name() string {
   298  	d.mu.RLock()
   299  	defer d.mu.RUnlock()
   300  	if d.endpoint != nil {
   301  		return d.endpoint.name
   302  	}
   303  	return ""
   304  }
   305  
   306  // Flags returns the flags set for d. Zero value if unset.
   307  func (d *Device) Flags() Flags {
   308  	d.mu.RLock()
   309  	defer d.mu.RUnlock()
   310  	return d.flags
   311  }
   312  
   313  // Readiness implements watier.Waitable.Readiness.
   314  func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
   315  	if mask&waiter.ReadableEvents != 0 {
   316  		d.mu.RLock()
   317  		endpoint := d.endpoint
   318  		d.mu.RUnlock()
   319  		if endpoint != nil && endpoint.NumQueued() == 0 {
   320  			mask &= ^waiter.ReadableEvents
   321  		}
   322  	}
   323  	return mask & (waiter.ReadableEvents | waiter.WritableEvents)
   324  }
   325  
   326  // WriteNotify implements channel.Notification.WriteNotify.
   327  func (d *Device) WriteNotify() {
   328  	d.Notify(waiter.ReadableEvents)
   329  }
   330  
   331  // tunEndpoint is the link endpoint for the NIC created by the tun device.
   332  //
   333  // It is ref-counted as multiple opening files can attach to the same NIC.
   334  // The last owner is responsible for deleting the NIC.
   335  type tunEndpoint struct {
   336  	tunEndpointRefs
   337  	*channel.Endpoint
   338  
   339  	stack *stack.Stack
   340  	nicID tcpip.NICID
   341  	name  string
   342  	isTap bool
   343  }
   344  
   345  // DecRef decrements refcount of e, removing NIC if it reaches 0.
   346  func (e *tunEndpoint) DecRef(ctx context.Context) {
   347  	e.tunEndpointRefs.DecRef(func() {
   348  		e.Close()
   349  		e.stack.RemoveNIC(e.nicID)
   350  	})
   351  }
   352  
   353  // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
   354  func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
   355  	if e.isTap {
   356  		return header.ARPHardwareEther
   357  	}
   358  	return header.ARPHardwareNone
   359  }
   360  
   361  // AddHeader implements stack.LinkEndpoint.AddHeader.
   362  func (e *tunEndpoint) AddHeader(pkt *stack.PacketBuffer) {
   363  	if !e.isTap {
   364  		return
   365  	}
   366  	eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
   367  	eth.Encode(&header.EthernetFields{
   368  		SrcAddr: pkt.EgressRoute.LocalLinkAddress,
   369  		DstAddr: pkt.EgressRoute.RemoteLinkAddress,
   370  		Type:    pkt.NetworkProtocolNumber,
   371  	})
   372  }
   373  
   374  // MaxHeaderLength returns the maximum size of the link layer header.
   375  func (e *tunEndpoint) MaxHeaderLength() uint16 {
   376  	if e.isTap {
   377  		return header.EthernetMinimumSize
   378  	}
   379  	return 0
   380  }