github.com/hyperledger/aries-framework-go@v0.3.2/pkg/doc/jwt/jwt.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package jwt
     8  
     9  import (
    10  	"bytes"
    11  	"encoding/base64"
    12  	"errors"
    13  	"fmt"
    14  	"reflect"
    15  	"strings"
    16  
    17  	"github.com/go-jose/go-jose/v3/json"
    18  	"github.com/go-jose/go-jose/v3/jwt"
    19  
    20  	"github.com/hyperledger/aries-framework-go/pkg/doc/jose"
    21  )
    22  
    23  const (
    24  	// TypeJWT defines JWT type.
    25  	TypeJWT = "JWT"
    26  
    27  	// AlgorithmNone used to indicate unsecured JWT.
    28  	AlgorithmNone = "none"
    29  )
    30  
    31  // Claims defines JSON Web Token Claims (https://tools.ietf.org/html/rfc7519#section-4)
    32  type Claims jwt.Claims
    33  
    34  // jwtParseOpts holds options for the JWT parsing.
    35  type parseOpts struct {
    36  	detachedPayload         []byte
    37  	sigVerifier             jose.SignatureVerifier
    38  	ignoreClaimsMapDecoding bool
    39  }
    40  
    41  // ParseOpt is the JWT Parser option.
    42  type ParseOpt func(opts *parseOpts)
    43  
    44  // WithJWTDetachedPayload option is for definition of JWT detached payload.
    45  func WithJWTDetachedPayload(payload []byte) ParseOpt {
    46  	return func(opts *parseOpts) {
    47  		opts.detachedPayload = payload
    48  	}
    49  }
    50  
    51  // WithIgnoreClaimsMapDecoding option is for ignore decoding claims into .Payload map[string]interface.
    52  // Decoding to map[string]interface is pretty expensive, so this option can be used for performance critical operations.
    53  func WithIgnoreClaimsMapDecoding(ignoreClaimsMapDecoding bool) ParseOpt {
    54  	return func(opts *parseOpts) {
    55  		opts.ignoreClaimsMapDecoding = ignoreClaimsMapDecoding
    56  	}
    57  }
    58  
    59  // WithSignatureVerifier option is for definition of JWT detached payload.
    60  func WithSignatureVerifier(signatureVerifier jose.SignatureVerifier) ParseOpt {
    61  	return func(opts *parseOpts) {
    62  		opts.sigVerifier = signatureVerifier
    63  	}
    64  }
    65  
    66  type signatureVerifierFunc func(joseHeaders jose.Headers, payload, signingInput, signature []byte) error
    67  
    68  func (v signatureVerifierFunc) Verify(joseHeaders jose.Headers, payload, signingInput, signature []byte) error {
    69  	return v(joseHeaders, payload, signingInput, signature)
    70  }
    71  
    72  func verifyUnsecuredJWT(joseHeaders jose.Headers, _, _, signature []byte) error {
    73  	alg, ok := joseHeaders.Algorithm()
    74  	if !ok {
    75  		return errors.New("alg is not defined")
    76  	}
    77  
    78  	if alg != AlgorithmNone {
    79  		return errors.New("alg value is not 'none'")
    80  	}
    81  
    82  	if len(signature) > 0 {
    83  		return errors.New("not empty signature")
    84  	}
    85  
    86  	return nil
    87  }
    88  
    89  // UnsecuredJWTVerifier provides verifier for unsecured JWT.
    90  func UnsecuredJWTVerifier() jose.SignatureVerifier {
    91  	return signatureVerifierFunc(verifyUnsecuredJWT)
    92  }
    93  
    94  type unsecuredJWTSigner struct{}
    95  
    96  func (s unsecuredJWTSigner) Sign(_ []byte) ([]byte, error) {
    97  	return []byte(""), nil
    98  }
    99  
   100  func (s unsecuredJWTSigner) Headers() jose.Headers {
   101  	return map[string]interface{}{
   102  		jose.HeaderAlgorithm: AlgorithmNone,
   103  	}
   104  }
   105  
   106  // JSONWebToken defines JSON Web Token (https://tools.ietf.org/html/rfc7519)
   107  type JSONWebToken struct {
   108  	Headers jose.Headers
   109  
   110  	Payload map[string]interface{}
   111  
   112  	jws *jose.JSONWebSignature
   113  }
   114  
   115  // Parse parses input JWT in serialized form into JSON Web Token.
   116  // Currently JWS and unsecured JWT is supported.
   117  func Parse(jwtSerialized string, opts ...ParseOpt) (*JSONWebToken, []byte, error) {
   118  	if !jose.IsCompactJWS(jwtSerialized) {
   119  		return nil, nil, errors.New("JWT of compacted JWS form is supported only")
   120  	}
   121  
   122  	pOpts := &parseOpts{}
   123  
   124  	for _, opt := range opts {
   125  		opt(pOpts)
   126  	}
   127  
   128  	return parseJWS(jwtSerialized, pOpts)
   129  }
   130  
   131  // DecodeClaims fills input c with claims of a token.
   132  func (j *JSONWebToken) DecodeClaims(c interface{}) error {
   133  	pBytes, err := json.Marshal(j.Payload)
   134  	if err != nil {
   135  		return err
   136  	}
   137  
   138  	return json.Unmarshal(pBytes, c)
   139  }
   140  
   141  // LookupStringHeader makes look up of particular header with string value.
   142  func (j *JSONWebToken) LookupStringHeader(name string) string {
   143  	if headerValue, ok := j.Headers[name]; ok {
   144  		if headerStrValue, ok := headerValue.(string); ok {
   145  			return headerStrValue
   146  		}
   147  	}
   148  
   149  	return ""
   150  }
   151  
   152  // Serialize makes (compact) serialization of token.
   153  func (j *JSONWebToken) Serialize(detached bool) (string, error) {
   154  	if j.jws == nil {
   155  		return "", errors.New("JWS serialization is supported only")
   156  	}
   157  
   158  	return j.jws.SerializeCompact(detached)
   159  }
   160  
   161  func parseJWS(jwtSerialized string, opts *parseOpts) (*JSONWebToken, []byte, error) {
   162  	jwsOpts := make([]jose.JWSParseOpt, 0)
   163  
   164  	if opts.detachedPayload != nil {
   165  		jwsOpts = append(jwsOpts, jose.WithJWSDetachedPayload(opts.detachedPayload))
   166  	}
   167  
   168  	jws, err := jose.ParseJWS(jwtSerialized, opts.sigVerifier, jwsOpts...)
   169  	if err != nil {
   170  		return nil, nil, fmt.Errorf("parse JWT from compact JWS: %w", err)
   171  	}
   172  
   173  	return mapJWSToJWT(jws, opts)
   174  }
   175  
   176  func mapJWSToJWT(jws *jose.JSONWebSignature, opts *parseOpts) (*JSONWebToken, []byte, error) {
   177  	headers := jws.ProtectedHeaders
   178  
   179  	err := checkHeaders(headers)
   180  	if err != nil {
   181  		return nil, nil, fmt.Errorf("check JWT headers: %w", err)
   182  	}
   183  
   184  	token := &JSONWebToken{
   185  		Headers: headers,
   186  		jws:     jws,
   187  	}
   188  
   189  	if !opts.ignoreClaimsMapDecoding {
   190  		claims, err := PayloadToMap(jws.Payload)
   191  		if err != nil {
   192  			return nil, nil, fmt.Errorf("read JWT claims from JWS payload: %w", err)
   193  		}
   194  
   195  		token.Payload = claims
   196  	}
   197  
   198  	return token, jws.Payload, nil
   199  }
   200  
   201  // NewSigned creates new signed JSON Web Token based on input claims.
   202  func NewSigned(claims interface{}, headers jose.Headers, signer jose.Signer) (*JSONWebToken, error) {
   203  	return newSigned(claims, headers, signer)
   204  }
   205  
   206  // NewUnsecured creates new unsecured JSON Web Token based on input claims.
   207  func NewUnsecured(claims interface{}, headers jose.Headers) (*JSONWebToken, error) {
   208  	return newSigned(claims, headers, &unsecuredJWTSigner{})
   209  }
   210  
   211  func newSigned(claims interface{}, headers jose.Headers, signer jose.Signer) (*JSONWebToken, error) {
   212  	payloadMap, err := PayloadToMap(claims)
   213  	if err != nil {
   214  		return nil, fmt.Errorf("unmarshallable claims: %w", err)
   215  	}
   216  
   217  	payloadBytes, err := json.Marshal(payloadMap)
   218  	if err != nil {
   219  		return nil, fmt.Errorf("marshal JWT claims: %w", err)
   220  	}
   221  
   222  	// JWS compact serialization uses only protected headers (https://tools.ietf.org/html/rfc7515#section-3.1).
   223  	jws, err := jose.NewJWS(headers, nil, payloadBytes, signer)
   224  	if err != nil {
   225  		return nil, fmt.Errorf("create JWS: %w", err)
   226  	}
   227  
   228  	return &JSONWebToken{
   229  		Headers: jws.ProtectedHeaders,
   230  		Payload: payloadMap,
   231  		jws:     jws,
   232  	}, nil
   233  }
   234  
   235  // IsJWS checks if JWT is a JWS of valid structure.
   236  func IsJWS(s string) bool {
   237  	parts := strings.Split(s, ".")
   238  
   239  	return len(parts) == 3 &&
   240  		isValidJSON(parts[0]) &&
   241  		isValidJSON(parts[1]) &&
   242  		parts[2] != ""
   243  }
   244  
   245  // IsJWTUnsecured checks if JWT is an unsecured JWT of valid structure.
   246  func IsJWTUnsecured(s string) bool {
   247  	parts := strings.Split(s, ".")
   248  
   249  	return len(parts) == 3 &&
   250  		isValidJSON(parts[0]) &&
   251  		isValidJSON(parts[1]) &&
   252  		parts[2] == ""
   253  }
   254  
   255  func isValidJSON(s string) bool {
   256  	b, err := base64.RawURLEncoding.DecodeString(s)
   257  	if err != nil {
   258  		return false
   259  	}
   260  
   261  	var j map[string]interface{}
   262  	err = json.Unmarshal(b, &j)
   263  
   264  	return err == nil
   265  }
   266  
   267  func checkHeaders(headers map[string]interface{}) error {
   268  	if _, ok := headers[jose.HeaderAlgorithm]; !ok {
   269  		return errors.New("alg header is not defined")
   270  	}
   271  
   272  	typ, ok := headers[jose.HeaderType]
   273  	if ok && typ != TypeJWT {
   274  		return errors.New("typ is not JWT")
   275  	}
   276  
   277  	cty, ok := headers[jose.HeaderContentType]
   278  	if ok && cty == TypeJWT { // https://tools.ietf.org/html/rfc7519#section-5.2
   279  		return errors.New("nested JWT is not supported")
   280  	}
   281  
   282  	return nil
   283  }
   284  
   285  // PayloadToMap transforms interface to map.
   286  func PayloadToMap(i interface{}) (map[string]interface{}, error) {
   287  	if reflect.ValueOf(i).Kind() == reflect.Map {
   288  		return i.(map[string]interface{}), nil
   289  	}
   290  
   291  	var (
   292  		b   []byte
   293  		err error
   294  	)
   295  
   296  	switch cv := i.(type) {
   297  	case []byte:
   298  		b = cv
   299  	case string:
   300  		b = []byte(cv)
   301  	default:
   302  		b, err = json.Marshal(i)
   303  		if err != nil {
   304  			return nil, fmt.Errorf("marshal interface[%T]: %w", i, err)
   305  		}
   306  	}
   307  
   308  	var m map[string]interface{}
   309  
   310  	d := json.NewDecoder(bytes.NewReader(b))
   311  	d.UseNumber()
   312  
   313  	if err := d.Decode(&m); err != nil {
   314  		return nil, fmt.Errorf("convert to map: %w", err)
   315  	}
   316  
   317  	return m, nil
   318  }