github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/platform/mmap_windows.go (about)

     1  package platform
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"syscall"
     7  	"unsafe"
     8  )
     9  
    10  var (
    11  	kernel32           = syscall.NewLazyDLL("kernel32.dll")
    12  	procVirtualAlloc   = kernel32.NewProc("VirtualAlloc")
    13  	procVirtualProtect = kernel32.NewProc("VirtualProtect")
    14  	procVirtualFree    = kernel32.NewProc("VirtualFree")
    15  )
    16  
    17  const (
    18  	windows_MEM_COMMIT             uintptr = 0x00001000
    19  	windows_MEM_RELEASE            uintptr = 0x00008000
    20  	windows_PAGE_READWRITE         uintptr = 0x00000004
    21  	windows_PAGE_EXECUTE_READ      uintptr = 0x00000020
    22  	windows_PAGE_EXECUTE_READWRITE uintptr = 0x00000040
    23  )
    24  
    25  func munmapCodeSegment(code []byte) error {
    26  	return freeMemory(code)
    27  }
    28  
    29  // allocateMemory commits the memory region via the "VirtualAlloc" function.
    30  // See https://docs.microsoft.com/en-us/windows/win32/api/memoryapi/nf-memoryapi-virtualalloc
    31  func allocateMemory(size uintptr, protect uintptr) (uintptr, error) {
    32  	address := uintptr(0) // system determines where to allocate the region.
    33  	alloctype := windows_MEM_COMMIT
    34  	if r, _, err := procVirtualAlloc.Call(address, size, alloctype, protect); r == 0 {
    35  		return 0, fmt.Errorf("compiler: VirtualAlloc error: %w", ensureErr(err))
    36  	} else {
    37  		return r, nil
    38  	}
    39  }
    40  
    41  // freeMemory releases the memory region via the "VirtualFree" function.
    42  // See https://docs.microsoft.com/en-us/windows/win32/api/memoryapi/nf-memoryapi-virtualfree
    43  func freeMemory(code []byte) error {
    44  	address := unsafe.Pointer(&code[0])
    45  	size := uintptr(0) // size must be 0 because we're using MEM_RELEASE.
    46  	freetype := windows_MEM_RELEASE
    47  	if r, _, err := procVirtualFree.Call(uintptr(address), size, freetype); r == 0 {
    48  		return fmt.Errorf("compiler: VirtualFree error: %w", ensureErr(err))
    49  	}
    50  	return nil
    51  }
    52  
    53  func virtualProtect(address, size, newprotect uintptr, oldprotect *uint32) error {
    54  	if r, _, err := procVirtualProtect.Call(address, size, newprotect, uintptr(unsafe.Pointer(oldprotect))); r == 0 {
    55  		return fmt.Errorf("compiler: VirtualProtect error: %w", ensureErr(err))
    56  	}
    57  	return nil
    58  }
    59  
    60  func mmapCodeSegmentAMD64(size int) ([]byte, error) {
    61  	p, err := allocateMemory(uintptr(size), windows_PAGE_EXECUTE_READWRITE)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	var mem []byte
    67  	sh := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
    68  	sh.Data = p
    69  	sh.Len = size
    70  	sh.Cap = size
    71  	return mem, err
    72  }
    73  
    74  func mmapCodeSegmentARM64(size int) ([]byte, error) {
    75  	p, err := allocateMemory(uintptr(size), windows_PAGE_READWRITE)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	var mem []byte
    81  	sh := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
    82  	sh.Data = p
    83  	sh.Len = size
    84  	sh.Cap = size
    85  	return mem, nil
    86  }
    87  
    88  var old = uint32(windows_PAGE_READWRITE)
    89  
    90  func MprotectRX(b []byte) (err error) {
    91  	err = virtualProtect(uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), windows_PAGE_EXECUTE_READ, &old)
    92  	return
    93  }
    94  
    95  // ensureErr returns syscall.EINVAL when the input error is nil.
    96  //
    97  // We are supposed to use "GetLastError" which is more precise, but it is not safe to execute in goroutines. While
    98  // "GetLastError" is thread-local, goroutines are not pinned to threads.
    99  //
   100  // See https://docs.microsoft.com/en-us/windows/win32/api/errhandlingapi/nf-errhandlingapi-getlasterror
   101  func ensureErr(err error) error {
   102  	if err != nil {
   103  		return err
   104  	}
   105  	return syscall.EINVAL
   106  }