github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/tun/tun_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  /* Implementation of the TUN device interface for linux
     9   */
    10  
    11  import (
    12  	"bytes"
    13  	"errors"
    14  	"fmt"
    15  	"os"
    16  	"sync"
    17  	"syscall"
    18  	"time"
    19  	"unsafe"
    20  
    21  	"golang.org/x/net/ipv6"
    22  	"golang.org/x/sys/unix"
    23  
    24  	"github.com/liloew/wireguard-go/rwcancel"
    25  )
    26  
    27  const (
    28  	cloneDevicePath = "/dev/net/tun"
    29  	ifReqSize       = unix.IFNAMSIZ + 64
    30  )
    31  
    32  type NativeTun struct {
    33  	tunFile                 *os.File
    34  	index                   int32      // if index
    35  	errors                  chan error // async error handling
    36  	events                  chan Event // device related events
    37  	nopi                    bool       // the device was passed IFF_NO_PI
    38  	netlinkSock             int
    39  	netlinkCancel           *rwcancel.RWCancel
    40  	hackListenerClosed      sync.Mutex
    41  	statusListenersShutdown chan struct{}
    42  
    43  	closeOnce sync.Once
    44  
    45  	nameOnce  sync.Once // guards calling initNameCache, which sets following fields
    46  	nameCache string    // name of interface
    47  	nameErr   error
    48  }
    49  
    50  func (tun *NativeTun) File() *os.File {
    51  	return tun.tunFile
    52  }
    53  
    54  func (tun *NativeTun) routineHackListener() {
    55  	defer tun.hackListenerClosed.Unlock()
    56  	/* This is needed for the detection to work across network namespaces
    57  	 * If you are reading this and know a better method, please get in touch.
    58  	 */
    59  	last := 0
    60  	const (
    61  		up   = 1
    62  		down = 2
    63  	)
    64  	for {
    65  		sysconn, err := tun.tunFile.SyscallConn()
    66  		if err != nil {
    67  			return
    68  		}
    69  		err2 := sysconn.Control(func(fd uintptr) {
    70  			_, err = unix.Write(int(fd), nil)
    71  		})
    72  		if err2 != nil {
    73  			return
    74  		}
    75  		switch err {
    76  		case unix.EINVAL:
    77  			if last != up {
    78  				// If the tunnel is up, it reports that write() is
    79  				// allowed but we provided invalid data.
    80  				tun.events <- EventUp
    81  				last = up
    82  			}
    83  		case unix.EIO:
    84  			if last != down {
    85  				// If the tunnel is down, it reports that no I/O
    86  				// is possible, without checking our provided data.
    87  				tun.events <- EventDown
    88  				last = down
    89  			}
    90  		default:
    91  			return
    92  		}
    93  		select {
    94  		case <-time.After(time.Second):
    95  			// nothing
    96  		case <-tun.statusListenersShutdown:
    97  			return
    98  		}
    99  	}
   100  }
   101  
   102  func createNetlinkSocket() (int, error) {
   103  	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
   104  	if err != nil {
   105  		return -1, err
   106  	}
   107  	saddr := &unix.SockaddrNetlink{
   108  		Family: unix.AF_NETLINK,
   109  		Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
   110  	}
   111  	err = unix.Bind(sock, saddr)
   112  	if err != nil {
   113  		return -1, err
   114  	}
   115  	return sock, nil
   116  }
   117  
   118  func (tun *NativeTun) routineNetlinkListener() {
   119  	defer func() {
   120  		unix.Close(tun.netlinkSock)
   121  		tun.hackListenerClosed.Lock()
   122  		close(tun.events)
   123  		tun.netlinkCancel.Close()
   124  	}()
   125  
   126  	for msg := make([]byte, 1<<16); ; {
   127  		var err error
   128  		var msgn int
   129  		for {
   130  			msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
   131  			if err == nil || !rwcancel.RetryAfterError(err) {
   132  				break
   133  			}
   134  			if !tun.netlinkCancel.ReadyRead() {
   135  				tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
   136  				return
   137  			}
   138  		}
   139  		if err != nil {
   140  			tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
   141  			return
   142  		}
   143  
   144  		select {
   145  		case <-tun.statusListenersShutdown:
   146  			return
   147  		default:
   148  		}
   149  
   150  		wasEverUp := false
   151  		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
   152  
   153  			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
   154  
   155  			if int(hdr.Len) > len(remain) {
   156  				break
   157  			}
   158  
   159  			switch hdr.Type {
   160  			case unix.NLMSG_DONE:
   161  				remain = []byte{}
   162  
   163  			case unix.RTM_NEWLINK:
   164  				info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
   165  				remain = remain[hdr.Len:]
   166  
   167  				if info.Index != tun.index {
   168  					// not our interface
   169  					continue
   170  				}
   171  
   172  				if info.Flags&unix.IFF_RUNNING != 0 {
   173  					tun.events <- EventUp
   174  					wasEverUp = true
   175  				}
   176  
   177  				if info.Flags&unix.IFF_RUNNING == 0 {
   178  					// Don't emit EventDown before we've ever emitted EventUp.
   179  					// This avoids a startup race with HackListener, which
   180  					// might detect Up before we have finished reporting Down.
   181  					if wasEverUp {
   182  						tun.events <- EventDown
   183  					}
   184  				}
   185  
   186  				tun.events <- EventMTUUpdate
   187  
   188  			default:
   189  				remain = remain[hdr.Len:]
   190  			}
   191  		}
   192  	}
   193  }
   194  
   195  func getIFIndex(name string) (int32, error) {
   196  	fd, err := unix.Socket(
   197  		unix.AF_INET,
   198  		unix.SOCK_DGRAM,
   199  		0,
   200  	)
   201  	if err != nil {
   202  		return 0, err
   203  	}
   204  
   205  	defer unix.Close(fd)
   206  
   207  	var ifr [ifReqSize]byte
   208  	copy(ifr[:], name)
   209  	_, _, errno := unix.Syscall(
   210  		unix.SYS_IOCTL,
   211  		uintptr(fd),
   212  		uintptr(unix.SIOCGIFINDEX),
   213  		uintptr(unsafe.Pointer(&ifr[0])),
   214  	)
   215  
   216  	if errno != 0 {
   217  		return 0, errno
   218  	}
   219  
   220  	return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
   221  }
   222  
   223  func (tun *NativeTun) setMTU(n int) error {
   224  	name, err := tun.Name()
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	// open datagram socket
   230  	fd, err := unix.Socket(
   231  		unix.AF_INET,
   232  		unix.SOCK_DGRAM,
   233  		0,
   234  	)
   235  	if err != nil {
   236  		return err
   237  	}
   238  
   239  	defer unix.Close(fd)
   240  
   241  	// do ioctl call
   242  	var ifr [ifReqSize]byte
   243  	copy(ifr[:], name)
   244  	*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
   245  	_, _, errno := unix.Syscall(
   246  		unix.SYS_IOCTL,
   247  		uintptr(fd),
   248  		uintptr(unix.SIOCSIFMTU),
   249  		uintptr(unsafe.Pointer(&ifr[0])),
   250  	)
   251  
   252  	if errno != 0 {
   253  		return fmt.Errorf("failed to set MTU of TUN device: %w", errno)
   254  	}
   255  
   256  	return nil
   257  }
   258  
   259  func (tun *NativeTun) MTU() (int, error) {
   260  	name, err := tun.Name()
   261  	if err != nil {
   262  		return 0, err
   263  	}
   264  
   265  	// open datagram socket
   266  	fd, err := unix.Socket(
   267  		unix.AF_INET,
   268  		unix.SOCK_DGRAM,
   269  		0,
   270  	)
   271  	if err != nil {
   272  		return 0, err
   273  	}
   274  
   275  	defer unix.Close(fd)
   276  
   277  	// do ioctl call
   278  
   279  	var ifr [ifReqSize]byte
   280  	copy(ifr[:], name)
   281  	_, _, errno := unix.Syscall(
   282  		unix.SYS_IOCTL,
   283  		uintptr(fd),
   284  		uintptr(unix.SIOCGIFMTU),
   285  		uintptr(unsafe.Pointer(&ifr[0])),
   286  	)
   287  	if errno != 0 {
   288  		return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno)
   289  	}
   290  
   291  	return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
   292  }
   293  
   294  func (tun *NativeTun) Name() (string, error) {
   295  	tun.nameOnce.Do(tun.initNameCache)
   296  	return tun.nameCache, tun.nameErr
   297  }
   298  
   299  func (tun *NativeTun) initNameCache() {
   300  	tun.nameCache, tun.nameErr = tun.nameSlow()
   301  }
   302  
   303  func (tun *NativeTun) nameSlow() (string, error) {
   304  	sysconn, err := tun.tunFile.SyscallConn()
   305  	if err != nil {
   306  		return "", err
   307  	}
   308  	var ifr [ifReqSize]byte
   309  	var errno syscall.Errno
   310  	err = sysconn.Control(func(fd uintptr) {
   311  		_, _, errno = unix.Syscall(
   312  			unix.SYS_IOCTL,
   313  			fd,
   314  			uintptr(unix.TUNGETIFF),
   315  			uintptr(unsafe.Pointer(&ifr[0])),
   316  		)
   317  	})
   318  	if err != nil {
   319  		return "", fmt.Errorf("failed to get name of TUN device: %w", err)
   320  	}
   321  	if errno != 0 {
   322  		return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
   323  	}
   324  	name := ifr[:]
   325  	if i := bytes.IndexByte(name, 0); i != -1 {
   326  		name = name[:i]
   327  	}
   328  	return string(name), nil
   329  }
   330  
   331  func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
   332  	if tun.nopi {
   333  		buf = buf[offset:]
   334  	} else {
   335  		// reserve space for header
   336  		buf = buf[offset-4:]
   337  
   338  		// add packet information header
   339  		buf[0] = 0x00
   340  		buf[1] = 0x00
   341  		if buf[4]>>4 == ipv6.Version {
   342  			buf[2] = 0x86
   343  			buf[3] = 0xdd
   344  		} else {
   345  			buf[2] = 0x08
   346  			buf[3] = 0x00
   347  		}
   348  	}
   349  
   350  	n, err := tun.tunFile.Write(buf)
   351  	if errors.Is(err, syscall.EBADFD) {
   352  		err = os.ErrClosed
   353  	}
   354  	return n, err
   355  }
   356  
   357  func (tun *NativeTun) Flush() error {
   358  	// TODO: can flushing be implemented by buffering and using sendmmsg?
   359  	return nil
   360  }
   361  
   362  func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
   363  	select {
   364  	case err = <-tun.errors:
   365  	default:
   366  		if tun.nopi {
   367  			n, err = tun.tunFile.Read(buf[offset:])
   368  		} else {
   369  			buff := buf[offset-4:]
   370  			n, err = tun.tunFile.Read(buff[:])
   371  			if errors.Is(err, syscall.EBADFD) {
   372  				err = os.ErrClosed
   373  			}
   374  			if n < 4 {
   375  				n = 0
   376  			} else {
   377  				n -= 4
   378  			}
   379  		}
   380  	}
   381  	return
   382  }
   383  
   384  func (tun *NativeTun) Events() chan Event {
   385  	return tun.events
   386  }
   387  
   388  func (tun *NativeTun) Close() error {
   389  	var err1, err2 error
   390  	tun.closeOnce.Do(func() {
   391  		if tun.statusListenersShutdown != nil {
   392  			close(tun.statusListenersShutdown)
   393  			if tun.netlinkCancel != nil {
   394  				err1 = tun.netlinkCancel.Cancel()
   395  			}
   396  		} else if tun.events != nil {
   397  			close(tun.events)
   398  		}
   399  		err2 = tun.tunFile.Close()
   400  	})
   401  	if err1 != nil {
   402  		return err1
   403  	}
   404  	return err2
   405  }
   406  
   407  func CreateTUN(name string, mtu int, nopi bool) (Device, error) {
   408  	nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
   409  	if err != nil {
   410  		if os.IsNotExist(err) {
   411  			return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
   412  		}
   413  		return nil, err
   414  	}
   415  
   416  	var ifr [ifReqSize]byte
   417  	var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
   418  	if nopi {
   419  		flags |= unix.IFF_NO_PI
   420  	}
   421  	nameBytes := []byte(name)
   422  	if len(nameBytes) >= unix.IFNAMSIZ {
   423  		unix.Close(nfd)
   424  		return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
   425  	}
   426  	copy(ifr[:], nameBytes)
   427  	*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
   428  
   429  	_, _, errno := unix.Syscall(
   430  		unix.SYS_IOCTL,
   431  		uintptr(nfd),
   432  		uintptr(unix.TUNSETIFF),
   433  		uintptr(unsafe.Pointer(&ifr[0])),
   434  	)
   435  	if errno != 0 {
   436  		unix.Close(nfd)
   437  		return nil, errno
   438  	}
   439  
   440  	err = unix.SetNonblock(nfd, true)
   441  	if err != nil {
   442  		unix.Close(nfd)
   443  		return nil, err
   444  	}
   445  
   446  	// Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
   447  
   448  	fd := os.NewFile(uintptr(nfd), cloneDevicePath)
   449  	return CreateTUNFromFile(fd, mtu, nopi)
   450  }
   451  
   452  func CreateTUNFromFile(file *os.File, mtu int, nopi bool) (Device, error) {
   453  	tun := &NativeTun{
   454  		tunFile:                 file,
   455  		events:                  make(chan Event, 5),
   456  		errors:                  make(chan error, 5),
   457  		statusListenersShutdown: make(chan struct{}),
   458  		nopi:                    nopi,
   459  	}
   460  
   461  	name, err := tun.Name()
   462  	if err != nil {
   463  		return nil, err
   464  	}
   465  
   466  	// start event listener
   467  
   468  	tun.index, err = getIFIndex(name)
   469  	if err != nil {
   470  		return nil, err
   471  	}
   472  
   473  	tun.netlinkSock, err = createNetlinkSocket()
   474  	if err != nil {
   475  		return nil, err
   476  	}
   477  	tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
   478  	if err != nil {
   479  		unix.Close(tun.netlinkSock)
   480  		return nil, err
   481  	}
   482  
   483  	tun.hackListenerClosed.Lock()
   484  	go tun.routineNetlinkListener()
   485  	go tun.routineHackListener() // cross namespace
   486  
   487  	err = tun.setMTU(mtu)
   488  	if err != nil {
   489  		unix.Close(tun.netlinkSock)
   490  		return nil, err
   491  	}
   492  
   493  	return tun, nil
   494  }
   495  
   496  func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
   497  	err := unix.SetNonblock(fd, true)
   498  	if err != nil {
   499  		return nil, "", err
   500  	}
   501  	file := os.NewFile(uintptr(fd), "/dev/tun")
   502  	tun := &NativeTun{
   503  		tunFile: file,
   504  		events:  make(chan Event, 5),
   505  		errors:  make(chan error, 5),
   506  		nopi:    true,
   507  	}
   508  	name, err := tun.Name()
   509  	if err != nil {
   510  		return nil, "", err
   511  	}
   512  	return tun, name, nil
   513  }