golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/driver/dll_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package driver
     7  
     8  import (
     9  	"fmt"
    10  	"sync"
    11  	"sync/atomic"
    12  	"unsafe"
    13  
    14  	"golang.org/x/sys/windows"
    15  )
    16  
    17  func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
    18  	return &lazyDLL{Name: name, onLoad: onLoad}
    19  }
    20  
    21  func (d *lazyDLL) NewProc(name string) *lazyProc {
    22  	return &lazyProc{dll: d, Name: name}
    23  }
    24  
    25  type lazyProc struct {
    26  	Name string
    27  	mu   sync.Mutex
    28  	dll  *lazyDLL
    29  	addr uintptr
    30  }
    31  
    32  func (p *lazyProc) Find() error {
    33  	if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
    34  		return nil
    35  	}
    36  	p.mu.Lock()
    37  	defer p.mu.Unlock()
    38  	if p.addr != 0 {
    39  		return nil
    40  	}
    41  
    42  	err := p.dll.Load()
    43  	if err != nil {
    44  		return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
    45  	}
    46  	addr, err := p.nameToAddr()
    47  	if err != nil {
    48  		return fmt.Errorf("Error getting %v address: %w", p.Name, err)
    49  	}
    50  
    51  	atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
    52  	return nil
    53  }
    54  
    55  func (p *lazyProc) Addr() uintptr {
    56  	err := p.Find()
    57  	if err != nil {
    58  		panic(err)
    59  	}
    60  	return p.addr
    61  }
    62  
    63  // Version returns the version of the driver DLL.
    64  func Version() string {
    65  	if modwireguard.Load() != nil {
    66  		return "unknown"
    67  	}
    68  	resInfo, err := windows.FindResource(modwireguard.Base, windows.ResourceID(1), windows.RT_VERSION)
    69  	if err != nil {
    70  		return "unknown"
    71  	}
    72  	data, err := windows.LoadResourceData(modwireguard.Base, resInfo)
    73  	if err != nil {
    74  		return "unknown"
    75  	}
    76  
    77  	var fixedInfo *windows.VS_FIXEDFILEINFO
    78  	fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
    79  	err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
    80  	if err != nil {
    81  		return "unknown"
    82  	}
    83  	version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff)
    84  	if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 {
    85  		version += fmt.Sprintf(".%d", nextNibble)
    86  	}
    87  	if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 {
    88  		version += fmt.Sprintf(".%d", nextNibble)
    89  	}
    90  	return version
    91  }