github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/link/tun/device.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tun
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/SagerNet/gvisor/pkg/context"
    21  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    22  	"github.com/SagerNet/gvisor/pkg/sync"
    23  	"github.com/SagerNet/gvisor/pkg/syserror"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    29  	"github.com/SagerNet/gvisor/pkg/waiter"
    30  )
    31  
    32  const (
    33  	// drivers/net/tun.c:tun_net_init()
    34  	defaultDevMtu = 1500
    35  
    36  	// Queue length for outbound packet, arriving at fd side for read. Overflow
    37  	// causes packet drops. gVisor implementation-specific.
    38  	defaultDevOutQueueLen = 1024
    39  )
    40  
    41  var zeroMAC [6]byte
    42  
    43  // Device is an opened /dev/net/tun device.
    44  //
    45  // +stateify savable
    46  type Device struct {
    47  	waiter.Queue
    48  
    49  	mu           sync.RWMutex `state:"nosave"`
    50  	endpoint     *tunEndpoint
    51  	notifyHandle *channel.NotificationHandle
    52  	flags        Flags
    53  }
    54  
    55  // Flags set properties of a Device
    56  type Flags struct {
    57  	TUN          bool
    58  	TAP          bool
    59  	NoPacketInfo bool
    60  }
    61  
    62  // beforeSave is invoked by stateify.
    63  func (d *Device) beforeSave() {
    64  	d.mu.Lock()
    65  	defer d.mu.Unlock()
    66  	// TODO(b/110961832): Restore the device to stack. At this moment, the stack
    67  	// is not savable.
    68  	if d.endpoint != nil {
    69  		panic("/dev/net/tun does not support save/restore when a device is associated with it.")
    70  	}
    71  }
    72  
    73  // Release implements fs.FileOperations.Release.
    74  func (d *Device) Release(ctx context.Context) {
    75  	d.mu.Lock()
    76  	defer d.mu.Unlock()
    77  
    78  	// Decrease refcount if there is an endpoint associated with this file.
    79  	if d.endpoint != nil {
    80  		d.endpoint.RemoveNotify(d.notifyHandle)
    81  		d.endpoint.DecRef(ctx)
    82  		d.endpoint = nil
    83  	}
    84  }
    85  
    86  // SetIff services TUNSETIFF ioctl(2) request.
    87  func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
    88  	d.mu.Lock()
    89  	defer d.mu.Unlock()
    90  
    91  	if d.endpoint != nil {
    92  		return linuxerr.EINVAL
    93  	}
    94  
    95  	// Input validation.
    96  	if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN {
    97  		return linuxerr.EINVAL
    98  	}
    99  
   100  	prefix := "tun"
   101  	if flags.TAP {
   102  		prefix = "tap"
   103  	}
   104  
   105  	linkCaps := stack.CapabilityNone
   106  	if flags.TAP {
   107  		linkCaps |= stack.CapabilityResolutionRequired
   108  	}
   109  
   110  	endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
   111  	if err != nil {
   112  		return linuxerr.EINVAL
   113  	}
   114  
   115  	d.endpoint = endpoint
   116  	d.notifyHandle = d.endpoint.AddNotify(d)
   117  	d.flags = flags
   118  	return nil
   119  }
   120  
   121  func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
   122  	for {
   123  		// 1. Try to attach to an existing NIC.
   124  		if name != "" {
   125  			if linkEP := s.GetLinkEndpointByName(name); linkEP != nil {
   126  				endpoint, ok := linkEP.(*tunEndpoint)
   127  				if !ok {
   128  					// Not a NIC created by tun device.
   129  					return nil, syserror.EOPNOTSUPP
   130  				}
   131  				if !endpoint.TryIncRef() {
   132  					// Race detected: NIC got deleted in between.
   133  					continue
   134  				}
   135  				return endpoint, nil
   136  			}
   137  		}
   138  
   139  		// 2. Creating a new NIC.
   140  		id := tcpip.NICID(s.UniqueID())
   141  		endpoint := &tunEndpoint{
   142  			Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
   143  			stack:    s,
   144  			nicID:    id,
   145  			name:     name,
   146  			isTap:    prefix == "tap",
   147  		}
   148  		endpoint.InitRefs()
   149  		endpoint.Endpoint.LinkEPCapabilities = linkCaps
   150  		if endpoint.name == "" {
   151  			endpoint.name = fmt.Sprintf("%s%d", prefix, id)
   152  		}
   153  		err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
   154  			Name: endpoint.name,
   155  		})
   156  		switch err.(type) {
   157  		case nil:
   158  			return endpoint, nil
   159  		case *tcpip.ErrDuplicateNICID:
   160  			// Race detected: A NIC has been created in between.
   161  			continue
   162  		default:
   163  			return nil, linuxerr.EINVAL
   164  		}
   165  	}
   166  }
   167  
   168  // Write inject one inbound packet to the network interface.
   169  func (d *Device) Write(data []byte) (int64, error) {
   170  	d.mu.RLock()
   171  	endpoint := d.endpoint
   172  	d.mu.RUnlock()
   173  	if endpoint == nil {
   174  		return 0, linuxerr.EBADFD
   175  	}
   176  	if !endpoint.IsAttached() {
   177  		return 0, syserror.EIO
   178  	}
   179  
   180  	dataLen := int64(len(data))
   181  
   182  	// Packet information.
   183  	var pktInfoHdr PacketInfoHeader
   184  	if !d.flags.NoPacketInfo {
   185  		if len(data) < PacketInfoHeaderSize {
   186  			// Ignore bad packet.
   187  			return dataLen, nil
   188  		}
   189  		pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize])
   190  		data = data[PacketInfoHeaderSize:]
   191  	}
   192  
   193  	// Ethernet header (TAP only).
   194  	var ethHdr header.Ethernet
   195  	if d.flags.TAP {
   196  		if len(data) < header.EthernetMinimumSize {
   197  			// Ignore bad packet.
   198  			return dataLen, nil
   199  		}
   200  		ethHdr = header.Ethernet(data[:header.EthernetMinimumSize])
   201  		data = data[header.EthernetMinimumSize:]
   202  	}
   203  
   204  	// Try to determine network protocol number, default zero.
   205  	var protocol tcpip.NetworkProtocolNumber
   206  	switch {
   207  	case pktInfoHdr != nil:
   208  		protocol = pktInfoHdr.Protocol()
   209  	case ethHdr != nil:
   210  		protocol = ethHdr.Type()
   211  	case d.flags.TUN:
   212  		// TUN interface with IFF_NO_PI enabled, thus
   213  		// we need to determine protocol from version field
   214  		version := data[0] >> 4
   215  		if version == 4 {
   216  			protocol = header.IPv4ProtocolNumber
   217  		} else if version == 6 {
   218  			protocol = header.IPv6ProtocolNumber
   219  		}
   220  	}
   221  
   222  	// Try to determine remote link address, default zero.
   223  	var remote tcpip.LinkAddress
   224  	switch {
   225  	case ethHdr != nil:
   226  		remote = ethHdr.SourceAddress()
   227  	default:
   228  		remote = tcpip.LinkAddress(zeroMAC[:])
   229  	}
   230  
   231  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   232  		ReserveHeaderBytes: len(ethHdr),
   233  		Data:               buffer.View(data).ToVectorisedView(),
   234  	})
   235  	copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr)
   236  	endpoint.InjectLinkAddr(protocol, remote, pkt)
   237  	return dataLen, nil
   238  }
   239  
   240  // Read reads one outgoing packet from the network interface.
   241  func (d *Device) Read() ([]byte, error) {
   242  	d.mu.RLock()
   243  	endpoint := d.endpoint
   244  	d.mu.RUnlock()
   245  	if endpoint == nil {
   246  		return nil, linuxerr.EBADFD
   247  	}
   248  
   249  	for {
   250  		info, ok := endpoint.Read()
   251  		if !ok {
   252  			return nil, syserror.ErrWouldBlock
   253  		}
   254  
   255  		v, ok := d.encodePkt(&info)
   256  		if !ok {
   257  			// Ignore unsupported packet.
   258  			continue
   259  		}
   260  		return v, nil
   261  	}
   262  }
   263  
   264  // encodePkt encodes packet for fd side.
   265  func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
   266  	var vv buffer.VectorisedView
   267  
   268  	// Packet information.
   269  	if !d.flags.NoPacketInfo {
   270  		hdr := make(PacketInfoHeader, PacketInfoHeaderSize)
   271  		hdr.Encode(&PacketInfoFields{
   272  			Protocol: info.Proto,
   273  		})
   274  		vv.AppendView(buffer.View(hdr))
   275  	}
   276  
   277  	// Ethernet header (TAP only).
   278  	if d.flags.TAP {
   279  		// Add ethernet header if not provided.
   280  		if info.Pkt.LinkHeader().View().IsEmpty() {
   281  			d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
   282  		}
   283  		vv.AppendView(info.Pkt.LinkHeader().View())
   284  	}
   285  
   286  	// Append upper headers.
   287  	vv.AppendView(info.Pkt.NetworkHeader().View())
   288  	vv.AppendView(info.Pkt.TransportHeader().View())
   289  	// Append data payload.
   290  	vv.Append(info.Pkt.Data().ExtractVV())
   291  
   292  	return vv.ToView(), true
   293  }
   294  
   295  // Name returns the name of the attached network interface. Empty string if
   296  // unattached.
   297  func (d *Device) Name() string {
   298  	d.mu.RLock()
   299  	defer d.mu.RUnlock()
   300  	if d.endpoint != nil {
   301  		return d.endpoint.name
   302  	}
   303  	return ""
   304  }
   305  
   306  // Flags returns the flags set for d. Zero value if unset.
   307  func (d *Device) Flags() Flags {
   308  	d.mu.RLock()
   309  	defer d.mu.RUnlock()
   310  	return d.flags
   311  }
   312  
   313  // Readiness implements watier.Waitable.Readiness.
   314  func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
   315  	if mask&waiter.ReadableEvents != 0 {
   316  		d.mu.RLock()
   317  		endpoint := d.endpoint
   318  		d.mu.RUnlock()
   319  		if endpoint != nil && endpoint.NumQueued() == 0 {
   320  			mask &= ^waiter.ReadableEvents
   321  		}
   322  	}
   323  	return mask & (waiter.ReadableEvents | waiter.WritableEvents)
   324  }
   325  
   326  // WriteNotify implements channel.Notification.WriteNotify.
   327  func (d *Device) WriteNotify() {
   328  	d.Notify(waiter.ReadableEvents)
   329  }
   330  
   331  // tunEndpoint is the link endpoint for the NIC created by the tun device.
   332  //
   333  // It is ref-counted as multiple opening files can attach to the same NIC.
   334  // The last owner is responsible for deleting the NIC.
   335  type tunEndpoint struct {
   336  	tunEndpointRefs
   337  	*channel.Endpoint
   338  
   339  	stack *stack.Stack
   340  	nicID tcpip.NICID
   341  	name  string
   342  	isTap bool
   343  }
   344  
   345  // DecRef decrements refcount of e, removing NIC if it reaches 0.
   346  func (e *tunEndpoint) DecRef(ctx context.Context) {
   347  	e.tunEndpointRefs.DecRef(func() {
   348  		e.stack.RemoveNIC(e.nicID)
   349  	})
   350  }
   351  
   352  // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
   353  func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
   354  	if e.isTap {
   355  		return header.ARPHardwareEther
   356  	}
   357  	return header.ARPHardwareNone
   358  }
   359  
   360  // AddHeader implements stack.LinkEndpoint.AddHeader.
   361  func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
   362  	if !e.isTap {
   363  		return
   364  	}
   365  	eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
   366  	hdr := &header.EthernetFields{
   367  		SrcAddr: local,
   368  		DstAddr: remote,
   369  		Type:    protocol,
   370  	}
   371  	if hdr.SrcAddr == "" {
   372  		hdr.SrcAddr = e.LinkAddress()
   373  	}
   374  
   375  	eth.Encode(hdr)
   376  }
   377  
   378  // MaxHeaderLength returns the maximum size of the link layer header.
   379  func (e *tunEndpoint) MaxHeaderLength() uint16 {
   380  	if e.isTap {
   381  		return header.EthernetMinimumSize
   382  	}
   383  	return 0
   384  }