github.com/avenga/couper@v1.12.2/eval/lib/jwt.go (about)

     1  package lib
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/x509"
     6  	"encoding/json"
     7  	"encoding/pem"
     8  	"fmt"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/golang-jwt/jwt/v4"
    13  	"github.com/hashicorp/hcl/v2"
    14  	"github.com/zclconf/go-cty/cty"
    15  	"github.com/zclconf/go-cty/cty/function"
    16  	"github.com/zclconf/go-cty/cty/function/stdlib"
    17  
    18  	acjwt "github.com/avenga/couper/accesscontrol/jwt"
    19  	"github.com/avenga/couper/config"
    20  	"github.com/avenga/couper/config/reader"
    21  	"github.com/avenga/couper/internal/seetie"
    22  )
    23  
    24  const FnJWTSign = "jwt_sign"
    25  
    26  type JWTSigningConfig struct {
    27  	Claims             config.Claims
    28  	Headers            hcl.Expression
    29  	Key                interface{}
    30  	SignatureAlgorithm string
    31  	TTL                int64
    32  }
    33  
    34  func checkData(ttl, signatureAlgorithm string) (int64, acjwt.Algorithm, error) {
    35  	alg := acjwt.NewAlgorithm(signatureAlgorithm)
    36  	if alg == acjwt.AlgorithmUnknown {
    37  		return 0, alg, fmt.Errorf("algorithm is not supported")
    38  	}
    39  
    40  	if ttl != "0" {
    41  		dur, err := time.ParseDuration(ttl)
    42  		if err != nil {
    43  			return 0, alg, err
    44  		}
    45  		return int64(dur.Seconds()), alg, nil
    46  	}
    47  
    48  	return 0, alg, nil
    49  }
    50  
    51  func getKey(keyBytes []byte, signatureAlgorithm string) (interface{}, error) {
    52  	var (
    53  		key      interface{}
    54  		parseErr error
    55  	)
    56  	key = keyBytes
    57  	if strings.HasPrefix(signatureAlgorithm, "RS") {
    58  		key, parseErr = jwt.ParseRSAPrivateKeyFromPEM(keyBytes)
    59  	} else if strings.HasPrefix(signatureAlgorithm, "ES") {
    60  		key, parseErr = parseECPrivateKeyFromPEM(keyBytes)
    61  	}
    62  
    63  	return key, parseErr
    64  }
    65  
    66  func NewJWTSigningConfigFromJWTSigningProfile(j *config.JWTSigningProfile, algCheckFunc func(alg acjwt.Algorithm) error) (*JWTSigningConfig, error) {
    67  	ttl, alg, err := checkData(j.TTL, j.SignatureAlgorithm)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	if algCheckFunc != nil {
    73  		if err = algCheckFunc(alg); err != nil {
    74  			return nil, err
    75  		}
    76  	}
    77  
    78  	keyBytes, err := reader.ReadFromAttrFile("jwt_signing_profile key", j.Key, j.KeyFile)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	key, err := getKey(keyBytes, j.SignatureAlgorithm)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	c := &JWTSigningConfig{
    89  		Claims:             j.Claims,
    90  		Headers:            j.Headers,
    91  		Key:                key,
    92  		SignatureAlgorithm: j.SignatureAlgorithm,
    93  		TTL:                ttl,
    94  	}
    95  	return c, nil
    96  }
    97  
    98  func NewJWTSigningConfigFromJWT(j *config.JWT) (*JWTSigningConfig, error) {
    99  	if j.SigningTTL == "" {
   100  		return nil, nil
   101  	}
   102  
   103  	ttl, alg, err := checkData(j.SigningTTL, j.SignatureAlgorithm)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	var signingKey, signingKeyFile string
   109  
   110  	if alg.IsHMAC() {
   111  		signingKey = j.Key
   112  		signingKeyFile = j.KeyFile
   113  	} else {
   114  		signingKey = j.SigningKey
   115  		signingKeyFile = j.SigningKeyFile
   116  	}
   117  	keyBytes, err := reader.ReadFromAttrFile("jwt signing key", signingKey, signingKeyFile)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	key, err := getKey(keyBytes, j.SignatureAlgorithm)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	c := &JWTSigningConfig{
   128  		Claims:             j.Claims,
   129  		Key:                key,
   130  		SignatureAlgorithm: j.SignatureAlgorithm,
   131  		TTL:                ttl,
   132  	}
   133  	return c, nil
   134  }
   135  
   136  var NoOpJwtSignFunction = function.New(&function.Spec{
   137  	Params: []function.Parameter{
   138  		{
   139  			Name: "jwt_signing_profile_label",
   140  			Type: cty.String,
   141  		},
   142  		{
   143  			Name: "claims",
   144  			Type: cty.DynamicPseudoType,
   145  		},
   146  	},
   147  	Type: function.StaticReturnType(cty.String),
   148  	Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) {
   149  		if len(args) > 0 {
   150  			return cty.StringVal(""), fmt.Errorf("missing jwt_signing_profile or jwt (with signing_ttl) block with referenced label %q", args[0].AsString())
   151  		}
   152  		return cty.StringVal(""), fmt.Errorf("missing jwt_signing_profile or jwt (with signing_ttl) definitions")
   153  	},
   154  })
   155  
   156  func NewJwtSignFunction(ctx *hcl.EvalContext, jwtSigningConfigs map[string]*JWTSigningConfig,
   157  	evalFn func(*hcl.EvalContext, hcl.Expression) (cty.Value, error)) function.Function {
   158  	return function.New(&function.Spec{
   159  		Params: []function.Parameter{
   160  			{
   161  				Name: "jwt_signing_profile_label",
   162  				Type: cty.String,
   163  			},
   164  			{
   165  				Name: "claims",
   166  				Type: cty.DynamicPseudoType,
   167  			},
   168  		},
   169  		Type: function.StaticReturnType(cty.String),
   170  		Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) {
   171  			label := args[0].AsString()
   172  			signingConfig, exist := jwtSigningConfigs[label]
   173  			if !exist {
   174  				return NoOpJwtSignFunction.Call(args)
   175  			}
   176  
   177  			var claims, argumentClaims, headers map[string]interface{}
   178  
   179  			if signingConfig.Headers != nil {
   180  				h, diags := evalFn(ctx, signingConfig.Headers)
   181  				if diags != nil {
   182  					return cty.StringVal(""), diags
   183  				}
   184  				headers = seetie.ValueToMap(h)
   185  			}
   186  
   187  			// get claims from signing profile
   188  			if signingConfig.Claims != nil {
   189  				v, diags := evalFn(ctx, signingConfig.Claims)
   190  				if diags != nil {
   191  					return cty.StringVal(""), err
   192  				}
   193  				claims = seetie.ValueToMap(v)
   194  			} else {
   195  				claims = make(map[string]interface{})
   196  			}
   197  
   198  			if signingConfig.TTL != 0 {
   199  				claims["exp"] = time.Now().Unix() + signingConfig.TTL
   200  			}
   201  
   202  			// get claims from function argument
   203  			jsonClaims, err := stdlib.JSONEncode(args[1])
   204  			if err != nil {
   205  				return cty.StringVal(""), err
   206  			}
   207  
   208  			err = json.Unmarshal([]byte(jsonClaims.AsString()), &argumentClaims)
   209  			if err != nil {
   210  				return cty.StringVal(""), err
   211  			}
   212  
   213  			for k, v := range argumentClaims {
   214  				claims[k] = v
   215  			}
   216  
   217  			tokenString, err := CreateJWT(signingConfig.SignatureAlgorithm, signingConfig.Key, claims, headers)
   218  			if err != nil {
   219  				return cty.StringVal(""), err
   220  			}
   221  
   222  			return cty.StringVal(tokenString), nil
   223  		},
   224  	})
   225  }
   226  
   227  func CreateJWT(signatureAlgorithm string, key interface{}, mapClaims jwt.MapClaims, headers map[string]interface{}) (string, error) {
   228  	signingMethod := jwt.GetSigningMethod(signatureAlgorithm)
   229  	if signingMethod == nil {
   230  		return "", fmt.Errorf("no signing method for given algorithm: %s", signatureAlgorithm)
   231  	}
   232  
   233  	if headers == nil {
   234  		headers = map[string]interface{}{}
   235  	}
   236  
   237  	if _, set := headers["typ"]; !set {
   238  		headers["typ"] = "JWT"
   239  	}
   240  	headers["alg"] = signingMethod.Alg()
   241  
   242  	// create token
   243  	token := &jwt.Token{Header: headers, Claims: mapClaims, Method: signingMethod}
   244  
   245  	// sign token
   246  	return token.SignedString(key)
   247  }
   248  
   249  func parseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
   250  	var err error
   251  
   252  	// Parse PEM block
   253  	var block *pem.Block
   254  	if block, _ = pem.Decode(key); block == nil {
   255  		return nil, jwt.ErrKeyMustBePEMEncoded
   256  	}
   257  
   258  	// Parse the key
   259  	var parsedKey interface{}
   260  	if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
   261  		if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
   262  			return nil, err
   263  		}
   264  	}
   265  
   266  	var pkey *ecdsa.PrivateKey
   267  	var ok bool
   268  	if pkey, ok = parsedKey.(*ecdsa.PrivateKey); !ok {
   269  		return nil, jwt.ErrNotECPrivateKey
   270  	}
   271  
   272  	return pkey, nil
   273  }