github.com/free5gc/openapi@v1.0.8/oauth/oauth.go (about)

     1  package oauth
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  	"crypto/x509"
     8  	"crypto/x509/pkix"
     9  	"encoding/pem"
    10  	"math/big"
    11  	"net/url"
    12  	"os"
    13  	"path/filepath"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/golang-jwt/jwt"
    18  	"github.com/pkg/errors"
    19  
    20  	"github.com/free5gc/openapi/models"
    21  )
    22  
    23  type CCAClaims struct {
    24  	Iat int32
    25  	Exp int32
    26  	jwt.StandardClaims
    27  }
    28  
    29  func GenerateClientCredentialAssertion(
    30  	sub, aud, keyPath string,
    31  ) (string, error) {
    32  	var expiration int32 = 1000
    33  	now := int32(time.Now().Unix())
    34  
    35  	accessTokenClaims := CCAClaims{
    36  		Iat: now,
    37  		Exp: now + expiration, // access_token is authorized for use
    38  		StandardClaims: jwt.StandardClaims{
    39  			Subject:  sub,
    40  			Audience: aud,
    41  		},
    42  	}
    43  
    44  	// Use RSA as a signing method
    45  	signKey, err := ParsePrivateKeyFromPEM(keyPath)
    46  	if err != nil {
    47  		return "", errors.Wrapf(err, "gen CCAClaims")
    48  	}
    49  	token := jwt.NewWithClaims(jwt.GetSigningMethod("RS512"), accessTokenClaims)
    50  	accessToken, err := token.SignedString(signKey)
    51  	if err != nil {
    52  		return "", errors.Wrapf(err, "gen CCAClaims")
    53  	}
    54  	return accessToken, nil
    55  }
    56  
    57  func VerifyOAuth(
    58  	authorization, serviceName, certPath string,
    59  ) error {
    60  	verifyKey, err := ParsePublicKeyFromPEM(certPath)
    61  	if err != nil {
    62  		return errors.Wrapf(err, "verify OAuth")
    63  	}
    64  
    65  	auth_fields := strings.Fields(authorization)
    66  	if len(auth_fields) < 2 {
    67  		return errors.Errorf("verify OAuth Authorization header invalid")
    68  	}
    69  
    70  	access_token := auth_fields[1]
    71  	token, err := jwt.ParseWithClaims(
    72  		access_token,
    73  		&models.AccessTokenClaims{},
    74  		func(token *jwt.Token) (interface{}, error) {
    75  			if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
    76  				return nil, errors.Wrapf(err, "Unexpected signing method")
    77  			}
    78  			if token.Header["alg"] != "RS512" {
    79  				return nil, errors.Wrapf(err, "Unexpected signing method")
    80  			}
    81  			return verifyKey, nil
    82  		})
    83  	if err != nil {
    84  		return errors.Wrapf(err, "verify OAuth parse")
    85  	}
    86  
    87  	if !verifyScope(token.Claims.(*models.AccessTokenClaims).Scope, serviceName) {
    88  		return errors.Wrapf(err, "verify OAuth scope")
    89  	}
    90  	return nil
    91  }
    92  
    93  func verifyScope(scope, serviceName string) bool {
    94  	if len(serviceName) == 0 {
    95  		return true
    96  	}
    97  	if len(scope) != 0 {
    98  		scopeSplit := strings.Fields(scope)
    99  		found := false
   100  		for _, item := range scopeSplit {
   101  			if item == serviceName {
   102  				found = true
   103  				break
   104  			}
   105  		}
   106  		if !found {
   107  			return false
   108  		}
   109  	} else {
   110  		return false
   111  	}
   112  	return true
   113  }
   114  
   115  func GenerateRootCertificate(
   116  	rootCertPath string,
   117  	rootPrivKey *rsa.PrivateKey,
   118  ) (*x509.Certificate, error) {
   119  	rootCert, err := GenerateCertificate(
   120  		"", "", rootCertPath, &rootPrivKey.PublicKey, nil, rootPrivKey)
   121  	if err != nil {
   122  		return nil, errors.Wrapf(err, "gen root cert")
   123  	}
   124  	return rootCert, nil
   125  }
   126  
   127  func GenerateCertificate(
   128  	nfType, nfId, certPemPath string,
   129  	pubKey *rsa.PublicKey,
   130  	rootCert *x509.Certificate,
   131  	rootPrivKey *rsa.PrivateKey,
   132  ) (*x509.Certificate, error) {
   133  	max := new(big.Int)
   134  	max.Exp(big.NewInt(16), big.NewInt(40), nil)
   135  	sn, _ := rand.Int(rand.Reader, max)
   136  
   137  	temp := &x509.Certificate{
   138  		SerialNumber: sn,
   139  		Subject: pkix.Name{
   140  			Country:            []string{"TW"},
   141  			Province:           []string{"Taiwan"},
   142  			Locality:           []string{"Hsinchu"},
   143  			Organization:       []string{"free5gc"},
   144  			OrganizationalUnit: []string{"free5gc"},
   145  		},
   146  		NotBefore: time.Now(),
   147  		NotAfter:  time.Now().AddDate(10, 0, 0),
   148  	}
   149  
   150  	if nfType != "" {
   151  		temp.Subject.CommonName = strings.ToUpper(nfType)
   152  		temp.DNSNames = []string{nfType}
   153  	}
   154  	if nfId != "" {
   155  		uri, err := url.Parse("urn:uuid:" + nfId)
   156  		if err != nil {
   157  			return nil, errors.Wrapf(err, "gen cert url")
   158  		}
   159  		temp.URIs = []*url.URL{uri}
   160  	}
   161  	if rootCert == nil {
   162  		// generate self-signed certificate
   163  		rootCert = temp
   164  	}
   165  
   166  	b, err := x509.CreateCertificate(
   167  		rand.Reader, temp, rootCert, pubKey, rootPrivKey)
   168  	if err != nil {
   169  		return nil, errors.Wrapf(err, "gen cert create")
   170  	}
   171  
   172  	cert, err := x509.ParseCertificate(b)
   173  	if err != nil {
   174  		return nil, errors.Wrapf(err, "gen cert parse")
   175  	}
   176  
   177  	if certPemPath != "" {
   178  		out := &bytes.Buffer{}
   179  		err = pem.Encode(out, &pem.Block{Type: "CERTIFICATE", Bytes: b})
   180  		if err != nil {
   181  			return nil, errors.Wrapf(err, "gen cert file encode")
   182  		}
   183  
   184  		certFile, err := os.Create(certPemPath)
   185  		if err != nil {
   186  			return nil, errors.Wrapf(err, "gen cert file create")
   187  		}
   188  		defer certFile.Close() // nolint
   189  
   190  		_, err = out.WriteTo(certFile)
   191  		if err != nil {
   192  			return nil, errors.Wrapf(err, "gen cert file write")
   193  		}
   194  	}
   195  
   196  	return cert, nil
   197  }
   198  
   199  func ParsePublicKeyFromPEM(pubPemPath string) (*rsa.PublicKey, error) {
   200  	b, err := os.ReadFile(pubPemPath)
   201  	if err != nil {
   202  		return nil, errors.Wrapf(err, "pubkey read")
   203  	}
   204  
   205  	pubKey, err := jwt.ParseRSAPublicKeyFromPEM(b)
   206  	if err != nil {
   207  		return nil, errors.Wrapf(err, "pubkey parse")
   208  	}
   209  
   210  	return pubKey, nil
   211  }
   212  
   213  func ParsePrivateKeyFromPEM(privPemPath string) (*rsa.PrivateKey, error) {
   214  	b, err := os.ReadFile(privPemPath)
   215  	if err != nil {
   216  		return nil, errors.Wrapf(err, "privkey read")
   217  	}
   218  
   219  	privKey, err := jwt.ParseRSAPrivateKeyFromPEM(b)
   220  	if err != nil {
   221  		return nil, errors.Wrapf(err, "privkey parse")
   222  	}
   223  
   224  	return privKey, nil
   225  }
   226  
   227  func ParseCertFromPEM(certPemPath string) (*x509.Certificate, error) {
   228  	b, err := os.ReadFile(certPemPath)
   229  	if err != nil {
   230  		return nil, errors.Wrapf(err, "read cert pem")
   231  	}
   232  
   233  	block, _ := pem.Decode(b)
   234  	cert, err := x509.ParseCertificate(block.Bytes)
   235  	if err != nil {
   236  		return nil, errors.Wrapf(err, "parse cert pem")
   237  	}
   238  
   239  	return cert, nil
   240  }
   241  
   242  func GenerateRSAKeyPair(pubPemPath, privPemPath string) (*rsa.PrivateKey, error) {
   243  	const rsaKeyBitSize = 2048
   244  
   245  	// generate key
   246  	privKey, err := rsa.GenerateKey(rand.Reader, rsaKeyBitSize)
   247  	if err != nil {
   248  		return nil, errors.Wrapf(err, "generate rsa key")
   249  	}
   250  
   251  	if pubPemPath != "" {
   252  		// dump public key to file
   253  		pubPem, err := os.Create(pubPemPath)
   254  		if err != nil {
   255  			return nil, errors.Wrapf(err, "generate rsa pub key create file")
   256  		}
   257  		defer pubPem.Close() // nolint
   258  
   259  		pubKey := &privKey.PublicKey
   260  		pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
   261  		if err != nil {
   262  			return nil, errors.Wrapf(err, "generate rsa pub key marshal")
   263  		}
   264  		pubKeyBlock := &pem.Block{
   265  			Type:  "PUBLIC KEY",
   266  			Bytes: pubKeyBytes,
   267  		}
   268  
   269  		err = pem.Encode(pubPem, pubKeyBlock)
   270  		if err != nil {
   271  			return nil, errors.Wrapf(err, "generate rsa pub key pem")
   272  		}
   273  	}
   274  
   275  	if privPemPath != "" {
   276  		// dump private key to file
   277  		privPem, err := os.Create(privPemPath)
   278  		if err != nil {
   279  			return nil, errors.Wrapf(err, "generate rsa priv key create file")
   280  		}
   281  		defer privPem.Close() // nolint
   282  
   283  		privKeyBlock := &pem.Block{
   284  			Type:  "PRIVATE KEY",
   285  			Bytes: x509.MarshalPKCS1PrivateKey(privKey),
   286  		}
   287  
   288  		err = pem.Encode(privPem, privKeyBlock)
   289  		if err != nil {
   290  			return nil, errors.Wrapf(err, "generate rsa priv key pem")
   291  		}
   292  	}
   293  	return privKey, nil
   294  }
   295  
   296  func GetNFCertFileName(nfType, nfId string) string {
   297  	if nfId != "" {
   298  		return strings.ToLower(nfType) + "_" + nfId + ".pem"
   299  	}
   300  	return strings.ToLower(nfType) + ".pem"
   301  }
   302  
   303  func GetNFCertPath(base, nfType, nfId string) string {
   304  	// Note: NF's cert should be put in the same base path
   305  	return filepath.Join(base, GetNFCertFileName(nfType, nfId))
   306  }