github.com/rajeev159/opa@v0.45.0/topdown/crypto.go (about)

     1  // Copyright 2018 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package topdown
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/hmac"
    10  	"crypto/md5"
    11  	"crypto/sha1"
    12  	"crypto/sha256"
    13  	"crypto/sha512"
    14  	"crypto/x509"
    15  	"encoding/base64"
    16  	"encoding/json"
    17  	"encoding/pem"
    18  	"fmt"
    19  	"hash"
    20  	"io/ioutil"
    21  	"os"
    22  	"strings"
    23  
    24  	"github.com/open-policy-agent/opa/ast"
    25  	"github.com/open-policy-agent/opa/internal/jwx/jwk"
    26  	"github.com/open-policy-agent/opa/topdown/builtins"
    27  	"github.com/open-policy-agent/opa/util"
    28  )
    29  
    30  const (
    31  	// blockTypeCertificate indicates this PEM block contains the signed certificate.
    32  	// Exported for tests.
    33  	blockTypeCertificate = "CERTIFICATE"
    34  	// blockTypeCertificateRequest indicates this PEM block contains a certificate
    35  	// request. Exported for tests.
    36  	blockTypeCertificateRequest = "CERTIFICATE REQUEST"
    37  	// blockTypeRSAPrivateKey indicates this PEM block contains a RSA private key.
    38  	// Exported for tests.
    39  	blockTypeRSAPrivateKey = "RSA PRIVATE KEY"
    40  	// blockTypeRSAPrivateKey indicates this PEM block contains a RSA private key.
    41  	// Exported for tests.
    42  	blockTypePrivateKey = "PRIVATE KEY"
    43  )
    44  
    45  func builtinCryptoX509ParseCertificates(a ast.Value) (ast.Value, error) {
    46  	input, err := builtins.StringOperand(a, 1)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	certs, err := getX509CertsFromString(string(input))
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	return ast.InterfaceToValue(certs)
    57  }
    58  
    59  func builtinCryptoX509ParseAndVerifyCertificates(
    60  	_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
    61  
    62  	a := args[0].Value
    63  	input, err := builtins.StringOperand(a, 1)
    64  	if err != nil {
    65  		return err
    66  	}
    67  
    68  	invalid := ast.ArrayTerm(
    69  		ast.BooleanTerm(false),
    70  		ast.NewTerm(ast.NewArray()),
    71  	)
    72  
    73  	certs, err := getX509CertsFromString(string(input))
    74  	if err != nil {
    75  		return iter(invalid)
    76  	}
    77  
    78  	verified, err := verifyX509CertificateChain(certs)
    79  	if err != nil {
    80  		return iter(invalid)
    81  	}
    82  
    83  	value, err := ast.InterfaceToValue(verified)
    84  	if err != nil {
    85  		return err
    86  	}
    87  
    88  	valid := ast.ArrayTerm(
    89  		ast.BooleanTerm(true),
    90  		ast.NewTerm(value),
    91  	)
    92  
    93  	return iter(valid)
    94  }
    95  
    96  func builtinCryptoX509ParseCertificateRequest(a ast.Value) (ast.Value, error) {
    97  
    98  	input, err := builtins.StringOperand(a, 1)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	// data to be passed to x509.ParseCertificateRequest
   104  	bytes := []byte(input)
   105  
   106  	// if the input is not a PEM string, attempt to decode b64
   107  	if str := string(input); !strings.HasPrefix(str, "-----BEGIN CERTIFICATE REQUEST-----") {
   108  		bytes, err = base64.StdEncoding.DecodeString(str)
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  	}
   113  
   114  	p, _ := pem.Decode(bytes)
   115  	if p != nil && p.Type != blockTypeCertificateRequest {
   116  		return nil, fmt.Errorf("invalid PEM-encoded certificate signing request")
   117  	}
   118  	if p != nil {
   119  		bytes = p.Bytes
   120  	}
   121  
   122  	csr, err := x509.ParseCertificateRequest(bytes)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	bs, err := json.Marshal(csr)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	var x interface{}
   133  	if err := util.UnmarshalJSON(bs, &x); err != nil {
   134  		return nil, err
   135  	}
   136  	return ast.InterfaceToValue(x)
   137  }
   138  
   139  func builtinCryptoX509ParseRSAPrivateKey(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
   140  
   141  	a := args[0].Value
   142  	input, err := builtins.StringOperand(a, 1)
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	// get the raw private key
   148  	rawKey, err := getRSAPrivateKeyFromString(string(input))
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	rsaPrivateKey, err := jwk.New(rawKey)
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	jsonKey, err := json.Marshal(rsaPrivateKey)
   159  	if err != nil {
   160  		return err
   161  	}
   162  
   163  	var x interface{}
   164  	if err := util.UnmarshalJSON(jsonKey, &x); err != nil {
   165  		return err
   166  	}
   167  
   168  	value, err := ast.InterfaceToValue(x)
   169  	if err != nil {
   170  		return err
   171  	}
   172  
   173  	return iter(ast.NewTerm(value))
   174  }
   175  
   176  func hashHelper(a ast.Value, h func(ast.String) string) (ast.Value, error) {
   177  	s, err := builtins.StringOperand(a, 1)
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	return ast.String(h(s)), nil
   182  }
   183  
   184  func builtinCryptoMd5(a ast.Value) (ast.Value, error) {
   185  	return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", md5.Sum([]byte(s))) })
   186  }
   187  
   188  func builtinCryptoSha1(a ast.Value) (ast.Value, error) {
   189  	return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", sha1.Sum([]byte(s))) })
   190  }
   191  
   192  func builtinCryptoSha256(a ast.Value) (ast.Value, error) {
   193  	return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", sha256.Sum256([]byte(s))) })
   194  }
   195  
   196  func hmacHelper(args []*ast.Term, iter func(*ast.Term) error, h func() hash.Hash) error {
   197  	a1 := args[0].Value
   198  	message, err := builtins.StringOperand(a1, 1)
   199  	if err != nil {
   200  		return err
   201  	}
   202  
   203  	a2 := args[1].Value
   204  	key, err := builtins.StringOperand(a2, 2)
   205  	if err != nil {
   206  		return err
   207  	}
   208  
   209  	mac := hmac.New(h, []byte(key))
   210  	mac.Write([]byte(message))
   211  	messageDigest := mac.Sum(nil)
   212  
   213  	return iter(ast.StringTerm(fmt.Sprintf("%x", messageDigest)))
   214  }
   215  
   216  func builtinCryptoHmacMd5(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
   217  	return hmacHelper(args, iter, md5.New)
   218  }
   219  
   220  func builtinCryptoHmacSha1(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
   221  	return hmacHelper(args, iter, sha1.New)
   222  }
   223  
   224  func builtinCryptoHmacSha256(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
   225  	return hmacHelper(args, iter, sha256.New)
   226  }
   227  
   228  func builtinCryptoHmacSha512(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
   229  	return hmacHelper(args, iter, sha512.New)
   230  }
   231  
   232  func init() {
   233  	RegisterFunctionalBuiltin1(ast.CryptoX509ParseCertificates.Name, builtinCryptoX509ParseCertificates)
   234  	RegisterBuiltinFunc(ast.CryptoX509ParseAndVerifyCertificates.Name, builtinCryptoX509ParseAndVerifyCertificates)
   235  	RegisterFunctionalBuiltin1(ast.CryptoMd5.Name, builtinCryptoMd5)
   236  	RegisterFunctionalBuiltin1(ast.CryptoSha1.Name, builtinCryptoSha1)
   237  	RegisterFunctionalBuiltin1(ast.CryptoSha256.Name, builtinCryptoSha256)
   238  	RegisterFunctionalBuiltin1(ast.CryptoX509ParseCertificateRequest.Name, builtinCryptoX509ParseCertificateRequest)
   239  	RegisterBuiltinFunc(ast.CryptoX509ParseRSAPrivateKey.Name, builtinCryptoX509ParseRSAPrivateKey)
   240  	RegisterBuiltinFunc(ast.CryptoHmacMd5.Name, builtinCryptoHmacMd5)
   241  	RegisterBuiltinFunc(ast.CryptoHmacSha1.Name, builtinCryptoHmacSha1)
   242  	RegisterBuiltinFunc(ast.CryptoHmacSha256.Name, builtinCryptoHmacSha256)
   243  	RegisterBuiltinFunc(ast.CryptoHmacSha512.Name, builtinCryptoHmacSha512)
   244  }
   245  
   246  func verifyX509CertificateChain(certs []*x509.Certificate) ([]*x509.Certificate, error) {
   247  	if len(certs) < 2 {
   248  		return nil, builtins.NewOperandErr(1, "must supply at least two certificates to be able to verify")
   249  	}
   250  
   251  	// first cert is the root
   252  	roots := x509.NewCertPool()
   253  	roots.AddCert(certs[0])
   254  
   255  	// all other certs except the last are intermediates
   256  	intermediates := x509.NewCertPool()
   257  	for i := 1; i < len(certs)-1; i++ {
   258  		intermediates.AddCert(certs[i])
   259  	}
   260  
   261  	// last cert is the leaf
   262  	leaf := certs[len(certs)-1]
   263  
   264  	// verify the cert chain back to the root
   265  	verifyOpts := x509.VerifyOptions{
   266  		Roots:         roots,
   267  		Intermediates: intermediates,
   268  	}
   269  	chains, err := leaf.Verify(verifyOpts)
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	return chains[0], nil
   275  }
   276  
   277  func getX509CertsFromString(certs string) ([]*x509.Certificate, error) {
   278  	// if the input is PEM handle that
   279  	if strings.HasPrefix(certs, "-----BEGIN") {
   280  		return getX509CertsFromPem([]byte(certs))
   281  	}
   282  
   283  	// assume input is base64 if not PEM
   284  	b64, err := base64.StdEncoding.DecodeString(certs)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  
   289  	// handle if the decoded base64 contains PEM rather than the expected DER
   290  	if bytes.HasPrefix(b64, []byte("-----BEGIN")) {
   291  		return getX509CertsFromPem(b64)
   292  	}
   293  
   294  	// otherwise assume the contents are DER
   295  	return x509.ParseCertificates(b64)
   296  }
   297  
   298  func getX509CertsFromPem(pemBlocks []byte) ([]*x509.Certificate, error) {
   299  	var decodedCerts []byte
   300  	for len(pemBlocks) > 0 {
   301  		p, r := pem.Decode(pemBlocks)
   302  		if p != nil && p.Type != blockTypeCertificate {
   303  			return nil, fmt.Errorf("PEM block type is '%s', expected %s", p.Type, blockTypeCertificate)
   304  		}
   305  
   306  		if p == nil {
   307  			break
   308  		}
   309  
   310  		pemBlocks = r
   311  		decodedCerts = append(decodedCerts, p.Bytes...)
   312  	}
   313  
   314  	return x509.ParseCertificates(decodedCerts)
   315  }
   316  
   317  func getRSAPrivateKeyFromString(key string) (interface{}, error) {
   318  	// if the input is PEM handle that
   319  	if strings.HasPrefix(key, "-----BEGIN") {
   320  		return getRSAPrivateKeyFromPEM([]byte(key))
   321  	}
   322  
   323  	// assume input is base64 if not PEM
   324  	b64, err := base64.StdEncoding.DecodeString(key)
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  
   329  	return getRSAPrivateKeyFromPEM(b64)
   330  }
   331  
   332  func getRSAPrivateKeyFromPEM(pemBlocks []byte) (interface{}, error) {
   333  
   334  	// decode the pem into the Block struct
   335  	p, _ := pem.Decode(pemBlocks)
   336  	if p == nil {
   337  		return nil, fmt.Errorf("failed to parse PEM block containing the key")
   338  	}
   339  
   340  	// if the key is in PKCS1 format
   341  	if p.Type == blockTypeRSAPrivateKey {
   342  		return x509.ParsePKCS1PrivateKey(p.Bytes)
   343  	}
   344  
   345  	// if the key is in PKCS8 format
   346  	if p.Type == blockTypePrivateKey {
   347  		return x509.ParsePKCS8PrivateKey(p.Bytes)
   348  	}
   349  
   350  	// unsupported key format
   351  	return nil, fmt.Errorf("PEM block type is '%s', expected %s or %s", p.Type, blockTypeRSAPrivateKey,
   352  		blockTypePrivateKey)
   353  
   354  }
   355  
   356  // addCACertsFromFile adds CA certificates from filePath into the given pool.
   357  // If pool is nil, it creates a new x509.CertPool. pool is returned.
   358  func addCACertsFromFile(pool *x509.CertPool, filePath string) (*x509.CertPool, error) {
   359  	if pool == nil {
   360  		pool = x509.NewCertPool()
   361  	}
   362  
   363  	caCert, err := readCertFromFile(filePath)
   364  	if err != nil {
   365  		return nil, err
   366  	}
   367  
   368  	if ok := pool.AppendCertsFromPEM(caCert); !ok {
   369  		return nil, fmt.Errorf("could not append CA certificates from %q", filePath)
   370  	}
   371  
   372  	return pool, nil
   373  }
   374  
   375  // addCACertsFromBytes adds CA certificates from pemBytes into the given pool.
   376  // If pool is nil, it creates a new x509.CertPool. pool is returned.
   377  func addCACertsFromBytes(pool *x509.CertPool, pemBytes []byte) (*x509.CertPool, error) {
   378  	if pool == nil {
   379  		pool = x509.NewCertPool()
   380  	}
   381  
   382  	if ok := pool.AppendCertsFromPEM(pemBytes); !ok {
   383  		return nil, fmt.Errorf("could not append certificates")
   384  	}
   385  
   386  	return pool, nil
   387  }
   388  
   389  // addCACertsFromBytes adds CA certificates from the environment variable named
   390  // by envName into the given pool. If pool is nil, it creates a new x509.CertPool.
   391  // pool is returned.
   392  func addCACertsFromEnv(pool *x509.CertPool, envName string) (*x509.CertPool, error) {
   393  	pool, err := addCACertsFromBytes(pool, []byte(os.Getenv(envName)))
   394  	if err != nil {
   395  		return nil, fmt.Errorf("could not add CA certificates from envvar %q: %w", envName, err)
   396  	}
   397  
   398  	return pool, err
   399  }
   400  
   401  // ReadCertFromFile reads a cert from file
   402  func readCertFromFile(localCertFile string) ([]byte, error) {
   403  	// Read in the cert file
   404  	certPEM, err := ioutil.ReadFile(localCertFile)
   405  	if err != nil {
   406  		return nil, err
   407  	}
   408  	return certPEM, nil
   409  }
   410  
   411  // ReadKeyFromFile reads a key from file
   412  func readKeyFromFile(localKeyFile string) ([]byte, error) {
   413  	// Read in the cert file
   414  	key, err := ioutil.ReadFile(localKeyFile)
   415  	if err != nil {
   416  		return nil, err
   417  	}
   418  	return key, nil
   419  }