golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/conf/filewriter_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conf
     7  
     8  import (
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"path/filepath"
    12  	"sync/atomic"
    13  	"unsafe"
    14  
    15  	"golang.org/x/sys/windows"
    16  )
    17  
    18  var encryptedFileSd unsafe.Pointer
    19  
    20  func randomFileName() string {
    21  	var randBytes [32]byte
    22  	_, err := rand.Read(randBytes[:])
    23  	if err != nil {
    24  		panic(err)
    25  	}
    26  	return hex.EncodeToString(randBytes[:]) + ".tmp"
    27  }
    28  
    29  func writeLockedDownFile(destination string, overwrite bool, contents []byte) error {
    30  	var err error
    31  	sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{}))}
    32  	sa.SecurityDescriptor = (*windows.SECURITY_DESCRIPTOR)(atomic.LoadPointer(&encryptedFileSd))
    33  	if sa.SecurityDescriptor == nil {
    34  		sa.SecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;;FA;;;SY)(A;;SD;;;BA)")
    35  		if err != nil {
    36  			return err
    37  		}
    38  		atomic.StorePointer(&encryptedFileSd, unsafe.Pointer(sa.SecurityDescriptor))
    39  	}
    40  	destination16, err := windows.UTF16FromString(destination)
    41  	if err != nil {
    42  		return err
    43  	}
    44  	tmpDestination := filepath.Join(filepath.Dir(destination), randomFileName())
    45  	tmpDestination16, err := windows.UTF16PtrFromString(tmpDestination)
    46  	if err != nil {
    47  		return err
    48  	}
    49  	handle, err := windows.CreateFile(tmpDestination16, windows.GENERIC_WRITE|windows.DELETE, windows.FILE_SHARE_READ, sa, windows.CREATE_ALWAYS, windows.FILE_ATTRIBUTE_NORMAL, 0)
    50  	if err != nil {
    51  		return err
    52  	}
    53  	defer windows.CloseHandle(handle)
    54  	deleteIt := func() {
    55  		yes := byte(1)
    56  		windows.SetFileInformationByHandle(handle, windows.FileDispositionInfo, &yes, 1)
    57  	}
    58  	n, err := windows.Write(handle, contents)
    59  	if err != nil {
    60  		deleteIt()
    61  		return err
    62  	}
    63  	if n != len(contents) {
    64  		deleteIt()
    65  		return windows.ERROR_IO_INCOMPLETE
    66  	}
    67  	fileRenameInfo := &struct {
    68  		replaceIfExists byte
    69  		rootDirectory   windows.Handle
    70  		fileNameLength  uint32
    71  		fileName        [windows.MAX_PATH]uint16
    72  	}{replaceIfExists: func() byte {
    73  		if overwrite {
    74  			return 1
    75  		} else {
    76  			return 0
    77  		}
    78  	}(), fileNameLength: uint32(len(destination16) - 1)}
    79  	if len(destination16) > len(fileRenameInfo.fileName) {
    80  		deleteIt()
    81  		return windows.ERROR_BUFFER_OVERFLOW
    82  	}
    83  	copy(fileRenameInfo.fileName[:], destination16[:])
    84  	err = windows.SetFileInformationByHandle(handle, windows.FileRenameInfo, (*byte)(unsafe.Pointer(fileRenameInfo)), uint32(unsafe.Sizeof(*fileRenameInfo)))
    85  	if err != nil {
    86  		deleteIt()
    87  		return err
    88  	}
    89  	return nil
    90  }