github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/tun/tun_openbsd.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"os"
    13  	"sync"
    14  	"syscall"
    15  	"unsafe"
    16  
    17  	"golang.org/x/net/ipv6"
    18  	"golang.org/x/sys/unix"
    19  )
    20  
    21  // Structure for iface mtu get/set ioctls
    22  type ifreq_mtu struct {
    23  	Name [unix.IFNAMSIZ]byte
    24  	MTU  uint32
    25  	Pad0 [12]byte
    26  }
    27  
    28  const _TUNSIFMODE = 0x8004745d
    29  
    30  type NativeTun struct {
    31  	name        string
    32  	tunFile     *os.File
    33  	events      chan Event
    34  	errors      chan error
    35  	routeSocket int
    36  	closeOnce   sync.Once
    37  }
    38  
    39  func (tun *NativeTun) routineRouteListener(tunIfindex int) {
    40  	var (
    41  		statusUp  bool
    42  		statusMTU int
    43  	)
    44  
    45  	defer close(tun.events)
    46  
    47  	check := func() bool {
    48  		iface, err := net.InterfaceByIndex(tunIfindex)
    49  		if err != nil {
    50  			tun.errors <- err
    51  			return true
    52  		}
    53  
    54  		// Up / Down event
    55  		up := (iface.Flags & net.FlagUp) != 0
    56  		if up != statusUp && up {
    57  			tun.events <- EventUp
    58  		}
    59  		if up != statusUp && !up {
    60  			tun.events <- EventDown
    61  		}
    62  		statusUp = up
    63  
    64  		// MTU changes
    65  		if iface.MTU != statusMTU {
    66  			tun.events <- EventMTUUpdate
    67  		}
    68  		statusMTU = iface.MTU
    69  		return false
    70  	}
    71  
    72  	if check() {
    73  		return
    74  	}
    75  
    76  	data := make([]byte, os.Getpagesize())
    77  	for {
    78  		n, err := unix.Read(tun.routeSocket, data)
    79  		if err != nil {
    80  			if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
    81  				continue
    82  			}
    83  			tun.errors <- err
    84  			return
    85  		}
    86  
    87  		if n < 8 {
    88  			continue
    89  		}
    90  
    91  		if data[3 /* type */] != unix.RTM_IFINFO {
    92  			continue
    93  		}
    94  		ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
    95  		if ifindex != tunIfindex {
    96  			continue
    97  		}
    98  		if check() {
    99  			return
   100  		}
   101  	}
   102  }
   103  
   104  func CreateTUN(name string, mtu int, nopi bool) (Device, error) {
   105  	ifIndex := -1
   106  	if name != "tun" {
   107  		_, err := fmt.Sscanf(name, "tun%d", &ifIndex)
   108  		if err != nil || ifIndex < 0 {
   109  			return nil, fmt.Errorf("Interface name must be tun[0-9]*")
   110  		}
   111  	}
   112  
   113  	var tunfile *os.File
   114  	var err error
   115  
   116  	if ifIndex != -1 {
   117  		tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
   118  	} else {
   119  		for ifIndex = 0; ifIndex < 256; ifIndex++ {
   120  			tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
   121  			if err == nil || !errors.Is(err, syscall.EBUSY) {
   122  				break
   123  			}
   124  		}
   125  	}
   126  
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	tun, err := CreateTUNFromFile(tunfile, mtu, nopi)
   132  
   133  	if err == nil && name == "tun" {
   134  		fname := os.Getenv("WG_TUN_NAME_FILE")
   135  		if fname != "" {
   136  			os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
   137  		}
   138  	}
   139  
   140  	return tun, err
   141  }
   142  
   143  func CreateTUNFromFile(file *os.File, mtu int, nopi bool) (Device, error) {
   144  	tun := &NativeTun{
   145  		tunFile: file,
   146  		events:  make(chan Event, 10),
   147  		errors:  make(chan error, 1),
   148  	}
   149  
   150  	name, err := tun.Name()
   151  	if err != nil {
   152  		tun.tunFile.Close()
   153  		return nil, err
   154  	}
   155  
   156  	tunIfindex, err := func() (int, error) {
   157  		iface, err := net.InterfaceByName(name)
   158  		if err != nil {
   159  			return -1, err
   160  		}
   161  		return iface.Index, nil
   162  	}()
   163  	if err != nil {
   164  		tun.tunFile.Close()
   165  		return nil, err
   166  	}
   167  
   168  	tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
   169  	if err != nil {
   170  		tun.tunFile.Close()
   171  		return nil, err
   172  	}
   173  
   174  	go tun.routineRouteListener(tunIfindex)
   175  
   176  	currentMTU, err := tun.MTU()
   177  	if err != nil || currentMTU != mtu {
   178  		err = tun.setMTU(mtu)
   179  		if err != nil {
   180  			tun.Close()
   181  			return nil, err
   182  		}
   183  	}
   184  
   185  	return tun, nil
   186  }
   187  
   188  func (tun *NativeTun) Name() (string, error) {
   189  	gostat, err := tun.tunFile.Stat()
   190  	if err != nil {
   191  		tun.name = ""
   192  		return "", err
   193  	}
   194  	stat := gostat.Sys().(*syscall.Stat_t)
   195  	tun.name = fmt.Sprintf("tun%d", stat.Rdev%256)
   196  	return tun.name, nil
   197  }
   198  
   199  func (tun *NativeTun) File() *os.File {
   200  	return tun.tunFile
   201  }
   202  
   203  func (tun *NativeTun) Events() chan Event {
   204  	return tun.events
   205  }
   206  
   207  func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
   208  	select {
   209  	case err := <-tun.errors:
   210  		return 0, err
   211  	default:
   212  		buff := buff[offset-4:]
   213  		n, err := tun.tunFile.Read(buff[:])
   214  		if n < 4 {
   215  			return 0, err
   216  		}
   217  		return n - 4, err
   218  	}
   219  }
   220  
   221  func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
   222  	// reserve space for header
   223  
   224  	buff = buff[offset-4:]
   225  
   226  	// add packet information header
   227  
   228  	buff[0] = 0x00
   229  	buff[1] = 0x00
   230  	buff[2] = 0x00
   231  
   232  	if buff[4]>>4 == ipv6.Version {
   233  		buff[3] = unix.AF_INET6
   234  	} else {
   235  		buff[3] = unix.AF_INET
   236  	}
   237  
   238  	// write
   239  
   240  	return tun.tunFile.Write(buff)
   241  }
   242  
   243  func (tun *NativeTun) Flush() error {
   244  	// TODO: can flushing be implemented by buffering and using sendmmsg?
   245  	return nil
   246  }
   247  
   248  func (tun *NativeTun) Close() error {
   249  	var err1, err2 error
   250  	tun.closeOnce.Do(func() {
   251  		err1 = tun.tunFile.Close()
   252  		if tun.routeSocket != -1 {
   253  			unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
   254  			err2 = unix.Close(tun.routeSocket)
   255  			tun.routeSocket = -1
   256  		} else if tun.events != nil {
   257  			close(tun.events)
   258  		}
   259  	})
   260  	if err1 != nil {
   261  		return err1
   262  	}
   263  	return err2
   264  }
   265  
   266  func (tun *NativeTun) setMTU(n int) error {
   267  	// open datagram socket
   268  
   269  	var fd int
   270  
   271  	fd, err := unix.Socket(
   272  		unix.AF_INET,
   273  		unix.SOCK_DGRAM,
   274  		0,
   275  	)
   276  	if err != nil {
   277  		return err
   278  	}
   279  
   280  	defer unix.Close(fd)
   281  
   282  	// do ioctl call
   283  
   284  	var ifr ifreq_mtu
   285  	copy(ifr.Name[:], tun.name)
   286  	ifr.MTU = uint32(n)
   287  
   288  	_, _, errno := unix.Syscall(
   289  		unix.SYS_IOCTL,
   290  		uintptr(fd),
   291  		uintptr(unix.SIOCSIFMTU),
   292  		uintptr(unsafe.Pointer(&ifr)),
   293  	)
   294  
   295  	if errno != 0 {
   296  		return fmt.Errorf("failed to set MTU on %s", tun.name)
   297  	}
   298  
   299  	return nil
   300  }
   301  
   302  func (tun *NativeTun) MTU() (int, error) {
   303  	// open datagram socket
   304  
   305  	fd, err := unix.Socket(
   306  		unix.AF_INET,
   307  		unix.SOCK_DGRAM,
   308  		0,
   309  	)
   310  	if err != nil {
   311  		return 0, err
   312  	}
   313  
   314  	defer unix.Close(fd)
   315  
   316  	// do ioctl call
   317  	var ifr ifreq_mtu
   318  	copy(ifr.Name[:], tun.name)
   319  
   320  	_, _, errno := unix.Syscall(
   321  		unix.SYS_IOCTL,
   322  		uintptr(fd),
   323  		uintptr(unix.SIOCGIFMTU),
   324  		uintptr(unsafe.Pointer(&ifr)),
   325  	)
   326  	if errno != 0 {
   327  		return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
   328  	}
   329  
   330  	return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
   331  }