github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/tun/tun_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  /* Implementation of the TUN device interface for linux
     9   */
    10  
    11  import (
    12  	"bytes"
    13  	"errors"
    14  	"fmt"
    15  	"os"
    16  	"sync"
    17  	"syscall"
    18  	"time"
    19  	"unsafe"
    20  
    21  	"github.com/tailscale/wireguard-go/rwcancel"
    22  	"golang.org/x/net/ipv6"
    23  	"golang.org/x/sys/unix"
    24  )
    25  
    26  const (
    27  	cloneDevicePath = "/dev/net/tun"
    28  	ifReqSize       = unix.IFNAMSIZ + 64
    29  )
    30  
    31  type NativeTun struct {
    32  	tunFile                 *os.File
    33  	index                   int32      // if index
    34  	errors                  chan error // async error handling
    35  	events                  chan Event // device related events
    36  	nopi                    bool       // the device was passed IFF_NO_PI
    37  	netlinkSock             int
    38  	netlinkCancel           *rwcancel.RWCancel
    39  	hackListenerClosed      sync.Mutex
    40  	statusListenersShutdown chan struct{}
    41  
    42  	closeOnce sync.Once
    43  
    44  	nameOnce  sync.Once // guards calling initNameCache, which sets following fields
    45  	nameCache string    // name of interface
    46  	nameErr   error
    47  }
    48  
    49  func (tun *NativeTun) File() *os.File {
    50  	return tun.tunFile
    51  }
    52  
    53  func (tun *NativeTun) routineHackListener() {
    54  	defer tun.hackListenerClosed.Unlock()
    55  	/* This is needed for the detection to work across network namespaces
    56  	 * If you are reading this and know a better method, please get in touch.
    57  	 */
    58  	last := 0
    59  	const (
    60  		up   = 1
    61  		down = 2
    62  	)
    63  	for {
    64  		sysconn, err := tun.tunFile.SyscallConn()
    65  		if err != nil {
    66  			return
    67  		}
    68  		err2 := sysconn.Control(func(fd uintptr) {
    69  			_, err = unix.Write(int(fd), nil)
    70  		})
    71  		if err2 != nil {
    72  			return
    73  		}
    74  		switch err {
    75  		case unix.EINVAL:
    76  			if last != up {
    77  				// If the tunnel is up, it reports that write() is
    78  				// allowed but we provided invalid data.
    79  				tun.events <- EventUp
    80  				last = up
    81  			}
    82  		case unix.EIO:
    83  			if last != down {
    84  				// If the tunnel is down, it reports that no I/O
    85  				// is possible, without checking our provided data.
    86  				tun.events <- EventDown
    87  				last = down
    88  			}
    89  		default:
    90  			return
    91  		}
    92  		select {
    93  		case <-time.After(time.Second):
    94  			// nothing
    95  		case <-tun.statusListenersShutdown:
    96  			return
    97  		}
    98  	}
    99  }
   100  
   101  func createNetlinkSocket() (int, error) {
   102  	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
   103  	if err != nil {
   104  		return -1, err
   105  	}
   106  	saddr := &unix.SockaddrNetlink{
   107  		Family: unix.AF_NETLINK,
   108  		Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
   109  	}
   110  	err = unix.Bind(sock, saddr)
   111  	if err != nil {
   112  		return -1, err
   113  	}
   114  	return sock, nil
   115  }
   116  
   117  func (tun *NativeTun) routineNetlinkListener() {
   118  	defer func() {
   119  		unix.Close(tun.netlinkSock)
   120  		tun.hackListenerClosed.Lock()
   121  		close(tun.events)
   122  		tun.netlinkCancel.Close()
   123  	}()
   124  
   125  	for msg := make([]byte, 1<<16); ; {
   126  		var err error
   127  		var msgn int
   128  		for {
   129  			msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
   130  			if err == nil || !rwcancel.RetryAfterError(err) {
   131  				break
   132  			}
   133  			if !tun.netlinkCancel.ReadyRead() {
   134  				tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
   135  				return
   136  			}
   137  		}
   138  		if err != nil {
   139  			tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
   140  			return
   141  		}
   142  
   143  		select {
   144  		case <-tun.statusListenersShutdown:
   145  			return
   146  		default:
   147  		}
   148  
   149  		wasEverUp := false
   150  		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
   151  
   152  			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
   153  
   154  			if int(hdr.Len) > len(remain) {
   155  				break
   156  			}
   157  
   158  			switch hdr.Type {
   159  			case unix.NLMSG_DONE:
   160  				remain = []byte{}
   161  
   162  			case unix.RTM_NEWLINK:
   163  				info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
   164  				remain = remain[hdr.Len:]
   165  
   166  				if info.Index != tun.index {
   167  					// not our interface
   168  					continue
   169  				}
   170  
   171  				if info.Flags&unix.IFF_RUNNING != 0 {
   172  					tun.events <- EventUp
   173  					wasEverUp = true
   174  				}
   175  
   176  				if info.Flags&unix.IFF_RUNNING == 0 {
   177  					// Don't emit EventDown before we've ever emitted EventUp.
   178  					// This avoids a startup race with HackListener, which
   179  					// might detect Up before we have finished reporting Down.
   180  					if wasEverUp {
   181  						tun.events <- EventDown
   182  					}
   183  				}
   184  
   185  				tun.events <- EventMTUUpdate
   186  
   187  			default:
   188  				remain = remain[hdr.Len:]
   189  			}
   190  		}
   191  	}
   192  }
   193  
   194  func getIFIndex(name string) (int32, error) {
   195  	fd, err := unix.Socket(
   196  		unix.AF_INET,
   197  		unix.SOCK_DGRAM,
   198  		0,
   199  	)
   200  	if err != nil {
   201  		return 0, err
   202  	}
   203  
   204  	defer unix.Close(fd)
   205  
   206  	var ifr [ifReqSize]byte
   207  	copy(ifr[:], name)
   208  	_, _, errno := unix.Syscall(
   209  		unix.SYS_IOCTL,
   210  		uintptr(fd),
   211  		uintptr(unix.SIOCGIFINDEX),
   212  		uintptr(unsafe.Pointer(&ifr[0])),
   213  	)
   214  
   215  	if errno != 0 {
   216  		return 0, errno
   217  	}
   218  
   219  	return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
   220  }
   221  
   222  func (tun *NativeTun) setMTU(n int) error {
   223  	name, err := tun.Name()
   224  	if err != nil {
   225  		return err
   226  	}
   227  
   228  	// open datagram socket
   229  	fd, err := unix.Socket(
   230  		unix.AF_INET,
   231  		unix.SOCK_DGRAM,
   232  		0,
   233  	)
   234  
   235  	if err != nil {
   236  		return err
   237  	}
   238  
   239  	defer unix.Close(fd)
   240  
   241  	// do ioctl call
   242  	var ifr [ifReqSize]byte
   243  	copy(ifr[:], name)
   244  	*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
   245  	_, _, errno := unix.Syscall(
   246  		unix.SYS_IOCTL,
   247  		uintptr(fd),
   248  		uintptr(unix.SIOCSIFMTU),
   249  		uintptr(unsafe.Pointer(&ifr[0])),
   250  	)
   251  
   252  	if errno != 0 {
   253  		return fmt.Errorf("failed to set MTU of TUN device: %w", errno)
   254  	}
   255  
   256  	return nil
   257  }
   258  
   259  func (tun *NativeTun) MTU() (int, error) {
   260  	name, err := tun.Name()
   261  	if err != nil {
   262  		return 0, err
   263  	}
   264  
   265  	// open datagram socket
   266  	fd, err := unix.Socket(
   267  		unix.AF_INET,
   268  		unix.SOCK_DGRAM,
   269  		0,
   270  	)
   271  
   272  	if err != nil {
   273  		return 0, err
   274  	}
   275  
   276  	defer unix.Close(fd)
   277  
   278  	// do ioctl call
   279  
   280  	var ifr [ifReqSize]byte
   281  	copy(ifr[:], name)
   282  	_, _, errno := unix.Syscall(
   283  		unix.SYS_IOCTL,
   284  		uintptr(fd),
   285  		uintptr(unix.SIOCGIFMTU),
   286  		uintptr(unsafe.Pointer(&ifr[0])),
   287  	)
   288  	if errno != 0 {
   289  		return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno)
   290  	}
   291  
   292  	return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
   293  }
   294  
   295  func (tun *NativeTun) Name() (string, error) {
   296  	tun.nameOnce.Do(tun.initNameCache)
   297  	return tun.nameCache, tun.nameErr
   298  }
   299  
   300  func (tun *NativeTun) initNameCache() {
   301  	tun.nameCache, tun.nameErr = tun.nameSlow()
   302  }
   303  
   304  func (tun *NativeTun) nameSlow() (string, error) {
   305  	sysconn, err := tun.tunFile.SyscallConn()
   306  	if err != nil {
   307  		return "", err
   308  	}
   309  	var ifr [ifReqSize]byte
   310  	var errno syscall.Errno
   311  	err = sysconn.Control(func(fd uintptr) {
   312  		_, _, errno = unix.Syscall(
   313  			unix.SYS_IOCTL,
   314  			fd,
   315  			uintptr(unix.TUNGETIFF),
   316  			uintptr(unsafe.Pointer(&ifr[0])),
   317  		)
   318  	})
   319  	if err != nil {
   320  		return "", fmt.Errorf("failed to get name of TUN device: %w", err)
   321  	}
   322  	if errno != 0 {
   323  		return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
   324  	}
   325  	name := ifr[:]
   326  	if i := bytes.IndexByte(name, 0); i != -1 {
   327  		name = name[:i]
   328  	}
   329  	return string(name), nil
   330  }
   331  
   332  func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
   333  	if tun.nopi {
   334  		buf = buf[offset:]
   335  	} else {
   336  		// reserve space for header
   337  		buf = buf[offset-4:]
   338  
   339  		// add packet information header
   340  		buf[0] = 0x00
   341  		buf[1] = 0x00
   342  		if buf[4]>>4 == ipv6.Version {
   343  			buf[2] = 0x86
   344  			buf[3] = 0xdd
   345  		} else {
   346  			buf[2] = 0x08
   347  			buf[3] = 0x00
   348  		}
   349  	}
   350  
   351  	n, err := tun.tunFile.Write(buf)
   352  	if errors.Is(err, syscall.EBADFD) {
   353  		err = os.ErrClosed
   354  	}
   355  	return n, err
   356  }
   357  
   358  func (tun *NativeTun) Flush() error {
   359  	// TODO: can flushing be implemented by buffering and using sendmmsg?
   360  	return nil
   361  }
   362  
   363  func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
   364  	select {
   365  	case err = <-tun.errors:
   366  	default:
   367  		if tun.nopi {
   368  			n, err = tun.tunFile.Read(buf[offset:])
   369  		} else {
   370  			buff := buf[offset-4:]
   371  			n, err = tun.tunFile.Read(buff[:])
   372  			if errors.Is(err, syscall.EBADFD) {
   373  				err = os.ErrClosed
   374  			}
   375  			if n < 4 {
   376  				n = 0
   377  			} else {
   378  				n -= 4
   379  			}
   380  		}
   381  	}
   382  	return
   383  }
   384  
   385  func (tun *NativeTun) Events() chan Event {
   386  	return tun.events
   387  }
   388  
   389  func (tun *NativeTun) Close() error {
   390  	var err1, err2 error
   391  	tun.closeOnce.Do(func() {
   392  		if tun.statusListenersShutdown != nil {
   393  			close(tun.statusListenersShutdown)
   394  			if tun.netlinkCancel != nil {
   395  				err1 = tun.netlinkCancel.Cancel()
   396  			}
   397  		} else if tun.events != nil {
   398  			close(tun.events)
   399  		}
   400  		err2 = tun.tunFile.Close()
   401  	})
   402  	if err1 != nil {
   403  		return err1
   404  	}
   405  	return err2
   406  }
   407  
   408  func CreateTUN(name string, mtu int) (Device, error) {
   409  	nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
   410  	if err != nil {
   411  		if os.IsNotExist(err) {
   412  			return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
   413  		}
   414  		return nil, err
   415  	}
   416  
   417  	var ifr [ifReqSize]byte
   418  	var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
   419  	nameBytes := []byte(name)
   420  	if len(nameBytes) >= unix.IFNAMSIZ {
   421  		return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
   422  	}
   423  	copy(ifr[:], nameBytes)
   424  	*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
   425  
   426  	_, _, errno := unix.Syscall(
   427  		unix.SYS_IOCTL,
   428  		uintptr(nfd),
   429  		uintptr(unix.TUNSETIFF),
   430  		uintptr(unsafe.Pointer(&ifr[0])),
   431  	)
   432  	if errno != 0 {
   433  		return nil, errno
   434  	}
   435  	err = unix.SetNonblock(nfd, true)
   436  
   437  	// Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
   438  
   439  	fd := os.NewFile(uintptr(nfd), cloneDevicePath)
   440  	if err != nil {
   441  		return nil, err
   442  	}
   443  
   444  	return CreateTUNFromFile(fd, mtu)
   445  }
   446  
   447  func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
   448  	tun := &NativeTun{
   449  		tunFile:                 file,
   450  		events:                  make(chan Event, 5),
   451  		errors:                  make(chan error, 5),
   452  		statusListenersShutdown: make(chan struct{}),
   453  		nopi:                    false,
   454  	}
   455  
   456  	name, err := tun.Name()
   457  	if err != nil {
   458  		return nil, err
   459  	}
   460  
   461  	// start event listener
   462  
   463  	tun.index, err = getIFIndex(name)
   464  	if err != nil {
   465  		return nil, err
   466  	}
   467  
   468  	tun.netlinkSock, err = createNetlinkSocket()
   469  	if err != nil {
   470  		return nil, err
   471  	}
   472  	tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
   473  	if err != nil {
   474  		unix.Close(tun.netlinkSock)
   475  		return nil, err
   476  	}
   477  
   478  	tun.hackListenerClosed.Lock()
   479  	go tun.routineNetlinkListener()
   480  	go tun.routineHackListener() // cross namespace
   481  
   482  	err = tun.setMTU(mtu)
   483  	if err != nil {
   484  		unix.Close(tun.netlinkSock)
   485  		return nil, err
   486  	}
   487  
   488  	return tun, nil
   489  }
   490  
   491  func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
   492  	err := unix.SetNonblock(fd, true)
   493  	if err != nil {
   494  		return nil, "", err
   495  	}
   496  	file := os.NewFile(uintptr(fd), "/dev/tun")
   497  	tun := &NativeTun{
   498  		tunFile: file,
   499  		events:  make(chan Event, 5),
   500  		errors:  make(chan error, 5),
   501  		nopi:    true,
   502  	}
   503  	name, err := tun.Name()
   504  	if err != nil {
   505  		return nil, "", err
   506  	}
   507  	return tun, name, nil
   508  }