github.com/hyperledger/aries-framework-go@v0.3.2/pkg/doc/jwt/jwt_test.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package jwt
     8  
     9  import (
    10  	"crypto/ed25519"
    11  	"crypto/rand"
    12  	"crypto/rsa"
    13  	"encoding/base64"
    14  	"fmt"
    15  	"strings"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/go-jose/go-jose/v3/json"
    20  	"github.com/go-jose/go-jose/v3/jwt"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/hyperledger/aries-framework-go/pkg/doc/jose"
    25  )
    26  
    27  type CustomClaim struct {
    28  	*Claims
    29  
    30  	PrivateClaim1 string `json:"privateClaim1,omitempty"`
    31  }
    32  
    33  func TestNewSigned(t *testing.T) {
    34  	claims := createClaims()
    35  
    36  	t.Run("Create JWS signed by EdDSA", func(t *testing.T) {
    37  		r := require.New(t)
    38  
    39  		pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
    40  		r.NoError(err)
    41  
    42  		token, err := NewSigned(claims, nil, NewEd25519Signer(privKey))
    43  		r.NoError(err)
    44  		jws, err := token.Serialize(false)
    45  		require.NoError(t, err)
    46  
    47  		var parsedClaims CustomClaim
    48  		err = verifyEd25519ViaGoJose(jws, pubKey, &parsedClaims)
    49  		r.NoError(err)
    50  		r.Equal(*claims, parsedClaims)
    51  
    52  		err = verifyEd25519(jws, pubKey)
    53  		r.NoError(err)
    54  	})
    55  
    56  	t.Run("Create JWS signed by RS256", func(t *testing.T) {
    57  		r := require.New(t)
    58  
    59  		privKey, err := rsa.GenerateKey(rand.Reader, 2048)
    60  		r.NoError(err)
    61  
    62  		pubKey := &privKey.PublicKey
    63  
    64  		token, err := NewSigned(claims, nil, NewRS256Signer(privKey, nil))
    65  		r.NoError(err)
    66  		jws, err := token.Serialize(false)
    67  		require.NoError(t, err)
    68  
    69  		var parsedClaims CustomClaim
    70  		err = verifyRS256ViaGoJose(jws, pubKey, &parsedClaims)
    71  		r.NoError(err)
    72  		r.Equal(*claims, parsedClaims)
    73  
    74  		err = verifyRS256(jws, pubKey)
    75  		r.NoError(err)
    76  	})
    77  }
    78  
    79  func TestNewUnsecured(t *testing.T) {
    80  	claims := createClaims()
    81  
    82  	t.Run("Create unsecured JWT", func(t *testing.T) {
    83  		r := require.New(t)
    84  
    85  		token, err := NewUnsecured(claims, map[string]interface{}{"custom": "ok"})
    86  		r.NoError(err)
    87  		jwtUnsecured, err := token.Serialize(false)
    88  		r.NoError(err)
    89  		r.NotEmpty(jwtUnsecured)
    90  
    91  		parsedJWT, _, err := Parse(jwtUnsecured, WithSignatureVerifier(UnsecuredJWTVerifier()))
    92  		r.NoError(err)
    93  		r.NotNil(parsedJWT)
    94  
    95  		var parsedClaims CustomClaim
    96  		err = parsedJWT.DecodeClaims(&parsedClaims)
    97  		r.NoError(err)
    98  		r.Equal(*claims, parsedClaims)
    99  	})
   100  
   101  	t.Run("Invalid claims", func(t *testing.T) {
   102  		token, err := NewUnsecured("not JSON claims", nil)
   103  		require.Error(t, err)
   104  		require.Nil(t, token)
   105  		require.Contains(t, err.Error(), "unmarshallable claims")
   106  
   107  		token, err = NewUnsecured(getUnmarshallableMap(), nil)
   108  		require.Error(t, err)
   109  		require.Nil(t, token)
   110  		require.Contains(t, err.Error(), "marshal JWT claims")
   111  
   112  		token, err = NewUnsecured(claims, getUnmarshallableMap())
   113  		require.Error(t, err)
   114  		require.Nil(t, token)
   115  		require.Contains(t, err.Error(), "create JWS")
   116  	})
   117  }
   118  
   119  func TestWithJWTDetachedPayload(t *testing.T) {
   120  	detachedPayloadOpt := WithJWTDetachedPayload([]byte("payload"))
   121  	require.NotNil(t, detachedPayloadOpt)
   122  
   123  	opts := &parseOpts{}
   124  	detachedPayloadOpt(opts)
   125  	require.Equal(t, []byte("payload"), opts.detachedPayload)
   126  }
   127  
   128  func TestParse(t *testing.T) {
   129  	r := require.New(t)
   130  
   131  	pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
   132  	r.NoError(err)
   133  
   134  	signer := NewEd25519Signer(privKey)
   135  	claims := map[string]interface{}{"iss": "Albert"}
   136  
   137  	token, err := NewSigned(claims, nil, signer)
   138  	r.NoError(err)
   139  	jws, err := token.Serialize(false)
   140  	r.NoError(err)
   141  
   142  	verifier, err := NewEd25519Verifier(pubKey)
   143  	r.NoError(err)
   144  
   145  	jsonWebToken, _, err := Parse(jws, WithSignatureVerifier(verifier))
   146  	r.NoError(err)
   147  
   148  	var parsedClaims map[string]interface{}
   149  	err = jsonWebToken.DecodeClaims(&parsedClaims)
   150  	r.NoError(err)
   151  
   152  	r.Equal(claims, parsedClaims)
   153  
   154  	// parse without .Payload data
   155  	jsonWebToken, _, err = Parse(jws, WithSignatureVerifier(verifier), WithIgnoreClaimsMapDecoding(true))
   156  	r.NoError(err)
   157  	assert.Nil(t, jsonWebToken.Payload)
   158  
   159  	// parse detached JWT
   160  	jwsParts := strings.Split(jws, ".")
   161  	jwsDetached := fmt.Sprintf("%s..%s", jwsParts[0], jwsParts[2])
   162  
   163  	jwsPayload, err := base64.RawURLEncoding.DecodeString(jwsParts[1])
   164  	require.NoError(t, err)
   165  
   166  	jsonWebToken, _, err = Parse(jwsDetached,
   167  		WithSignatureVerifier(verifier), WithJWTDetachedPayload(jwsPayload))
   168  	r.NoError(err)
   169  	r.NotNil(r, jsonWebToken)
   170  
   171  	// claims is not JSON
   172  	jws, err = buildJWS(signer, "not JSON")
   173  	r.NoError(err)
   174  	token, _, err = Parse(jws, WithSignatureVerifier(verifier))
   175  	r.Error(err)
   176  	r.Contains(err.Error(), "read JWT claims from JWS payload")
   177  	r.Nil(token)
   178  
   179  	// type is not JWT
   180  	signer.headers = map[string]interface{}{"alg": "EdDSA", "typ": "JWM"}
   181  	jws, err = buildJWS(signer, map[string]interface{}{"iss": "Albert"})
   182  	r.NoError(err)
   183  	token, _, err = Parse(jws, WithSignatureVerifier(verifier))
   184  	r.Error(err)
   185  	r.Contains(err.Error(), "typ is not JWT")
   186  	r.Nil(token)
   187  
   188  	// content type is not empty (equals to JWT)
   189  	signer.headers = map[string]interface{}{"alg": "EdDSA", "typ": "JWT", "cty": "JWT"}
   190  	jws, err = buildJWS(signer, map[string]interface{}{"iss": "Albert"})
   191  	r.NoError(err)
   192  	token, _, err = Parse(jws, WithSignatureVerifier(verifier))
   193  	r.Error(err)
   194  	r.Contains(err.Error(), "nested JWT is not supported")
   195  	r.Nil(token)
   196  
   197  	// handle compact JWS of invalid form
   198  	token, _, err = Parse("invalid.compact.JWS")
   199  	r.Error(err)
   200  	r.Contains(err.Error(), "parse JWT from compact JWS")
   201  	r.Nil(token)
   202  
   203  	// pass not compact JWS
   204  	token, _, err = Parse("invalid jws")
   205  	r.Error(err)
   206  	r.EqualError(err, "JWT of compacted JWS form is supported only")
   207  	r.Nil(token)
   208  }
   209  
   210  func buildJWS(signer jose.Signer, claims interface{}) (string, error) {
   211  	claimsBytes, err := json.Marshal(claims)
   212  	if err != nil {
   213  		return "", err
   214  	}
   215  
   216  	jws, err := jose.NewJWS(nil, nil, claimsBytes, signer)
   217  	if err != nil {
   218  		return "", err
   219  	}
   220  
   221  	return jws.SerializeCompact(false)
   222  }
   223  
   224  func TestJSONWebToken_DecodeClaims(t *testing.T) {
   225  	token, err := getValidJSONWebToken()
   226  	require.NoError(t, err)
   227  
   228  	var tokensMap map[string]interface{}
   229  
   230  	err = token.DecodeClaims(&tokensMap)
   231  	require.NoError(t, err)
   232  	require.Equal(t, map[string]interface{}{"iss": "Albert"}, tokensMap)
   233  
   234  	var claims Claims
   235  
   236  	err = token.DecodeClaims(&claims)
   237  	require.NoError(t, err)
   238  	require.Equal(t, Claims{Issuer: "Albert"}, claims)
   239  
   240  	token, err = getJSONWebTokenWithInvalidPayload()
   241  	require.NoError(t, err)
   242  
   243  	err = token.DecodeClaims(&claims)
   244  	require.Error(t, err)
   245  }
   246  
   247  func TestJSONWebToken_LookupStringHeader(t *testing.T) {
   248  	token, err := getValidJSONWebToken()
   249  	require.NoError(t, err)
   250  
   251  	require.Equal(t, "JWT", token.LookupStringHeader("typ"))
   252  
   253  	require.Empty(t, token.LookupStringHeader("undef"))
   254  
   255  	token.Headers["not_str"] = 55
   256  	require.Empty(t, token.LookupStringHeader("not_str"))
   257  }
   258  
   259  func TestJSONWebToken_Serialize(t *testing.T) {
   260  	token, err := getValidJSONWebToken()
   261  	require.NoError(t, err)
   262  
   263  	tokenSerialized, err := token.Serialize(false)
   264  	require.NoError(t, err)
   265  	require.NotEmpty(t, tokenSerialized)
   266  
   267  	// cannot serialize without signature
   268  	token.jws = nil
   269  	tokenSerialized, err = token.Serialize(false)
   270  	require.Error(t, err)
   271  	require.EqualError(t, err, "JWS serialization is supported only")
   272  	require.Empty(t, tokenSerialized)
   273  }
   274  
   275  func TestUnsecuredJWTVerifier(t *testing.T) {
   276  	verifier := UnsecuredJWTVerifier()
   277  
   278  	err := verifier.Verify(map[string]interface{}{"alg": "none"}, nil, nil, nil)
   279  	require.NoError(t, err)
   280  
   281  	err = verifier.Verify(map[string]interface{}{}, nil, nil, nil)
   282  	require.Error(t, err)
   283  	require.EqualError(t, err, "alg is not defined")
   284  
   285  	err = verifier.Verify(map[string]interface{}{"alg": "EdDSA"}, nil, nil, nil)
   286  	require.Error(t, err)
   287  	require.EqualError(t, err, "alg value is not 'none'")
   288  
   289  	err = verifier.Verify(map[string]interface{}{"alg": "none"}, nil, nil, []byte("unexpected signature"))
   290  	require.Error(t, err)
   291  	require.EqualError(t, err, "not empty signature")
   292  }
   293  
   294  func Test_IsJWS(t *testing.T) {
   295  	b64 := base64.RawURLEncoding.EncodeToString([]byte("not json"))
   296  	j, err := json.Marshal(map[string]string{"alg": "none"})
   297  	require.NoError(t, err)
   298  
   299  	jb64 := base64.RawURLEncoding.EncodeToString(j)
   300  
   301  	type args struct {
   302  		data string
   303  	}
   304  
   305  	tests := []struct {
   306  		name string
   307  		args args
   308  		want bool
   309  	}{
   310  		{
   311  			name: "two parts only",
   312  			args: args{"two parts.only"},
   313  			want: false,
   314  		},
   315  		{
   316  			name: "empty third part",
   317  			args: args{"empty third.part."},
   318  			want: false,
   319  		},
   320  		{
   321  			name: "part 1 is not base64 decoded",
   322  			args: args{"not base64.part2.part3"},
   323  			want: false,
   324  		},
   325  		{
   326  			name: "part 1 is not JSON",
   327  			args: args{fmt.Sprintf("%s.part2.part3", b64)},
   328  			want: false,
   329  		},
   330  		{
   331  			name: "part 2 is not base64 decoded",
   332  			args: args{fmt.Sprintf("%s.not base64.part3", jb64)},
   333  			want: false,
   334  		},
   335  		{
   336  			name: "part 2 is not JSON",
   337  			args: args{fmt.Sprintf("%s.%s.part3", jb64, b64)},
   338  			want: false,
   339  		},
   340  		{
   341  			name: "is JWS",
   342  			args: args{fmt.Sprintf("%s.%s.signature", jb64, jb64)},
   343  			want: true,
   344  		},
   345  	}
   346  
   347  	for i := range tests {
   348  		tt := tests[i]
   349  		t.Run(tt.name, func(t *testing.T) {
   350  			if got := IsJWS(tt.args.data); got != tt.want {
   351  				t.Errorf("isJWS() = %v, want %v", got, tt.want)
   352  			}
   353  		})
   354  	}
   355  }
   356  
   357  func Test_IsJWTUnsecured(t *testing.T) {
   358  	b64 := base64.RawURLEncoding.EncodeToString([]byte("not json"))
   359  	j, err := json.Marshal(map[string]string{"alg": "none"})
   360  	require.NoError(t, err)
   361  
   362  	jb64 := base64.RawURLEncoding.EncodeToString(j)
   363  
   364  	type args struct {
   365  		data string
   366  	}
   367  
   368  	tests := []struct {
   369  		name string
   370  		args args
   371  		want bool
   372  	}{
   373  		{
   374  			name: "two parts only",
   375  			args: args{"two parts.only"},
   376  			want: false,
   377  		},
   378  		{
   379  			name: "not empty third part",
   380  			args: args{"third.part.not-empty"},
   381  			want: false,
   382  		},
   383  		{
   384  			name: "part 1 is not base64 decoded",
   385  			args: args{"not base64.part2.part3"},
   386  			want: false,
   387  		},
   388  		{
   389  			name: "part 1 is not JSON",
   390  			args: args{fmt.Sprintf("%s.part2.part3", b64)},
   391  			want: false,
   392  		},
   393  		{
   394  			name: "part 2 is not base64 decoded",
   395  			args: args{fmt.Sprintf("%s.not base64.part3", jb64)},
   396  			want: false,
   397  		},
   398  		{
   399  			name: "part 2 is not JSON",
   400  			args: args{fmt.Sprintf("%s.%s.part3", jb64, b64)},
   401  			want: false,
   402  		},
   403  		{
   404  			name: "is JWT unsecured",
   405  			args: args{fmt.Sprintf("%s.%s.", jb64, jb64)},
   406  			want: true,
   407  		},
   408  	}
   409  
   410  	for i := range tests {
   411  		tt := tests[i]
   412  		t.Run(tt.name, func(t *testing.T) {
   413  			if got := IsJWTUnsecured(tt.args.data); got != tt.want {
   414  				t.Errorf("isJWTUnsecured() = %v, want %v", got, tt.want)
   415  			}
   416  		})
   417  	}
   418  }
   419  
   420  type testToMapStruct struct {
   421  	TestField string `json:"a"`
   422  }
   423  
   424  func Test_toMap(t *testing.T) {
   425  	inputMap := map[string]interface{}{"a": "b"}
   426  
   427  	r := require.New(t)
   428  
   429  	// pass map
   430  	resultMap, err := PayloadToMap(inputMap)
   431  	r.NoError(err)
   432  	r.Equal(inputMap, resultMap)
   433  
   434  	// pass []byte
   435  	inputMapBytes, err := json.Marshal(inputMap)
   436  	r.NoError(err)
   437  	resultMap, err = PayloadToMap(inputMapBytes)
   438  	r.NoError(err)
   439  	r.Equal(inputMap, resultMap)
   440  
   441  	// pass string
   442  	inputMapStr := string(inputMapBytes)
   443  	resultMap, err = PayloadToMap(inputMapStr)
   444  	r.NoError(err)
   445  	r.Equal(inputMap, resultMap)
   446  
   447  	// pass struct
   448  	s := testToMapStruct{TestField: "b"}
   449  	resultMap, err = PayloadToMap(s)
   450  	r.NoError(err)
   451  	r.Equal(inputMap, resultMap)
   452  
   453  	// pass invalid []byte
   454  	resultMap, err = PayloadToMap([]byte("not JSON"))
   455  	r.Error(err)
   456  	r.Contains(err.Error(), "convert to map")
   457  	r.Nil(resultMap)
   458  
   459  	// pass invalid structure
   460  	resultMap, err = PayloadToMap(make(chan int))
   461  	r.Error(err)
   462  	r.Contains(err.Error(), "marshal interface[chan int]: json: unsupported type: chan int")
   463  	r.Nil(resultMap)
   464  }
   465  
   466  func getValidJSONWebToken() (*JSONWebToken, error) {
   467  	headers := map[string]interface{}{"typ": "JWT", "alg": "EdDSA"}
   468  	claims := map[string]interface{}{"iss": "Albert"}
   469  
   470  	_, privKey, err := ed25519.GenerateKey(rand.Reader)
   471  	if err != nil {
   472  		return nil, err
   473  	}
   474  
   475  	signer := NewEd25519Signer(privKey)
   476  
   477  	return NewSigned(claims, headers, signer)
   478  }
   479  
   480  func getJSONWebTokenWithInvalidPayload() (*JSONWebToken, error) {
   481  	token, err := getValidJSONWebToken()
   482  	if err != nil {
   483  		return nil, err
   484  	}
   485  
   486  	// hack the token
   487  	token.Payload = getUnmarshallableMap()
   488  
   489  	return token, nil
   490  }
   491  
   492  func verifyEd25519ViaGoJose(jws string, pubKey ed25519.PublicKey, claims interface{}) error {
   493  	jwtToken, err := jwt.ParseSigned(jws)
   494  	if err != nil {
   495  		return fmt.Errorf("parse VC from signed JWS: %w", err)
   496  	}
   497  
   498  	if err = jwtToken.Claims(pubKey, claims); err != nil {
   499  		return fmt.Errorf("verify JWT signature: %w", err)
   500  	}
   501  
   502  	return nil
   503  }
   504  
   505  func verifyRS256ViaGoJose(jws string, pubKey *rsa.PublicKey, claims interface{}) error {
   506  	jwtToken, err := jwt.ParseSigned(jws)
   507  	if err != nil {
   508  		return fmt.Errorf("parse VC from signed JWS: %w", err)
   509  	}
   510  
   511  	if err = jwtToken.Claims(pubKey, claims); err != nil {
   512  		return fmt.Errorf("verify JWT signature: %w", err)
   513  	}
   514  
   515  	return nil
   516  }
   517  
   518  func getUnmarshallableMap() map[string]interface{} {
   519  	return map[string]interface{}{"error": map[chan int]interface{}{make(chan int): 6}}
   520  }
   521  
   522  func createClaims() *CustomClaim {
   523  	issued := time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)
   524  	expiry := time.Date(2022, time.January, 1, 0, 0, 0, 0, time.UTC)
   525  	notBefore := time.Date(2021, time.January, 1, 0, 0, 0, 0, time.UTC)
   526  
   527  	return &CustomClaim{
   528  		Claims: &Claims{
   529  			Issuer:    "iss",
   530  			Subject:   "sub",
   531  			Audience:  []string{"aud"},
   532  			Expiry:    jwt.NewNumericDate(expiry),
   533  			NotBefore: jwt.NewNumericDate(notBefore),
   534  			IssuedAt:  jwt.NewNumericDate(issued),
   535  			ID:        "id",
   536  		},
   537  
   538  		PrivateClaim1: "private claim",
   539  	}
   540  }