github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/cert/certutils.go (about)

     1  package cert
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"fmt"
     9  	"regexp"
    10  	"strings"
    11  
    12  	"github.com/kyma-incubator/compass/components/director/internal/util"
    13  
    14  	"github.com/pkg/errors"
    15  
    16  	"github.com/google/uuid"
    17  )
    18  
    19  const (
    20  	// ConsumerTypeExtraField is the consumer type json field in the auth session body extra.
    21  	ConsumerTypeExtraField = "consumer_type"
    22  	// AccessLevelsExtraField is the tenant access levels json field in the auth session body extra.
    23  	AccessLevelsExtraField = "tenant_access_levels"
    24  	// InternalConsumerIDField is the internal consumer id json field in the auth session body extra.
    25  	InternalConsumerIDField = "internal_consumer_id"
    26  )
    27  
    28  // GetOrganization returns the O part of the subject
    29  func GetOrganization(subject string) string {
    30  	return getRegexMatch("O=([^(,|+)]+)", subject)
    31  }
    32  
    33  // GetOrganizationalUnit returns the first OU of the subject
    34  func GetOrganizationalUnit(subject string) string {
    35  	return getRegexMatch("OU=([^(,|+)]+)", subject)
    36  }
    37  
    38  // GetUUIDOrganizationalUnit returns the OU that is a valid UUID or empty string if there is no OU that is a valid UUID
    39  func GetUUIDOrganizationalUnit(subject string) string {
    40  	orgUnits := GetAllOrganizationalUnits(subject)
    41  	for _, orgUnit := range orgUnits {
    42  		if _, err := uuid.Parse(orgUnit); err == nil {
    43  			return orgUnit
    44  		}
    45  	}
    46  	return ""
    47  }
    48  
    49  // GetRemainingOrganizationalUnit returns the OU that is remaining after matching previously expected ones based on a given pattern
    50  func GetRemainingOrganizationalUnit(organizationalUnitPattern string, ouRegionPattern string) func(string) string {
    51  	return func(subject string) string {
    52  		regex := ConstructOURegex(organizationalUnitPattern, ouRegionPattern)
    53  		orgUnitRegex := regexp.MustCompile(regex)
    54  		orgUnits := GetAllOrganizationalUnits(subject)
    55  
    56  		remainingOrgUnit := ""
    57  		var matchedOrgUnits int
    58  		for _, orgUnit := range orgUnits {
    59  			if !orgUnitRegex.MatchString(orgUnit) {
    60  				remainingOrgUnit = orgUnit
    61  			} else {
    62  				matchedOrgUnits++
    63  			}
    64  		}
    65  
    66  		if len(orgUnits)-matchedOrgUnits == 1 {
    67  			return remainingOrgUnit
    68  		}
    69  
    70  		return ""
    71  	}
    72  }
    73  
    74  // ConstructOURegex returns regex which is used to determine authID from cert subject
    75  func ConstructOURegex(patterns ...string) string {
    76  	nonEmptyStr := make([]string, 0)
    77  	for _, pattern := range patterns {
    78  		if len(pattern) > 0 {
    79  			nonEmptyStr = append(nonEmptyStr, fmt.Sprintf("\\b%s\\b", pattern))
    80  		}
    81  	}
    82  	return strings.Join(nonEmptyStr, "|")
    83  }
    84  
    85  // GetAllOrganizationalUnits returns all OU parts of the subject
    86  func GetAllOrganizationalUnits(subject string) []string {
    87  	return getAllRegexMatches("OU=([^(,|+)]+)", subject)
    88  }
    89  
    90  // GetCountry returns the C part of the subject
    91  func GetCountry(subject string) string {
    92  	return getRegexMatch("C=([^(,|+)]+)", subject)
    93  }
    94  
    95  // GetProvince returns the ST part of the subject
    96  func GetProvince(subject string) string {
    97  	return getRegexMatch("ST=([^(,|+)]+)", subject)
    98  }
    99  
   100  // GetLocality returns the L part of the subject
   101  func GetLocality(subject string) string {
   102  	return getRegexMatch("L=([^(,|+)]+)", subject)
   103  }
   104  
   105  // GetCommonName returns the CN part of the subject
   106  func GetCommonName(subject string) string {
   107  	return getRegexMatch("CN=([^,]+)", subject)
   108  }
   109  
   110  // GetAuthSessionExtra returns an appropriate auth session extra for the given consumerType, accessLevel and internalConsumerID
   111  func GetAuthSessionExtra(consumerType, internalConsumerID string, accessLevels []string) map[string]interface{} {
   112  	return map[string]interface{}{
   113  		ConsumerTypeExtraField:  consumerType,
   114  		AccessLevelsExtraField:  accessLevels,
   115  		InternalConsumerIDField: internalConsumerID,
   116  	}
   117  }
   118  
   119  // DecodeCertificates accepts raw certificate chain and return slice of parsed certificates
   120  func DecodeCertificates(pemCertChain []byte) ([]*x509.Certificate, error) {
   121  	if pemCertChain == nil {
   122  		return nil, errors.New("Certificate data is empty")
   123  	}
   124  
   125  	var certificates []*x509.Certificate
   126  
   127  	for block, rest := pem.Decode(pemCertChain); block != nil && rest != nil; {
   128  		cert, err := x509.ParseCertificate(block.Bytes)
   129  		if err != nil {
   130  			return nil, errors.Wrap(err, "Failed to decode one of the pem blocks")
   131  		}
   132  
   133  		certificates = append(certificates, cert)
   134  
   135  		block, rest = pem.Decode(rest)
   136  	}
   137  
   138  	if len(certificates) == 0 {
   139  		return nil, errors.New("No certificates found in the pem block")
   140  	}
   141  
   142  	return certificates, nil
   143  }
   144  
   145  // NewTLSCertificate creates tls certificate from given certificate chain in form of slice of certificates
   146  func NewTLSCertificate(key *rsa.PrivateKey, certificates ...*x509.Certificate) tls.Certificate {
   147  	rawCerts := make([][]byte, len(certificates))
   148  	for i, c := range certificates {
   149  		rawCerts[i] = c.Raw
   150  	}
   151  
   152  	return tls.Certificate{
   153  		Certificate: rawCerts,
   154  		PrivateKey:  key,
   155  	}
   156  }
   157  
   158  func getRegexMatch(regex, text string) string {
   159  	matches := getAllRegexMatches(regex, text)
   160  	if len(matches) > 0 {
   161  		return matches[0]
   162  	}
   163  	return ""
   164  }
   165  
   166  func getAllRegexMatches(regex, text string) []string {
   167  	cnRegex := regexp.MustCompile(regex)
   168  	matches := cnRegex.FindAllStringSubmatch(text, -1)
   169  
   170  	result := make([]string, 0, len(matches))
   171  	for _, match := range matches {
   172  		if len(match) != 2 {
   173  			continue
   174  		}
   175  		result = append(result, match[1])
   176  	}
   177  
   178  	return result
   179  }
   180  
   181  // ParseCertificate creates a tls.Certificate from certificate and key
   182  // The cert/key can be in PEM format or can be base64 encoded
   183  func ParseCertificate(cert string, key string) (*tls.Certificate, error) {
   184  	if cert == "" || key == "" {
   185  		return nil, errors.New("The cert/key is required")
   186  	}
   187  
   188  	certChainBytes := util.TryDecodeBase64(cert)
   189  	privateKeyBytes := util.TryDecodeBase64(key)
   190  
   191  	return ParseCertificateBytes(certChainBytes, privateKeyBytes)
   192  }
   193  
   194  // ParseCertificateBytes creates a tls.Certificate from certificate and key
   195  func ParseCertificateBytes(cert []byte, key []byte) (*tls.Certificate, error) {
   196  	certs, err := DecodeCertificates(cert)
   197  	if err != nil {
   198  		return nil, errors.Wrap(err, "Error while decoding certificate pem block")
   199  	}
   200  
   201  	privateKeyPem, _ := pem.Decode(key)
   202  	if privateKeyPem == nil {
   203  		return nil, errors.New("Error while decoding private key pem block")
   204  	}
   205  
   206  	privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyPem.Bytes)
   207  	if err != nil {
   208  		pkcs8PrivateKey, err := x509.ParsePKCS8PrivateKey(privateKeyPem.Bytes)
   209  		if err != nil {
   210  			return nil, err
   211  		}
   212  		var ok bool
   213  		privateKey, ok = pkcs8PrivateKey.(*rsa.PrivateKey)
   214  		if !ok {
   215  			return nil, err
   216  		}
   217  	}
   218  
   219  	tlsCert := NewTLSCertificate(privateKey, certs...)
   220  
   221  	return &tlsCert, nil
   222  }