github.com/emmansun/gmsm@v0.29.1/smx509/root_windows.go (about)

     1  package smx509
     2  
     3  import (
     4  	"crypto/x509"
     5  	"errors"
     6  	"syscall"
     7  	"unsafe"
     8  )
     9  
    10  func loadSystemRoots() (*CertPool, error) {
    11  	return &CertPool{systemPool: true}, nil
    12  }
    13  
    14  // Creates a new *syscall.CertContext representing the leaf certificate in an in-memory
    15  // certificate store containing itself and all of the intermediate certificates specified
    16  // in the opts.Intermediates CertPool.
    17  //
    18  // A pointer to the in-memory store is available in the returned CertContext's Store field.
    19  // The store is automatically freed when the CertContext is freed using
    20  // syscall.CertFreeCertificateContext.
    21  func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertContext, error) {
    22  	var storeCtx *syscall.CertContext
    23  
    24  	leafCtx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &leaf.Raw[0], uint32(len(leaf.Raw)))
    25  	if err != nil {
    26  		return nil, err
    27  	}
    28  	defer syscall.CertFreeCertificateContext(leafCtx)
    29  
    30  	handle, err := syscall.CertOpenStore(syscall.CERT_STORE_PROV_MEMORY, 0, 0, syscall.CERT_STORE_DEFER_CLOSE_UNTIL_LAST_FREE_FLAG, 0)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  	defer syscall.CertCloseStore(handle, 0)
    35  
    36  	err = syscall.CertAddCertificateContextToStore(handle, leafCtx, syscall.CERT_STORE_ADD_ALWAYS, &storeCtx)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	if opts.Intermediates != nil {
    42  		for i := 0; i < opts.Intermediates.len(); i++ {
    43  			intermediate, _, err := opts.Intermediates.cert(i)
    44  			if err != nil {
    45  				return nil, err
    46  			}
    47  			ctx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &intermediate.Raw[0], uint32(len(intermediate.Raw)))
    48  			if err != nil {
    49  				return nil, err
    50  			}
    51  
    52  			err = syscall.CertAddCertificateContextToStore(handle, ctx, syscall.CERT_STORE_ADD_ALWAYS, nil)
    53  			syscall.CertFreeCertificateContext(ctx)
    54  			if err != nil {
    55  				return nil, err
    56  			}
    57  		}
    58  	}
    59  
    60  	return storeCtx, nil
    61  }
    62  
    63  // extractSimpleChain extracts the final certificate chain from a CertSimpleChain.
    64  func extractSimpleChain(simpleChain **syscall.CertSimpleChain, count int) (chain []*Certificate, err error) {
    65  	if simpleChain == nil || count == 0 {
    66  		return nil, errors.New("x509: invalid simple chain")
    67  	}
    68  
    69  	simpleChains := (*[1 << 20]*syscall.CertSimpleChain)(unsafe.Pointer(simpleChain))[:count:count]
    70  	lastChain := simpleChains[count-1]
    71  	elements := (*[1 << 20]*syscall.CertChainElement)(unsafe.Pointer(lastChain.Elements))[:lastChain.NumElements:lastChain.NumElements]
    72  	for i := 0; i < int(lastChain.NumElements); i++ {
    73  		// Copy the buf, since ParseCertificate does not create its own copy.
    74  		cert := elements[i].CertContext
    75  		encodedCert := (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:cert.Length:cert.Length]
    76  		buf := make([]byte, cert.Length)
    77  		copy(buf, encodedCert)
    78  		parsedCert, err := ParseCertificate(buf)
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  		chain = append(chain, parsedCert)
    83  	}
    84  
    85  	return chain, nil
    86  }
    87  
    88  // checkChainTrustStatus checks the trust status of the certificate chain, translating
    89  // any errors it finds into Go errors in the process.
    90  func checkChainTrustStatus(c *Certificate, chainCtx *syscall.CertChainContext) error {
    91  	if chainCtx.TrustStatus.ErrorStatus != syscall.CERT_TRUST_NO_ERROR {
    92  		status := chainCtx.TrustStatus.ErrorStatus
    93  		switch status {
    94  		case syscall.CERT_TRUST_IS_NOT_TIME_VALID:
    95  			return CertificateInvalidError{Cert: c.asX509(), Reason: Expired, Detail: ""}
    96  		case syscall.CERT_TRUST_IS_NOT_VALID_FOR_USAGE:
    97  			return CertificateInvalidError{Cert: c.asX509(), Reason: IncompatibleUsage, Detail: ""}
    98  		// TODO(filippo): surface more error statuses.
    99  		default:
   100  			return UnknownAuthorityError{c, nil, nil}
   101  		}
   102  	}
   103  	return nil
   104  }
   105  
   106  // checkChainSSLServerPolicy checks that the certificate chain in chainCtx is valid for
   107  // use as a certificate chain for a SSL/TLS server.
   108  func checkChainSSLServerPolicy(c *Certificate, chainCtx *syscall.CertChainContext, opts *VerifyOptions) error {
   109  	servernamep, err := syscall.UTF16PtrFromString(opts.DNSName)
   110  	if err != nil {
   111  		return err
   112  	}
   113  	sslPara := &syscall.SSLExtraCertChainPolicyPara{
   114  		AuthType:   syscall.AUTHTYPE_SERVER,
   115  		ServerName: servernamep,
   116  	}
   117  	sslPara.Size = uint32(unsafe.Sizeof(*sslPara))
   118  
   119  	para := &syscall.CertChainPolicyPara{
   120  		ExtraPolicyPara: (syscall.Pointer)(unsafe.Pointer(sslPara)),
   121  	}
   122  	para.Size = uint32(unsafe.Sizeof(*para))
   123  
   124  	status := syscall.CertChainPolicyStatus{}
   125  	err = syscall.CertVerifyCertificateChainPolicy(syscall.CERT_CHAIN_POLICY_SSL, chainCtx, para, &status)
   126  	if err != nil {
   127  		return err
   128  	}
   129  
   130  	// TODO(mkrautz): use the lChainIndex and lElementIndex fields
   131  	// of the CertChainPolicyStatus to provide proper context, instead
   132  	// using c.
   133  	if status.Error != 0 {
   134  		switch status.Error {
   135  		case syscall.CERT_E_EXPIRED:
   136  			return CertificateInvalidError{Cert: c.asX509(), Reason: Expired, Detail: ""}
   137  		case syscall.CERT_E_CN_NO_MATCH:
   138  			return x509.HostnameError{Certificate: c.asX509(), Host: opts.DNSName}
   139  		case syscall.CERT_E_UNTRUSTEDROOT:
   140  			return UnknownAuthorityError{c, nil, nil}
   141  		default:
   142  			return UnknownAuthorityError{c, nil, nil}
   143  		}
   144  	}
   145  
   146  	return nil
   147  }
   148  
   149  // windowsExtKeyUsageOIDs are the C NUL-terminated string representations of the
   150  // OIDs for use with the Windows API.
   151  var windowsExtKeyUsageOIDs = make(map[ExtKeyUsage][]byte, len(extKeyUsageOIDs))
   152  
   153  func init() {
   154  	for _, eku := range extKeyUsageOIDs {
   155  		windowsExtKeyUsageOIDs[eku.extKeyUsage] = []byte(eku.oid.String() + "\x00")
   156  	}
   157  }
   158  
   159  func verifyChain(c *Certificate, chainCtx *syscall.CertChainContext, opts *VerifyOptions) (chain []*Certificate, err error) {
   160  	err = checkChainTrustStatus(c, chainCtx)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	if opts != nil && len(opts.DNSName) > 0 {
   166  		err = checkChainSSLServerPolicy(c, chainCtx, opts)
   167  		if err != nil {
   168  			return nil, err
   169  		}
   170  	}
   171  
   172  	chain, err = extractSimpleChain(chainCtx.Chains, int(chainCtx.ChainCount))
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	if len(chain) == 0 {
   177  		return nil, errors.New("x509: internal error: system verifier returned an empty chain")
   178  	}
   179  
   180  	// Mitigate CVE-2020-0601, where the Windows system verifier might be
   181  	// tricked into using custom curve parameters for a trusted root, by
   182  	// double-checking all ECDSA signatures. If the system was tricked into
   183  	// using spoofed parameters, the signature will be invalid for the correct
   184  	// ones we parsed. (We don't support custom curves ourselves.)
   185  	for i, parent := range chain[1:] {
   186  		if parent.PublicKeyAlgorithm != ECDSA {
   187  			continue
   188  		}
   189  		if err := parent.CheckSignature(chain[i].SignatureAlgorithm,
   190  			chain[i].RawTBSCertificate, chain[i].Signature); err != nil {
   191  			return nil, err
   192  		}
   193  	}
   194  	return chain, nil
   195  }
   196  
   197  // systemVerify is like Verify, except that it uses CryptoAPI calls
   198  // to build certificate chains and verify them.
   199  func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
   200  	storeCtx, err := createStoreContext(c, opts)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	defer syscall.CertFreeCertificateContext(storeCtx)
   205  
   206  	para := new(syscall.CertChainPara)
   207  	para.Size = uint32(unsafe.Sizeof(*para))
   208  
   209  	keyUsages := opts.KeyUsages
   210  	if len(keyUsages) == 0 {
   211  		keyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth}
   212  	}
   213  	oids := make([]*byte, 0, len(keyUsages))
   214  	for _, eku := range keyUsages {
   215  		if eku == ExtKeyUsageAny {
   216  			oids = nil
   217  			break
   218  		}
   219  		if oid, ok := windowsExtKeyUsageOIDs[eku]; ok {
   220  			oids = append(oids, &oid[0])
   221  		}
   222  	}
   223  	if oids != nil {
   224  		para.RequestedUsage.Type = syscall.USAGE_MATCH_TYPE_OR
   225  		para.RequestedUsage.Usage.Length = uint32(len(oids))
   226  		para.RequestedUsage.Usage.UsageIdentifiers = &oids[0]
   227  	} else {
   228  		para.RequestedUsage.Type = syscall.USAGE_MATCH_TYPE_AND
   229  		para.RequestedUsage.Usage.Length = 0
   230  		para.RequestedUsage.Usage.UsageIdentifiers = nil
   231  	}
   232  
   233  	var verifyTime *syscall.Filetime
   234  	if opts != nil && !opts.CurrentTime.IsZero() {
   235  		ft := syscall.NsecToFiletime(opts.CurrentTime.UnixNano())
   236  		verifyTime = &ft
   237  	}
   238  
   239  	// The default is to return only the highest quality chain,
   240  	// setting this flag will add additional lower quality contexts.
   241  	// These are returned in the LowerQualityChains field.
   242  	const CERT_CHAIN_RETURN_LOWER_QUALITY_CONTEXTS = 0x00000080
   243  
   244  	// CertGetCertificateChain will traverse Windows's root stores in an attempt to build a verified certificate chain
   245  	var topCtx *syscall.CertChainContext
   246  	err = syscall.CertGetCertificateChain(syscall.Handle(0), storeCtx, verifyTime, storeCtx.Store, para, CERT_CHAIN_RETURN_LOWER_QUALITY_CONTEXTS, 0, &topCtx)
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  	defer syscall.CertFreeCertificateChain(topCtx)
   251  
   252  	chain, topErr := verifyChain(c, topCtx, opts)
   253  	if topErr == nil {
   254  		chains = append(chains, chain)
   255  	}
   256  
   257  	if lqCtxCount := topCtx.LowerQualityChainCount; lqCtxCount > 0 {
   258  		lqCtxs := (*[1 << 20]*syscall.CertChainContext)(unsafe.Pointer(topCtx.LowerQualityChains))[:lqCtxCount:lqCtxCount]
   259  
   260  		for _, ctx := range lqCtxs {
   261  			chain, err := verifyChain(c, ctx, opts)
   262  			if err == nil {
   263  				chains = append(chains, chain)
   264  			}
   265  		}
   266  	}
   267  
   268  	if len(chains) == 0 {
   269  		// Return the error from the highest quality context.
   270  		return nil, topErr
   271  	}
   272  
   273  	return chains, nil
   274  }