code.gitea.io/gitea@v1.21.7/services/auth/source/oauth2/jwtsigningkey.go (about)

     1  // Copyright 2021 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  package oauth2
     5  
     6  import (
     7  	"crypto/ecdsa"
     8  	"crypto/ed25519"
     9  	"crypto/elliptic"
    10  	"crypto/rand"
    11  	"crypto/rsa"
    12  	"crypto/x509"
    13  	"encoding/base64"
    14  	"encoding/pem"
    15  	"fmt"
    16  	"math/big"
    17  	"os"
    18  	"path/filepath"
    19  	"strings"
    20  
    21  	"code.gitea.io/gitea/modules/log"
    22  	"code.gitea.io/gitea/modules/setting"
    23  	"code.gitea.io/gitea/modules/util"
    24  
    25  	"github.com/golang-jwt/jwt/v5"
    26  )
    27  
    28  // ErrInvalidAlgorithmType represents an invalid algorithm error.
    29  type ErrInvalidAlgorithmType struct {
    30  	Algorithm string
    31  }
    32  
    33  func (err ErrInvalidAlgorithmType) Error() string {
    34  	return fmt.Sprintf("JWT signing algorithm is not supported: %s", err.Algorithm)
    35  }
    36  
    37  // JWTSigningKey represents a algorithm/key pair to sign JWTs
    38  type JWTSigningKey interface {
    39  	IsSymmetric() bool
    40  	SigningMethod() jwt.SigningMethod
    41  	SignKey() any
    42  	VerifyKey() any
    43  	ToJWK() (map[string]string, error)
    44  	PreProcessToken(*jwt.Token)
    45  }
    46  
    47  type hmacSigningKey struct {
    48  	signingMethod jwt.SigningMethod
    49  	secret        []byte
    50  }
    51  
    52  func (key hmacSigningKey) IsSymmetric() bool {
    53  	return true
    54  }
    55  
    56  func (key hmacSigningKey) SigningMethod() jwt.SigningMethod {
    57  	return key.signingMethod
    58  }
    59  
    60  func (key hmacSigningKey) SignKey() any {
    61  	return key.secret
    62  }
    63  
    64  func (key hmacSigningKey) VerifyKey() any {
    65  	return key.secret
    66  }
    67  
    68  func (key hmacSigningKey) ToJWK() (map[string]string, error) {
    69  	return map[string]string{
    70  		"kty": "oct",
    71  		"alg": key.SigningMethod().Alg(),
    72  	}, nil
    73  }
    74  
    75  func (key hmacSigningKey) PreProcessToken(*jwt.Token) {}
    76  
    77  type rsaSingingKey struct {
    78  	signingMethod jwt.SigningMethod
    79  	key           *rsa.PrivateKey
    80  	id            string
    81  }
    82  
    83  func newRSASingingKey(signingMethod jwt.SigningMethod, key *rsa.PrivateKey) (rsaSingingKey, error) {
    84  	kid, err := util.CreatePublicKeyFingerprint(key.Public().(*rsa.PublicKey))
    85  	if err != nil {
    86  		return rsaSingingKey{}, err
    87  	}
    88  
    89  	return rsaSingingKey{
    90  		signingMethod,
    91  		key,
    92  		base64.RawURLEncoding.EncodeToString(kid),
    93  	}, nil
    94  }
    95  
    96  func (key rsaSingingKey) IsSymmetric() bool {
    97  	return false
    98  }
    99  
   100  func (key rsaSingingKey) SigningMethod() jwt.SigningMethod {
   101  	return key.signingMethod
   102  }
   103  
   104  func (key rsaSingingKey) SignKey() any {
   105  	return key.key
   106  }
   107  
   108  func (key rsaSingingKey) VerifyKey() any {
   109  	return key.key.Public()
   110  }
   111  
   112  func (key rsaSingingKey) ToJWK() (map[string]string, error) {
   113  	pubKey := key.key.Public().(*rsa.PublicKey)
   114  
   115  	return map[string]string{
   116  		"kty": "RSA",
   117  		"alg": key.SigningMethod().Alg(),
   118  		"kid": key.id,
   119  		"e":   base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pubKey.E)).Bytes()),
   120  		"n":   base64.RawURLEncoding.EncodeToString(pubKey.N.Bytes()),
   121  	}, nil
   122  }
   123  
   124  func (key rsaSingingKey) PreProcessToken(token *jwt.Token) {
   125  	token.Header["kid"] = key.id
   126  }
   127  
   128  type eddsaSigningKey struct {
   129  	signingMethod jwt.SigningMethod
   130  	key           ed25519.PrivateKey
   131  	id            string
   132  }
   133  
   134  func newEdDSASingingKey(signingMethod jwt.SigningMethod, key ed25519.PrivateKey) (eddsaSigningKey, error) {
   135  	kid, err := util.CreatePublicKeyFingerprint(key.Public().(ed25519.PublicKey))
   136  	if err != nil {
   137  		return eddsaSigningKey{}, err
   138  	}
   139  
   140  	return eddsaSigningKey{
   141  		signingMethod,
   142  		key,
   143  		base64.RawURLEncoding.EncodeToString(kid),
   144  	}, nil
   145  }
   146  
   147  func (key eddsaSigningKey) IsSymmetric() bool {
   148  	return false
   149  }
   150  
   151  func (key eddsaSigningKey) SigningMethod() jwt.SigningMethod {
   152  	return key.signingMethod
   153  }
   154  
   155  func (key eddsaSigningKey) SignKey() any {
   156  	return key.key
   157  }
   158  
   159  func (key eddsaSigningKey) VerifyKey() any {
   160  	return key.key.Public()
   161  }
   162  
   163  func (key eddsaSigningKey) ToJWK() (map[string]string, error) {
   164  	pubKey := key.key.Public().(ed25519.PublicKey)
   165  
   166  	return map[string]string{
   167  		"alg": key.SigningMethod().Alg(),
   168  		"kid": key.id,
   169  		"kty": "OKP",
   170  		"crv": "Ed25519",
   171  		"x":   base64.RawURLEncoding.EncodeToString(pubKey),
   172  	}, nil
   173  }
   174  
   175  func (key eddsaSigningKey) PreProcessToken(token *jwt.Token) {
   176  	token.Header["kid"] = key.id
   177  }
   178  
   179  type ecdsaSingingKey struct {
   180  	signingMethod jwt.SigningMethod
   181  	key           *ecdsa.PrivateKey
   182  	id            string
   183  }
   184  
   185  func newECDSASingingKey(signingMethod jwt.SigningMethod, key *ecdsa.PrivateKey) (ecdsaSingingKey, error) {
   186  	kid, err := util.CreatePublicKeyFingerprint(key.Public().(*ecdsa.PublicKey))
   187  	if err != nil {
   188  		return ecdsaSingingKey{}, err
   189  	}
   190  
   191  	return ecdsaSingingKey{
   192  		signingMethod,
   193  		key,
   194  		base64.RawURLEncoding.EncodeToString(kid),
   195  	}, nil
   196  }
   197  
   198  func (key ecdsaSingingKey) IsSymmetric() bool {
   199  	return false
   200  }
   201  
   202  func (key ecdsaSingingKey) SigningMethod() jwt.SigningMethod {
   203  	return key.signingMethod
   204  }
   205  
   206  func (key ecdsaSingingKey) SignKey() any {
   207  	return key.key
   208  }
   209  
   210  func (key ecdsaSingingKey) VerifyKey() any {
   211  	return key.key.Public()
   212  }
   213  
   214  func (key ecdsaSingingKey) ToJWK() (map[string]string, error) {
   215  	pubKey := key.key.Public().(*ecdsa.PublicKey)
   216  
   217  	return map[string]string{
   218  		"kty": "EC",
   219  		"alg": key.SigningMethod().Alg(),
   220  		"kid": key.id,
   221  		"crv": pubKey.Params().Name,
   222  		"x":   base64.RawURLEncoding.EncodeToString(pubKey.X.Bytes()),
   223  		"y":   base64.RawURLEncoding.EncodeToString(pubKey.Y.Bytes()),
   224  	}, nil
   225  }
   226  
   227  func (key ecdsaSingingKey) PreProcessToken(token *jwt.Token) {
   228  	token.Header["kid"] = key.id
   229  }
   230  
   231  // CreateJWTSigningKey creates a signing key from an algorithm / key pair.
   232  func CreateJWTSigningKey(algorithm string, key any) (JWTSigningKey, error) {
   233  	var signingMethod jwt.SigningMethod
   234  	switch algorithm {
   235  	case "HS256":
   236  		signingMethod = jwt.SigningMethodHS256
   237  	case "HS384":
   238  		signingMethod = jwt.SigningMethodHS384
   239  	case "HS512":
   240  		signingMethod = jwt.SigningMethodHS512
   241  
   242  	case "RS256":
   243  		signingMethod = jwt.SigningMethodRS256
   244  	case "RS384":
   245  		signingMethod = jwt.SigningMethodRS384
   246  	case "RS512":
   247  		signingMethod = jwt.SigningMethodRS512
   248  
   249  	case "ES256":
   250  		signingMethod = jwt.SigningMethodES256
   251  	case "ES384":
   252  		signingMethod = jwt.SigningMethodES384
   253  	case "ES512":
   254  		signingMethod = jwt.SigningMethodES512
   255  	case "EdDSA":
   256  		signingMethod = jwt.SigningMethodEdDSA
   257  	default:
   258  		return nil, ErrInvalidAlgorithmType{algorithm}
   259  	}
   260  
   261  	switch signingMethod.(type) {
   262  	case *jwt.SigningMethodEd25519:
   263  		privateKey, ok := key.(ed25519.PrivateKey)
   264  		if !ok {
   265  			return nil, jwt.ErrInvalidKeyType
   266  		}
   267  		return newEdDSASingingKey(signingMethod, privateKey)
   268  	case *jwt.SigningMethodECDSA:
   269  		privateKey, ok := key.(*ecdsa.PrivateKey)
   270  		if !ok {
   271  			return nil, jwt.ErrInvalidKeyType
   272  		}
   273  		return newECDSASingingKey(signingMethod, privateKey)
   274  	case *jwt.SigningMethodRSA:
   275  		privateKey, ok := key.(*rsa.PrivateKey)
   276  		if !ok {
   277  			return nil, jwt.ErrInvalidKeyType
   278  		}
   279  		return newRSASingingKey(signingMethod, privateKey)
   280  	default:
   281  		secret, ok := key.([]byte)
   282  		if !ok {
   283  			return nil, jwt.ErrInvalidKeyType
   284  		}
   285  		return hmacSigningKey{signingMethod, secret}, nil
   286  	}
   287  }
   288  
   289  // DefaultSigningKey is the default signing key for JWTs.
   290  var DefaultSigningKey JWTSigningKey
   291  
   292  // InitSigningKey creates the default signing key from settings or creates a random key.
   293  func InitSigningKey() error {
   294  	var err error
   295  	var key any
   296  
   297  	switch setting.OAuth2.JWTSigningAlgorithm {
   298  	case "HS256":
   299  		fallthrough
   300  	case "HS384":
   301  		fallthrough
   302  	case "HS512":
   303  		key = setting.GetGeneralTokenSigningSecret()
   304  	case "RS256":
   305  		fallthrough
   306  	case "RS384":
   307  		fallthrough
   308  	case "RS512":
   309  		fallthrough
   310  	case "ES256":
   311  		fallthrough
   312  	case "ES384":
   313  		fallthrough
   314  	case "ES512":
   315  		fallthrough
   316  	case "EdDSA":
   317  		key, err = loadOrCreateAsymmetricKey()
   318  	default:
   319  		return ErrInvalidAlgorithmType{setting.OAuth2.JWTSigningAlgorithm}
   320  	}
   321  
   322  	if err != nil {
   323  		return fmt.Errorf("Error while loading or creating JWT key: %w", err)
   324  	}
   325  
   326  	signingKey, err := CreateJWTSigningKey(setting.OAuth2.JWTSigningAlgorithm, key)
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	DefaultSigningKey = signingKey
   332  
   333  	return nil
   334  }
   335  
   336  // loadOrCreateAsymmetricKey checks if the configured private key exists.
   337  // If it does not exist a new random key gets generated and saved on the configured path.
   338  func loadOrCreateAsymmetricKey() (any, error) {
   339  	keyPath := setting.OAuth2.JWTSigningPrivateKeyFile
   340  
   341  	isExist, err := util.IsExist(keyPath)
   342  	if err != nil {
   343  		log.Fatal("Unable to check if %s exists. Error: %v", keyPath, err)
   344  	}
   345  	if !isExist {
   346  		err := func() error {
   347  			key, err := func() (any, error) {
   348  				switch {
   349  				case strings.HasPrefix(setting.OAuth2.JWTSigningAlgorithm, "RS"):
   350  					return rsa.GenerateKey(rand.Reader, 4096)
   351  				case setting.OAuth2.JWTSigningAlgorithm == "EdDSA":
   352  					_, pk, err := ed25519.GenerateKey(rand.Reader)
   353  					return pk, err
   354  				default:
   355  					return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   356  				}
   357  			}()
   358  			if err != nil {
   359  				return err
   360  			}
   361  
   362  			bytes, err := x509.MarshalPKCS8PrivateKey(key)
   363  			if err != nil {
   364  				return err
   365  			}
   366  
   367  			privateKeyPEM := &pem.Block{Type: "PRIVATE KEY", Bytes: bytes}
   368  
   369  			if err := os.MkdirAll(filepath.Dir(keyPath), os.ModePerm); err != nil {
   370  				return err
   371  			}
   372  
   373  			f, err := os.OpenFile(keyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600)
   374  			if err != nil {
   375  				return err
   376  			}
   377  			defer func() {
   378  				if err = f.Close(); err != nil {
   379  					log.Error("Close: %v", err)
   380  				}
   381  			}()
   382  
   383  			return pem.Encode(f, privateKeyPEM)
   384  		}()
   385  		if err != nil {
   386  			log.Fatal("Error generating private key: %v", err)
   387  			return nil, err
   388  		}
   389  	}
   390  
   391  	bytes, err := os.ReadFile(keyPath)
   392  	if err != nil {
   393  		return nil, err
   394  	}
   395  
   396  	block, _ := pem.Decode(bytes)
   397  	if block == nil {
   398  		return nil, fmt.Errorf("no valid PEM data found in %s", keyPath)
   399  	} else if block.Type != "PRIVATE KEY" {
   400  		return nil, fmt.Errorf("expected PRIVATE KEY, got %s in %s", block.Type, keyPath)
   401  	}
   402  
   403  	return x509.ParsePKCS8PrivateKey(block.Bytes)
   404  }