github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/tun/tun_openbsd.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"os"
    14  	"sync"
    15  	"syscall"
    16  	"unsafe"
    17  
    18  	"golang.org/x/sys/unix"
    19  )
    20  
    21  // Structure for iface mtu get/set ioctls
    22  type ifreq_mtu struct {
    23  	Name [unix.IFNAMSIZ]byte
    24  	MTU  uint32
    25  	Pad0 [12]byte
    26  }
    27  
    28  const _TUNSIFMODE = 0x8004745d
    29  
    30  type NativeTun struct {
    31  	name        string
    32  	tunFile     *os.File
    33  	events      chan Event
    34  	errors      chan error
    35  	routeSocket int
    36  	closeOnce   sync.Once
    37  }
    38  
    39  func (tun *NativeTun) routineRouteListener(tunIfindex int) {
    40  	var (
    41  		statusUp  bool
    42  		statusMTU int
    43  	)
    44  
    45  	defer close(tun.events)
    46  
    47  	check := func() bool {
    48  		iface, err := net.InterfaceByIndex(tunIfindex)
    49  		if err != nil {
    50  			tun.errors <- err
    51  			return true
    52  		}
    53  
    54  		// Up / Down event
    55  		up := (iface.Flags & net.FlagUp) != 0
    56  		if up != statusUp && up {
    57  			tun.events <- EventUp
    58  		}
    59  		if up != statusUp && !up {
    60  			tun.events <- EventDown
    61  		}
    62  		statusUp = up
    63  
    64  		// MTU changes
    65  		if iface.MTU != statusMTU {
    66  			tun.events <- EventMTUUpdate
    67  		}
    68  		statusMTU = iface.MTU
    69  		return false
    70  	}
    71  
    72  	if check() {
    73  		return
    74  	}
    75  
    76  	data := make([]byte, os.Getpagesize())
    77  	for {
    78  		n, err := unix.Read(tun.routeSocket, data)
    79  		if err != nil {
    80  			if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
    81  				continue
    82  			}
    83  			tun.errors <- err
    84  			return
    85  		}
    86  
    87  		if n < 8 {
    88  			continue
    89  		}
    90  
    91  		if data[3 /* type */] != unix.RTM_IFINFO {
    92  			continue
    93  		}
    94  		ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
    95  		if ifindex != tunIfindex {
    96  			continue
    97  		}
    98  		if check() {
    99  			return
   100  		}
   101  	}
   102  }
   103  
   104  func CreateTUN(name string, mtu int) (Device, error) {
   105  	ifIndex := -1
   106  	if name != "tun" {
   107  		_, err := fmt.Sscanf(name, "tun%d", &ifIndex)
   108  		if err != nil || ifIndex < 0 {
   109  			return nil, fmt.Errorf("Interface name must be tun[0-9]*")
   110  		}
   111  	}
   112  
   113  	var tunfile *os.File
   114  	var err error
   115  
   116  	if ifIndex != -1 {
   117  		tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
   118  	} else {
   119  		for ifIndex = 0; ifIndex < 256; ifIndex++ {
   120  			tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
   121  			if err == nil || !errors.Is(err, syscall.EBUSY) {
   122  				break
   123  			}
   124  		}
   125  	}
   126  
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	tun, err := CreateTUNFromFile(tunfile, mtu)
   132  
   133  	if err == nil && name == "tun" {
   134  		fname := os.Getenv("WG_TUN_NAME_FILE")
   135  		if fname != "" {
   136  			os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
   137  		}
   138  	}
   139  
   140  	return tun, err
   141  }
   142  
   143  func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
   144  	tun := &NativeTun{
   145  		tunFile: file,
   146  		events:  make(chan Event, 10),
   147  		errors:  make(chan error, 1),
   148  	}
   149  
   150  	name, err := tun.Name()
   151  	if err != nil {
   152  		tun.tunFile.Close()
   153  		return nil, err
   154  	}
   155  
   156  	tunIfindex, err := func() (int, error) {
   157  		iface, err := net.InterfaceByName(name)
   158  		if err != nil {
   159  			return -1, err
   160  		}
   161  		return iface.Index, nil
   162  	}()
   163  	if err != nil {
   164  		tun.tunFile.Close()
   165  		return nil, err
   166  	}
   167  
   168  	tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
   169  	if err != nil {
   170  		tun.tunFile.Close()
   171  		return nil, err
   172  	}
   173  
   174  	go tun.routineRouteListener(tunIfindex)
   175  
   176  	currentMTU, err := tun.MTU()
   177  	if err != nil || currentMTU != mtu {
   178  		err = tun.setMTU(mtu)
   179  		if err != nil {
   180  			tun.Close()
   181  			return nil, err
   182  		}
   183  	}
   184  
   185  	return tun, nil
   186  }
   187  
   188  func (tun *NativeTun) Name() (string, error) {
   189  	gostat, err := tun.tunFile.Stat()
   190  	if err != nil {
   191  		tun.name = ""
   192  		return "", err
   193  	}
   194  	stat := gostat.Sys().(*syscall.Stat_t)
   195  	tun.name = fmt.Sprintf("tun%d", stat.Rdev%256)
   196  	return tun.name, nil
   197  }
   198  
   199  func (tun *NativeTun) File() *os.File {
   200  	return tun.tunFile
   201  }
   202  
   203  func (tun *NativeTun) Events() <-chan Event {
   204  	return tun.events
   205  }
   206  
   207  func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
   208  	select {
   209  	case err := <-tun.errors:
   210  		return 0, err
   211  	default:
   212  		buf := bufs[0][offset-4:]
   213  		n, err := tun.tunFile.Read(buf[:])
   214  		if n < 4 {
   215  			return 0, err
   216  		}
   217  		sizes[0] = n - 4
   218  		return 1, err
   219  	}
   220  }
   221  
   222  func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
   223  	if offset < 4 {
   224  		return 0, io.ErrShortBuffer
   225  	}
   226  	for i, buf := range bufs {
   227  		buf = buf[offset-4:]
   228  		buf[0] = 0x00
   229  		buf[1] = 0x00
   230  		buf[2] = 0x00
   231  		switch buf[4] >> 4 {
   232  		case 4:
   233  			buf[3] = unix.AF_INET
   234  		case 6:
   235  			buf[3] = unix.AF_INET6
   236  		default:
   237  			return i, unix.EAFNOSUPPORT
   238  		}
   239  		if _, err := tun.tunFile.Write(buf); err != nil {
   240  			return i, err
   241  		}
   242  	}
   243  	return len(bufs), nil
   244  }
   245  
   246  func (tun *NativeTun) Close() error {
   247  	var err1, err2 error
   248  	tun.closeOnce.Do(func() {
   249  		err1 = tun.tunFile.Close()
   250  		if tun.routeSocket != -1 {
   251  			unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
   252  			err2 = unix.Close(tun.routeSocket)
   253  			tun.routeSocket = -1
   254  		} else if tun.events != nil {
   255  			close(tun.events)
   256  		}
   257  	})
   258  	if err1 != nil {
   259  		return err1
   260  	}
   261  	return err2
   262  }
   263  
   264  func (tun *NativeTun) setMTU(n int) error {
   265  	// open datagram socket
   266  
   267  	var fd int
   268  
   269  	fd, err := unix.Socket(
   270  		unix.AF_INET,
   271  		unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
   272  		0,
   273  	)
   274  	if err != nil {
   275  		return err
   276  	}
   277  
   278  	defer unix.Close(fd)
   279  
   280  	// do ioctl call
   281  
   282  	var ifr ifreq_mtu
   283  	copy(ifr.Name[:], tun.name)
   284  	ifr.MTU = uint32(n)
   285  
   286  	_, _, errno := unix.Syscall(
   287  		unix.SYS_IOCTL,
   288  		uintptr(fd),
   289  		uintptr(unix.SIOCSIFMTU),
   290  		uintptr(unsafe.Pointer(&ifr)),
   291  	)
   292  
   293  	if errno != 0 {
   294  		return fmt.Errorf("failed to set MTU on %s", tun.name)
   295  	}
   296  
   297  	return nil
   298  }
   299  
   300  func (tun *NativeTun) MTU() (int, error) {
   301  	// open datagram socket
   302  
   303  	fd, err := unix.Socket(
   304  		unix.AF_INET,
   305  		unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
   306  		0,
   307  	)
   308  	if err != nil {
   309  		return 0, err
   310  	}
   311  
   312  	defer unix.Close(fd)
   313  
   314  	// do ioctl call
   315  	var ifr ifreq_mtu
   316  	copy(ifr.Name[:], tun.name)
   317  
   318  	_, _, errno := unix.Syscall(
   319  		unix.SYS_IOCTL,
   320  		uintptr(fd),
   321  		uintptr(unix.SIOCGIFMTU),
   322  		uintptr(unsafe.Pointer(&ifr)),
   323  	)
   324  	if errno != 0 {
   325  		return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
   326  	}
   327  
   328  	return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
   329  }
   330  
   331  func (tun *NativeTun) BatchSize() int {
   332  	return 1
   333  }