go-micro.dev/v5@v5.12.0/cmd/micro/server/util_jwt.go (about)

     1  package server
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"errors"
     9  	"os"
    10  	"time"
    11  
    12  	"github.com/golang-jwt/jwt/v5"
    13  )
    14  
    15  var (
    16  	jwtPrivateKey *rsa.PrivateKey
    17  	jwtPublicKey  *rsa.PublicKey
    18  )
    19  
    20  // Load or generate RSA keys for JWT
    21  func InitJWTKeys(privPath, pubPath string) error {
    22  	var err error
    23  	if _, err = os.Stat(privPath); os.IsNotExist(err) {
    24  		priv, _ := rsa.GenerateKey(rand.Reader, 2048)
    25  		privBytes := x509.MarshalPKCS1PrivateKey(priv)
    26  		privPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: privBytes})
    27  		os.WriteFile(privPath, privPem, 0600)
    28  		pubBytes, _ := x509.MarshalPKIXPublicKey(&priv.PublicKey)
    29  		pubPem := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes})
    30  		os.WriteFile(pubPath, pubPem, 0644)
    31  	}
    32  	privPem, err := os.ReadFile(privPath)
    33  	if err != nil {
    34  		return err
    35  	}
    36  	block, _ := pem.Decode(privPem)
    37  	if block == nil {
    38  		return errors.New("invalid private key PEM")
    39  	}
    40  	jwtPrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
    41  	if err != nil {
    42  		return err
    43  	}
    44  	pubPem, err := os.ReadFile(pubPath)
    45  	if err != nil {
    46  		return err
    47  	}
    48  	block, _ = pem.Decode(pubPem)
    49  	if block == nil {
    50  		return errors.New("invalid public key PEM")
    51  	}
    52  	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
    53  	if err != nil {
    54  		return err
    55  	}
    56  	var ok bool
    57  	jwtPublicKey, ok = pub.(*rsa.PublicKey)
    58  	if !ok {
    59  		return errors.New("not RSA public key")
    60  	}
    61  	return nil
    62  }
    63  
    64  // Generate a JWT for a user
    65  func GenerateJWT(userID, userType string, scopes []string, expiry time.Duration) (string, error) {
    66  	claims := jwt.MapClaims{
    67  		"sub":    userID,
    68  		"type":   userType,
    69  		"scopes": scopes,
    70  		"exp":    time.Now().Add(expiry).Unix(),
    71  	}
    72  	token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
    73  	return token.SignedString(jwtPrivateKey)
    74  }
    75  
    76  // Parse and validate a JWT, returns claims if valid
    77  func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
    78  	token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
    79  		if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
    80  			return nil, errors.New("unexpected signing method")
    81  		}
    82  		return jwtPublicKey, nil
    83  	})
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
    88  		return claims, nil
    89  	}
    90  	return nil, errors.New("invalid token")
    91  }