golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/version/official.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package version
     7  
     8  import (
     9  	"errors"
    10  	"os"
    11  	"unsafe"
    12  
    13  	"golang.org/x/sys/windows"
    14  )
    15  
    16  const (
    17  	officialCommonName = "WireGuard LLC"
    18  	evPolicyOid        = "2.23.140.1.3"
    19  	policyExtensionOid = "2.5.29.32"
    20  )
    21  
    22  // These are easily by-passable checks, which do not serve security purposes.
    23  // DO NOT PLACE SECURITY-SENSITIVE FUNCTIONS IN THIS FILE
    24  
    25  func IsRunningOfficialVersion() bool {
    26  	path, err := os.Executable()
    27  	if err != nil {
    28  		return false
    29  	}
    30  
    31  	names, err := extractCertificateNames(path)
    32  	if err != nil {
    33  		return false
    34  	}
    35  	for _, name := range names {
    36  		if name == officialCommonName {
    37  			return true
    38  		}
    39  	}
    40  	return false
    41  }
    42  
    43  func IsRunningEVSigned() bool {
    44  	path, err := os.Executable()
    45  	if err != nil {
    46  		return false
    47  	}
    48  
    49  	policies, err := extractCertificatePolicies(path, policyExtensionOid)
    50  	if err != nil {
    51  		return false
    52  	}
    53  	for _, policy := range policies {
    54  		if policy == evPolicyOid {
    55  			return true
    56  		}
    57  	}
    58  	return false
    59  }
    60  
    61  func extractCertificateNames(path string) ([]string, error) {
    62  	path16, err := windows.UTF16PtrFromString(path)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	var certStore windows.Handle
    67  	err = windows.CryptQueryObject(windows.CERT_QUERY_OBJECT_FILE, unsafe.Pointer(path16), windows.CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED, windows.CERT_QUERY_FORMAT_FLAG_ALL, 0, nil, nil, nil, &certStore, nil, nil)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  	defer windows.CertCloseStore(certStore, 0)
    72  	var cert *windows.CertContext
    73  	var names []string
    74  	for {
    75  		cert, err = windows.CertEnumCertificatesInStore(certStore, cert)
    76  		if err != nil {
    77  			if errors.Is(err, windows.Errno(windows.CRYPT_E_NOT_FOUND)) {
    78  				break
    79  			}
    80  			return nil, err
    81  		}
    82  		if cert == nil {
    83  			break
    84  		}
    85  		nameLen := windows.CertGetNameString(cert, windows.CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, nil, nil, 0)
    86  		if nameLen == 0 {
    87  			continue
    88  		}
    89  		name16 := make([]uint16, nameLen)
    90  		if windows.CertGetNameString(cert, windows.CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, nil, &name16[0], nameLen) != nameLen {
    91  			continue
    92  		}
    93  		if name16[0] == 0 {
    94  			continue
    95  		}
    96  		names = append(names, windows.UTF16ToString(name16))
    97  	}
    98  	if names == nil {
    99  		return nil, windows.Errno(windows.CRYPT_E_NOT_FOUND)
   100  	}
   101  	return names, nil
   102  }
   103  
   104  func extractCertificatePolicies(path, oid string) ([]string, error) {
   105  	path16, err := windows.UTF16PtrFromString(path)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	oid8, err := windows.BytePtrFromString(oid)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	var certStore windows.Handle
   114  	err = windows.CryptQueryObject(windows.CERT_QUERY_OBJECT_FILE, unsafe.Pointer(path16), windows.CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED, windows.CERT_QUERY_FORMAT_FLAG_ALL, 0, nil, nil, nil, &certStore, nil, nil)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	defer windows.CertCloseStore(certStore, 0)
   119  	var cert *windows.CertContext
   120  	var policies []string
   121  	for {
   122  		cert, err = windows.CertEnumCertificatesInStore(certStore, cert)
   123  		if err != nil {
   124  			if errors.Is(err, windows.Errno(windows.CRYPT_E_NOT_FOUND)) {
   125  				break
   126  			}
   127  			return nil, err
   128  		}
   129  		if cert == nil {
   130  			break
   131  		}
   132  		ext := windows.CertFindExtension(oid8, cert.CertInfo.CountExtensions, cert.CertInfo.Extensions)
   133  		if ext == nil {
   134  			continue
   135  		}
   136  		var decodedLen uint32
   137  		err = windows.CryptDecodeObject(windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, ext.ObjId, ext.Value.Data, ext.Value.Size, 0, nil, &decodedLen)
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  		bytes := make([]byte, decodedLen)
   142  		certPoliciesInfo := (*windows.CertPoliciesInfo)(unsafe.Pointer(&bytes[0]))
   143  		err = windows.CryptDecodeObject(windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, ext.ObjId, ext.Value.Data, ext.Value.Size, 0, unsafe.Pointer(&bytes[0]), &decodedLen)
   144  		if err != nil {
   145  			return nil, err
   146  		}
   147  		for i := uintptr(0); i < uintptr(certPoliciesInfo.Count); i++ {
   148  			cp := (*windows.CertPolicyInfo)(unsafe.Add(unsafe.Pointer(certPoliciesInfo.PolicyInfos), i*unsafe.Sizeof(*certPoliciesInfo.PolicyInfos)))
   149  			policies = append(policies, windows.BytePtrToString(cp.Identifier))
   150  		}
   151  	}
   152  	if policies == nil {
   153  		return nil, windows.Errno(windows.CRYPT_E_NOT_FOUND)
   154  	}
   155  	return policies, nil
   156  }