golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/conf/dpapi/dpapi_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package dpapi
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"runtime"
    12  	"unsafe"
    13  
    14  	"golang.org/x/sys/windows"
    15  )
    16  
    17  func bytesToBlob(bytes []byte) *windows.DataBlob {
    18  	blob := &windows.DataBlob{Size: uint32(len(bytes))}
    19  	if len(bytes) > 0 {
    20  		blob.Data = &bytes[0]
    21  	}
    22  	return blob
    23  }
    24  
    25  func Encrypt(data []byte, name string) ([]byte, error) {
    26  	out := windows.DataBlob{}
    27  	err := windows.CryptProtectData(bytesToBlob(data), windows.StringToUTF16Ptr(name), nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
    28  	if err != nil {
    29  		return nil, fmt.Errorf("unable to encrypt DPAPI protected data: %w", err)
    30  	}
    31  	ret := make([]byte, out.Size)
    32  	copy(ret, unsafe.Slice(out.Data, out.Size))
    33  	windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
    34  	return ret, nil
    35  }
    36  
    37  func Decrypt(data []byte, name string) ([]byte, error) {
    38  	out := windows.DataBlob{}
    39  	var outName *uint16
    40  	utf16Name, err := windows.UTF16PtrFromString(name)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	err = windows.CryptUnprotectData(bytesToBlob(data), &outName, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
    45  	if err != nil {
    46  		return nil, fmt.Errorf("unable to decrypt DPAPI protected data: %w", err)
    47  	}
    48  	ret := make([]byte, out.Size)
    49  	copy(ret, unsafe.Slice(out.Data, out.Size))
    50  	windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
    51  
    52  	// Note: this ridiculous open-coded strcmp is not constant time.
    53  	different := false
    54  	a := outName
    55  	b := utf16Name
    56  	for {
    57  		if *a != *b {
    58  			different = true
    59  			break
    60  		}
    61  		if *a == 0 || *b == 0 {
    62  			break
    63  		}
    64  		a = (*uint16)(unsafe.Add(unsafe.Pointer(a), 2))
    65  		b = (*uint16)(unsafe.Add(unsafe.Pointer(b), 2))
    66  	}
    67  	runtime.KeepAlive(utf16Name)
    68  	windows.LocalFree(windows.Handle(unsafe.Pointer(outName)))
    69  
    70  	if different {
    71  		return nil, errors.New("input name does not match the stored name")
    72  	}
    73  
    74  	return ret, nil
    75  }