github.com/emmansun/gmsm@v0.29.1/smx509/cert_pool.go (about)

     1  package smx509
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/pem"
     7  	"sync"
     8  )
     9  
    10  type sum224 [sha256.Size224]byte
    11  
    12  // CertPool is a set of certificates.
    13  type CertPool struct {
    14  	byName map[string][]int // cert.RawSubject => index into lazyCerts
    15  
    16  	// lazyCerts contains funcs that return a certificate,
    17  	// lazily parsing/decompressing it as needed.
    18  	lazyCerts []lazyCert
    19  
    20  	// haveSum maps from sum224(cert.Raw) to true. It's used only
    21  	// for AddCert duplicate detection, to avoid CertPool.contains
    22  	// calls in the AddCert path (because the contains method can
    23  	// call getCert and otherwise negate savings from lazy getCert
    24  	// funcs).
    25  	haveSum map[sum224]bool
    26  
    27  	// systemPool indicates whether this is a special pool derived from the
    28  	// system roots. If it includes additional roots, it requires doing two
    29  	// verifications, one using the roots provided by the caller, and one using
    30  	// the system platform verifier.
    31  	systemPool bool
    32  }
    33  
    34  // lazyCert is minimal metadata about a Cert and a func to retrieve it
    35  // in its normal expanded *Certificate form.
    36  type lazyCert struct {
    37  	// rawSubject is the Certificate.RawSubject value.
    38  	// It's the same as the CertPool.byName key, but in []byte
    39  	// form to make CertPool.Subjects (as used by crypto/tls) do
    40  	// fewer allocations.
    41  	rawSubject []byte
    42  
    43  	// constraint is a function to run against a chain when it is a candidate to
    44  	// be added to the chain. This allows adding arbitrary constraints that are
    45  	// not specified in the certificate itself.
    46  	constraint func([]*Certificate) error
    47  
    48  	// getCert returns the certificate.
    49  	//
    50  	// It is not meant to do network operations or anything else
    51  	// where a failure is likely; the func is meant to lazily
    52  	// parse/decompress data that is already known to be good. The
    53  	// error in the signature primarily is meant for use in the
    54  	// case where a cert file existed on local disk when the program
    55  	// started up is deleted later before it's read.
    56  	getCert func() (*Certificate, error)
    57  }
    58  
    59  // NewCertPool returns a new, empty CertPool.
    60  func NewCertPool() *CertPool {
    61  	return &CertPool{
    62  		byName:  make(map[string][]int),
    63  		haveSum: make(map[sum224]bool),
    64  	}
    65  }
    66  
    67  // len returns the number of certs in the set.
    68  // A nil set is a valid empty set.
    69  func (s *CertPool) len() int {
    70  	if s == nil {
    71  		return 0
    72  	}
    73  	return len(s.lazyCerts)
    74  }
    75  
    76  // cert returns cert index n in s.
    77  func (s *CertPool) cert(n int) (*Certificate, func([]*Certificate) error, error) {
    78  	cert, err := s.lazyCerts[n].getCert()
    79  	return cert, s.lazyCerts[n].constraint, err
    80  }
    81  
    82  // Clone returns a copy of s.
    83  func (s *CertPool) Clone() *CertPool {
    84  	p := &CertPool{
    85  		byName:     make(map[string][]int, len(s.byName)),
    86  		lazyCerts:  make([]lazyCert, len(s.lazyCerts)),
    87  		haveSum:    make(map[sum224]bool, len(s.haveSum)),
    88  		systemPool: s.systemPool,
    89  	}
    90  	for k, v := range s.byName {
    91  		indexes := make([]int, len(v))
    92  		copy(indexes, v)
    93  		p.byName[k] = indexes
    94  	}
    95  	for k := range s.haveSum {
    96  		p.haveSum[k] = true
    97  	}
    98  	copy(p.lazyCerts, s.lazyCerts)
    99  	return p
   100  }
   101  
   102  // SystemCertPool returns a copy of the system cert pool.
   103  //
   104  // On Unix systems other than macOS the environment variables SSL_CERT_FILE and
   105  // SSL_CERT_DIR can be used to override the system default locations for the SSL
   106  // certificate file and SSL certificate files directory, respectively. The
   107  // latter can be a colon-separated list.
   108  //
   109  // Any mutations to the returned pool are not written to disk and do not affect
   110  // any other pool returned by SystemCertPool.
   111  //
   112  // New changes in the system cert pool might not be reflected in subsequent calls.
   113  func SystemCertPool() (*CertPool, error) {
   114  	if sysRoots := systemRootsPool(); sysRoots != nil {
   115  		return sysRoots.Clone(), nil
   116  	}
   117  
   118  	return loadSystemRoots()
   119  }
   120  
   121  type potentialParent struct {
   122  	cert       *Certificate
   123  	constraint func([]*Certificate) error
   124  }
   125  
   126  // findPotentialParents returns the indexes of certificates in s which might
   127  // have signed cert.
   128  func (s *CertPool) findPotentialParents(cert *Certificate) []potentialParent {
   129  	if s == nil {
   130  		return nil
   131  	}
   132  
   133  	// consider all candidates where cert.Issuer matches cert.Subject.
   134  	// when picking possible candidates the list is built in the order
   135  	// of match plausibility as to save cycles in buildChains:
   136  	//   AKID and SKID match
   137  	//   AKID present, SKID missing / AKID missing, SKID present
   138  	//   AKID and SKID don't match
   139  	var matchingKeyID, oneKeyID, mismatchKeyID []potentialParent
   140  	for _, c := range s.byName[string(cert.RawIssuer)] {
   141  		candidate, constraint, err := s.cert(c)
   142  		if err != nil {
   143  			continue
   144  		}
   145  		kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId)
   146  		switch {
   147  		case kidMatch:
   148  			matchingKeyID = append(matchingKeyID, potentialParent{candidate, constraint})
   149  		case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) ||
   150  			(len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0):
   151  			oneKeyID = append(oneKeyID, potentialParent{candidate, constraint})
   152  		default:
   153  			mismatchKeyID = append(mismatchKeyID, potentialParent{candidate, constraint})
   154  		}
   155  	}
   156  
   157  	found := len(matchingKeyID) + len(oneKeyID) + len(mismatchKeyID)
   158  	if found == 0 {
   159  		return nil
   160  	}
   161  	candidates := make([]potentialParent, 0, found)
   162  	candidates = append(candidates, matchingKeyID...)
   163  	candidates = append(candidates, oneKeyID...)
   164  	candidates = append(candidates, mismatchKeyID...)
   165  	return candidates
   166  }
   167  
   168  func (s *CertPool) contains(cert *Certificate) bool {
   169  	if s == nil {
   170  		return false
   171  	}
   172  	return s.haveSum[sha256.Sum224(cert.Raw)]
   173  }
   174  
   175  // AddCert adds a certificate to a pool.
   176  func (s *CertPool) AddCert(cert *Certificate) {
   177  	if cert == nil {
   178  		panic("adding nil Certificate to CertPool")
   179  	}
   180  	s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
   181  		return cert, nil
   182  	}, nil)
   183  }
   184  
   185  // addCertFunc adds metadata about a certificate to a pool, along with
   186  // a func to fetch that certificate later when needed.
   187  //
   188  // The rawSubject is Certificate.RawSubject and must be non-empty.
   189  // The getCert func may be called 0 or more times.
   190  func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error), constraint func([]*Certificate) error) {
   191  	if getCert == nil {
   192  		panic("getCert can't be nil")
   193  	}
   194  
   195  	// Check that the certificate isn't being added twice.
   196  	if s.haveSum[rawSum224] {
   197  		return
   198  	}
   199  
   200  	s.haveSum[rawSum224] = true
   201  	s.lazyCerts = append(s.lazyCerts, lazyCert{
   202  		rawSubject: []byte(rawSubject),
   203  		getCert:    getCert,
   204  		constraint: constraint,
   205  	})
   206  	s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1)
   207  }
   208  
   209  // AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
   210  // It appends any certificates found to s and reports whether any certificates
   211  // were successfully parsed.
   212  //
   213  // On many Linux systems, /etc/ssl/cert.pem will contain the system wide set
   214  // of root CAs in a format suitable for this function.
   215  func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
   216  	for len(pemCerts) > 0 {
   217  		var block *pem.Block
   218  		block, pemCerts = pem.Decode(pemCerts)
   219  		if block == nil {
   220  			break
   221  		}
   222  		if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   223  			continue
   224  		}
   225  
   226  		certBytes := block.Bytes
   227  		cert, err := ParseCertificate(certBytes)
   228  		if err != nil {
   229  			continue
   230  		}
   231  		var lazyCert struct {
   232  			sync.Once
   233  			v *Certificate
   234  		}
   235  		s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
   236  			lazyCert.Do(func() {
   237  				// This can't fail, as the same bytes already parsed above.
   238  				lazyCert.v, _ = ParseCertificate(certBytes)
   239  				certBytes = nil
   240  			})
   241  			return lazyCert.v, nil
   242  		}, nil)
   243  		ok = true
   244  	}
   245  
   246  	return ok
   247  }
   248  
   249  // Subjects returns a list of the DER-encoded subjects of
   250  // all of the certificates in the pool.
   251  //
   252  // Deprecated: if s was returned by SystemCertPool, Subjects
   253  // will not include the system roots.
   254  func (s *CertPool) Subjects() [][]byte {
   255  	res := make([][]byte, s.len())
   256  	for i, lc := range s.lazyCerts {
   257  		res[i] = lc.rawSubject
   258  	}
   259  	return res
   260  }
   261  
   262  // Equal reports whether s and other are equal.
   263  func (s *CertPool) Equal(other *CertPool) bool {
   264  	if s == nil || other == nil {
   265  		return s == other
   266  	}
   267  	if s.systemPool != other.systemPool || len(s.haveSum) != len(other.haveSum) {
   268  		return false
   269  	}
   270  	for h := range s.haveSum {
   271  		if !other.haveSum[h] {
   272  			return false
   273  		}
   274  	}
   275  	return true
   276  }
   277  
   278  // AddCertWithConstraint adds a certificate to the pool with the additional
   279  // constraint. When Certificate.Verify builds a chain which is rooted by cert,
   280  // it will additionally pass the whole chain to constraint to determine its
   281  // validity. If constraint returns a non-nil error, the chain will be discarded.
   282  // constraint may be called concurrently from multiple goroutines.
   283  func (s *CertPool) AddCertWithConstraint(cert *Certificate, constraint func([]*Certificate) error) {
   284  	if cert == nil {
   285  		panic("adding nil Certificate to CertPool")
   286  	}
   287  	s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
   288  		return cert, nil
   289  	}, constraint)
   290  }