github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/cert/certutils.go (about) 1 package cert 2 3 import ( 4 "crypto/rsa" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/pem" 8 "fmt" 9 "regexp" 10 "strings" 11 12 "github.com/kyma-incubator/compass/components/director/internal/util" 13 14 "github.com/pkg/errors" 15 16 "github.com/google/uuid" 17 ) 18 19 const ( 20 // ConsumerTypeExtraField is the consumer type json field in the auth session body extra. 21 ConsumerTypeExtraField = "consumer_type" 22 // AccessLevelsExtraField is the tenant access levels json field in the auth session body extra. 23 AccessLevelsExtraField = "tenant_access_levels" 24 // InternalConsumerIDField is the internal consumer id json field in the auth session body extra. 25 InternalConsumerIDField = "internal_consumer_id" 26 ) 27 28 // GetOrganization returns the O part of the subject 29 func GetOrganization(subject string) string { 30 return getRegexMatch("O=([^(,|+)]+)", subject) 31 } 32 33 // GetOrganizationalUnit returns the first OU of the subject 34 func GetOrganizationalUnit(subject string) string { 35 return getRegexMatch("OU=([^(,|+)]+)", subject) 36 } 37 38 // GetUUIDOrganizationalUnit returns the OU that is a valid UUID or empty string if there is no OU that is a valid UUID 39 func GetUUIDOrganizationalUnit(subject string) string { 40 orgUnits := GetAllOrganizationalUnits(subject) 41 for _, orgUnit := range orgUnits { 42 if _, err := uuid.Parse(orgUnit); err == nil { 43 return orgUnit 44 } 45 } 46 return "" 47 } 48 49 // GetRemainingOrganizationalUnit returns the OU that is remaining after matching previously expected ones based on a given pattern 50 func GetRemainingOrganizationalUnit(organizationalUnitPattern string, ouRegionPattern string) func(string) string { 51 return func(subject string) string { 52 regex := ConstructOURegex(organizationalUnitPattern, ouRegionPattern) 53 orgUnitRegex := regexp.MustCompile(regex) 54 orgUnits := GetAllOrganizationalUnits(subject) 55 56 remainingOrgUnit := "" 57 var matchedOrgUnits int 58 for _, orgUnit := range orgUnits { 59 if !orgUnitRegex.MatchString(orgUnit) { 60 remainingOrgUnit = orgUnit 61 } else { 62 matchedOrgUnits++ 63 } 64 } 65 66 if len(orgUnits)-matchedOrgUnits == 1 { 67 return remainingOrgUnit 68 } 69 70 return "" 71 } 72 } 73 74 // ConstructOURegex returns regex which is used to determine authID from cert subject 75 func ConstructOURegex(patterns ...string) string { 76 nonEmptyStr := make([]string, 0) 77 for _, pattern := range patterns { 78 if len(pattern) > 0 { 79 nonEmptyStr = append(nonEmptyStr, fmt.Sprintf("\\b%s\\b", pattern)) 80 } 81 } 82 return strings.Join(nonEmptyStr, "|") 83 } 84 85 // GetAllOrganizationalUnits returns all OU parts of the subject 86 func GetAllOrganizationalUnits(subject string) []string { 87 return getAllRegexMatches("OU=([^(,|+)]+)", subject) 88 } 89 90 // GetCountry returns the C part of the subject 91 func GetCountry(subject string) string { 92 return getRegexMatch("C=([^(,|+)]+)", subject) 93 } 94 95 // GetProvince returns the ST part of the subject 96 func GetProvince(subject string) string { 97 return getRegexMatch("ST=([^(,|+)]+)", subject) 98 } 99 100 // GetLocality returns the L part of the subject 101 func GetLocality(subject string) string { 102 return getRegexMatch("L=([^(,|+)]+)", subject) 103 } 104 105 // GetCommonName returns the CN part of the subject 106 func GetCommonName(subject string) string { 107 return getRegexMatch("CN=([^,]+)", subject) 108 } 109 110 // GetAuthSessionExtra returns an appropriate auth session extra for the given consumerType, accessLevel and internalConsumerID 111 func GetAuthSessionExtra(consumerType, internalConsumerID string, accessLevels []string) map[string]interface{} { 112 return map[string]interface{}{ 113 ConsumerTypeExtraField: consumerType, 114 AccessLevelsExtraField: accessLevels, 115 InternalConsumerIDField: internalConsumerID, 116 } 117 } 118 119 // DecodeCertificates accepts raw certificate chain and return slice of parsed certificates 120 func DecodeCertificates(pemCertChain []byte) ([]*x509.Certificate, error) { 121 if pemCertChain == nil { 122 return nil, errors.New("Certificate data is empty") 123 } 124 125 var certificates []*x509.Certificate 126 127 for block, rest := pem.Decode(pemCertChain); block != nil && rest != nil; { 128 cert, err := x509.ParseCertificate(block.Bytes) 129 if err != nil { 130 return nil, errors.Wrap(err, "Failed to decode one of the pem blocks") 131 } 132 133 certificates = append(certificates, cert) 134 135 block, rest = pem.Decode(rest) 136 } 137 138 if len(certificates) == 0 { 139 return nil, errors.New("No certificates found in the pem block") 140 } 141 142 return certificates, nil 143 } 144 145 // NewTLSCertificate creates tls certificate from given certificate chain in form of slice of certificates 146 func NewTLSCertificate(key *rsa.PrivateKey, certificates ...*x509.Certificate) tls.Certificate { 147 rawCerts := make([][]byte, len(certificates)) 148 for i, c := range certificates { 149 rawCerts[i] = c.Raw 150 } 151 152 return tls.Certificate{ 153 Certificate: rawCerts, 154 PrivateKey: key, 155 } 156 } 157 158 func getRegexMatch(regex, text string) string { 159 matches := getAllRegexMatches(regex, text) 160 if len(matches) > 0 { 161 return matches[0] 162 } 163 return "" 164 } 165 166 func getAllRegexMatches(regex, text string) []string { 167 cnRegex := regexp.MustCompile(regex) 168 matches := cnRegex.FindAllStringSubmatch(text, -1) 169 170 result := make([]string, 0, len(matches)) 171 for _, match := range matches { 172 if len(match) != 2 { 173 continue 174 } 175 result = append(result, match[1]) 176 } 177 178 return result 179 } 180 181 // ParseCertificate creates a tls.Certificate from certificate and key 182 // The cert/key can be in PEM format or can be base64 encoded 183 func ParseCertificate(cert string, key string) (*tls.Certificate, error) { 184 if cert == "" || key == "" { 185 return nil, errors.New("The cert/key is required") 186 } 187 188 certChainBytes := util.TryDecodeBase64(cert) 189 privateKeyBytes := util.TryDecodeBase64(key) 190 191 return ParseCertificateBytes(certChainBytes, privateKeyBytes) 192 } 193 194 // ParseCertificateBytes creates a tls.Certificate from certificate and key 195 func ParseCertificateBytes(cert []byte, key []byte) (*tls.Certificate, error) { 196 certs, err := DecodeCertificates(cert) 197 if err != nil { 198 return nil, errors.Wrap(err, "Error while decoding certificate pem block") 199 } 200 201 privateKeyPem, _ := pem.Decode(key) 202 if privateKeyPem == nil { 203 return nil, errors.New("Error while decoding private key pem block") 204 } 205 206 privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyPem.Bytes) 207 if err != nil { 208 pkcs8PrivateKey, err := x509.ParsePKCS8PrivateKey(privateKeyPem.Bytes) 209 if err != nil { 210 return nil, err 211 } 212 var ok bool 213 privateKey, ok = pkcs8PrivateKey.(*rsa.PrivateKey) 214 if !ok { 215 return nil, err 216 } 217 } 218 219 tlsCert := NewTLSCertificate(privateKey, certs...) 220 221 return &tlsCert, nil 222 }