golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/driver/driver_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package driver
     7  
     8  import (
     9  	"log"
    10  	"runtime"
    11  	"syscall"
    12  	"unsafe"
    13  
    14  	"golang.org/x/sys/windows"
    15  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    16  )
    17  
    18  type loggerLevel int
    19  
    20  const (
    21  	logInfo loggerLevel = iota
    22  	logWarn
    23  	logErr
    24  )
    25  
    26  const AdapterNameMax = 128
    27  
    28  type Adapter struct {
    29  	handle           uintptr
    30  	lastGetGuessSize uint32
    31  }
    32  
    33  var (
    34  	modwireguard                         = newLazyDLL("wireguard.dll", setupLogger)
    35  	procWireGuardCreateAdapter           = modwireguard.NewProc("WireGuardCreateAdapter")
    36  	procWireGuardOpenAdapter             = modwireguard.NewProc("WireGuardOpenAdapter")
    37  	procWireGuardCloseAdapter            = modwireguard.NewProc("WireGuardCloseAdapter")
    38  	procWireGuardDeleteDriver            = modwireguard.NewProc("WireGuardDeleteDriver")
    39  	procWireGuardGetAdapterLUID          = modwireguard.NewProc("WireGuardGetAdapterLUID")
    40  	procWireGuardGetRunningDriverVersion = modwireguard.NewProc("WireGuardGetRunningDriverVersion")
    41  	procWireGuardSetAdapterLogging       = modwireguard.NewProc("WireGuardSetAdapterLogging")
    42  )
    43  
    44  type TimestampedWriter interface {
    45  	WriteWithTimestamp(p []byte, ts int64) (n int, err error)
    46  }
    47  
    48  func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int {
    49  	if tw, ok := log.Default().Writer().(TimestampedWriter); ok {
    50  		tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100)
    51  	} else {
    52  		log.Println(windows.UTF16PtrToString(msg))
    53  	}
    54  	return 0
    55  }
    56  
    57  func setupLogger(dll *lazyDLL) {
    58  	var callback uintptr
    59  	if runtime.GOARCH == "386" {
    60  		callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
    61  			return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
    62  		})
    63  	} else if runtime.GOARCH == "arm" {
    64  		callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int {
    65  			return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
    66  		})
    67  	} else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
    68  		callback = windows.NewCallback(logMessage)
    69  	}
    70  	syscall.SyscallN(dll.NewProc("WireGuardSetLogger").Addr(), callback)
    71  }
    72  
    73  func closeAdapter(wireguard *Adapter) {
    74  	syscall.SyscallN(procWireGuardCloseAdapter.Addr(), wireguard.handle)
    75  }
    76  
    77  // CreateAdapter creates a WireGuard adapter. name is the cosmetic name of the adapter.
    78  // tunnelType represents the type of adapter and should be "WireGuard". requestedGUID is
    79  // the GUID of the created network adapter, which then influences NLA generation
    80  // deterministically. If it is set to nil, the GUID is chosen by the system at random,
    81  // and hence a new NLA entry is created for each new adapter.
    82  func CreateAdapter(name, tunnelType string, requestedGUID *windows.GUID) (wireguard *Adapter, err error) {
    83  	var name16 *uint16
    84  	name16, err = windows.UTF16PtrFromString(name)
    85  	if err != nil {
    86  		return
    87  	}
    88  	var tunnelType16 *uint16
    89  	tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
    90  	if err != nil {
    91  		return
    92  	}
    93  	r0, _, e1 := syscall.SyscallN(procWireGuardCreateAdapter.Addr(), uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
    94  	if r0 == 0 {
    95  		err = e1
    96  		return
    97  	}
    98  	wireguard = &Adapter{handle: r0}
    99  	runtime.SetFinalizer(wireguard, closeAdapter)
   100  	return
   101  }
   102  
   103  // OpenAdapter opens an existing WireGuard adapter by name.
   104  func OpenAdapter(name string) (wireguard *Adapter, err error) {
   105  	var name16 *uint16
   106  	name16, err = windows.UTF16PtrFromString(name)
   107  	if err != nil {
   108  		return
   109  	}
   110  	r0, _, e1 := syscall.SyscallN(procWireGuardOpenAdapter.Addr(), uintptr(unsafe.Pointer(name16)))
   111  	if r0 == 0 {
   112  		err = e1
   113  		return
   114  	}
   115  	wireguard = &Adapter{handle: r0}
   116  	runtime.SetFinalizer(wireguard, closeAdapter)
   117  	return
   118  }
   119  
   120  // Close closes a WireGuard adapter.
   121  func (wireguard *Adapter) Close() (err error) {
   122  	runtime.SetFinalizer(wireguard, nil)
   123  	r1, _, e1 := syscall.SyscallN(procWireGuardCloseAdapter.Addr(), wireguard.handle)
   124  	if r1 == 0 {
   125  		err = e1
   126  	}
   127  	return
   128  }
   129  
   130  // Uninstall removes the driver from the system if no drivers are currently in use.
   131  func Uninstall() (err error) {
   132  	r1, _, e1 := syscall.SyscallN(procWireGuardDeleteDriver.Addr())
   133  	if r1 == 0 {
   134  		err = e1
   135  	}
   136  	return
   137  }
   138  
   139  type AdapterLogState uint32
   140  
   141  const (
   142  	AdapterLogOff          AdapterLogState = 0
   143  	AdapterLogOn           AdapterLogState = 1
   144  	AdapterLogOnWithPrefix AdapterLogState = 2
   145  )
   146  
   147  // SetLogging enables or disables logging on the WireGuard adapter.
   148  func (wireguard *Adapter) SetLogging(logState AdapterLogState) (err error) {
   149  	r1, _, e1 := syscall.SyscallN(procWireGuardSetAdapterLogging.Addr(), wireguard.handle, uintptr(logState))
   150  	if r1 == 0 {
   151  		err = e1
   152  	}
   153  	return
   154  }
   155  
   156  // RunningVersion returns the version of the loaded driver.
   157  func RunningVersion() (version uint32, err error) {
   158  	r0, _, e1 := syscall.SyscallN(procWireGuardGetRunningDriverVersion.Addr())
   159  	version = uint32(r0)
   160  	if version == 0 {
   161  		err = e1
   162  	}
   163  	return
   164  }
   165  
   166  // LUID returns the LUID of the adapter.
   167  func (wireguard *Adapter) LUID() (luid winipcfg.LUID) {
   168  	syscall.SyscallN(procWireGuardGetAdapterLUID.Addr(), wireguard.handle, uintptr(unsafe.Pointer(&luid)))
   169  	return
   170  }