golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/tunnel/firewall/helpers.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package firewall
     7  
     8  import (
     9  	"fmt"
    10  	"os"
    11  	"runtime"
    12  	"syscall"
    13  	"unsafe"
    14  
    15  	"golang.org/x/sys/windows"
    16  )
    17  
    18  func runTransaction(session uintptr, operation wfpObjectInstaller) error {
    19  	err := fwpmTransactionBegin0(session, 0)
    20  	if err != nil {
    21  		return wrapErr(err)
    22  	}
    23  
    24  	err = operation(session)
    25  	if err != nil {
    26  		fwpmTransactionAbort0(session)
    27  		return wrapErr(err)
    28  	}
    29  
    30  	err = fwpmTransactionCommit0(session)
    31  	if err != nil {
    32  		fwpmTransactionAbort0(session)
    33  		return wrapErr(err)
    34  	}
    35  
    36  	return nil
    37  }
    38  
    39  func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
    40  	namePtr, err := windows.UTF16PtrFromString(name)
    41  	if err != nil {
    42  		return nil, wrapErr(err)
    43  	}
    44  
    45  	descriptionPtr, err := windows.UTF16PtrFromString(description)
    46  	if err != nil {
    47  		return nil, wrapErr(err)
    48  	}
    49  
    50  	return &wtFwpmDisplayData0{
    51  		name:        namePtr,
    52  		description: descriptionPtr,
    53  	}, nil
    54  }
    55  
    56  func filterWeight(weight uint8) wtFwpValue0 {
    57  	return wtFwpValue0{
    58  		_type: cFWP_UINT8,
    59  		value: uintptr(weight),
    60  	}
    61  }
    62  
    63  func wrapErr(err error) error {
    64  	if _, ok := err.(syscall.Errno); !ok {
    65  		return err
    66  	}
    67  	_, file, line, ok := runtime.Caller(1)
    68  	if !ok {
    69  		return fmt.Errorf("Firewall error at unknown location: %w", err)
    70  	}
    71  	return fmt.Errorf("Firewall error at %s:%d: %w", file, line, err)
    72  }
    73  
    74  func getCurrentProcessSecurityDescriptor() (*windows.SECURITY_DESCRIPTOR, error) {
    75  	var processToken windows.Token
    76  	err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &processToken)
    77  	if err != nil {
    78  		return nil, wrapErr(err)
    79  	}
    80  	defer processToken.Close()
    81  	gs, err := processToken.GetTokenGroups()
    82  	if err != nil {
    83  		return nil, wrapErr(err)
    84  	}
    85  	var sid *windows.SID
    86  	for _, g := range gs.AllGroups() {
    87  		if g.Attributes != windows.SE_GROUP_ENABLED|windows.SE_GROUP_ENABLED_BY_DEFAULT|windows.SE_GROUP_OWNER {
    88  			continue
    89  		}
    90  		// We could be checking != 6, but hopefully Microsoft will update
    91  		// RtlCreateServiceSid to use SHA2, which will then likely bump
    92  		// this up. So instead just roll with a minimum.
    93  		if !g.Sid.IsValid() || g.Sid.IdentifierAuthority() != windows.SECURITY_NT_AUTHORITY || g.Sid.SubAuthorityCount() < 6 || g.Sid.SubAuthority(0) != 80 {
    94  			continue
    95  		}
    96  		sid = g.Sid
    97  		break
    98  	}
    99  	if sid == nil {
   100  		return nil, wrapErr(windows.ERROR_NO_SUCH_GROUP)
   101  	}
   102  
   103  	access := []windows.EXPLICIT_ACCESS{{
   104  		AccessPermissions: cFWP_ACTRL_MATCH_FILTER,
   105  		AccessMode:        windows.GRANT_ACCESS,
   106  		Trustee: windows.TRUSTEE{
   107  			TrusteeForm:  windows.TRUSTEE_IS_SID,
   108  			TrusteeType:  windows.TRUSTEE_IS_GROUP,
   109  			TrusteeValue: windows.TrusteeValueFromSID(sid),
   110  		},
   111  	}}
   112  	dacl, err := windows.ACLFromEntries(access, nil)
   113  	if err != nil {
   114  		return nil, wrapErr(err)
   115  	}
   116  	sd, err := windows.NewSecurityDescriptor()
   117  	if err != nil {
   118  		return nil, wrapErr(err)
   119  	}
   120  	err = sd.SetDACL(dacl, true, false)
   121  	if err != nil {
   122  		return nil, wrapErr(err)
   123  	}
   124  	sd, err = sd.ToSelfRelative()
   125  	if err != nil {
   126  		return nil, wrapErr(err)
   127  	}
   128  	return sd, nil
   129  }
   130  
   131  func getCurrentProcessAppID() (*wtFwpByteBlob, error) {
   132  	currentFile, err := os.Executable()
   133  	if err != nil {
   134  		return nil, wrapErr(err)
   135  	}
   136  
   137  	curFilePtr, err := windows.UTF16PtrFromString(currentFile)
   138  	if err != nil {
   139  		return nil, wrapErr(err)
   140  	}
   141  
   142  	var appID *wtFwpByteBlob
   143  	err = fwpmGetAppIdFromFileName0(curFilePtr, unsafe.Pointer(&appID))
   144  	if err != nil {
   145  		return nil, wrapErr(err)
   146  	}
   147  	return appID, nil
   148  }