github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/mw_jwt_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"crypto/md5"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"fmt"
     8  	"net/http"
     9  	"reflect"
    10  	"sort"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	jwt "github.com/dgrijalva/jwt-go"
    16  	"github.com/lonelycode/go-uuid/uuid"
    17  	"github.com/stretchr/testify/assert"
    18  
    19  	"github.com/TykTechnologies/tyk/test"
    20  	"github.com/TykTechnologies/tyk/user"
    21  )
    22  
    23  // openssl rsa -in app.rsa -pubout > app.rsa.pub
    24  const jwtRSAPubKey = `
    25  -----BEGIN PUBLIC KEY-----
    26  MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyqZ4rwKF8qCExS7kpY4c
    27  nJa/37FMkJNkalZ3OuslLB0oRL8T4c94kdF4aeNzSFkSe2n99IBI6Ssl79vbfMZb
    28  +t06L0Q94k+/P37x7+/RJZiff4y1VGjrnrnMI2iu9l4iBBRYzNmG6eblroEMMWlg
    29  k5tysHgxB59CSNIcD9gqk1hx4n/FgOmvKsfQgWHNlPSDTRcWGWGhB2/XgNVYG2pO
    30  lQxAPqLhBHeqGTXBbPfGF9cHzixpsPr6GtbzPwhsQ/8bPxoJ7hdfn+rzztks3d6+
    31  HWURcyNTLRe0mjXjjee9Z6+gZ+H+fS4pnP9tqT7IgU6ePUWTpjoiPtLexgsAa/ct
    32  jQIDAQAB
    33  -----END PUBLIC KEY-----
    34  `
    35  
    36  const jwtRSAPubKeyinvalid = `
    37  -----BEGIN PUBLIC KEY-----
    38  MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyqZ4rwKF8qCExS7kpY4c
    39  nJa/37FMkJNkalZ3OuslLB0oRL8T4c94kdF4aeNzSFkSe2n99IBI6Ssl79vbfMZb
    40  +t06L0Q94k+/P37x7+/RJZiff4y1VGjrnrnMI2iu9l4iBBRYzNmG6eblroEMMWlg
    41  k5tysHgxB59CSNIcD9gqk1hx4n/FgOmvKsfQgWHNlPSDTRcWGWGhB2/XgNVYG2pO
    42  lQxAPqLhBHeqGTXBbPfGF9cHzixpsPr6GtbzPwhsQ/8bPxoJ7hdfn+rzztks3d6+
    43  HWURcyNTLRe0mjXjjee9Z6+gZ+H+fS4pnP9tqT7IgU6ePUWTpjoiPtLexgsAa/ct
    44  jQIDAQAB!!!!
    45  -----END PUBLIC KEY-----
    46  `
    47  
    48  func createJWTSession() *user.SessionState {
    49  	session := new(user.SessionState)
    50  	session.Rate = 1000000.0
    51  	session.Allowance = session.Rate
    52  	session.LastCheck = time.Now().Unix() - 10
    53  	session.Per = 1.0
    54  	session.QuotaRenewalRate = 300 // 5 minutes
    55  	session.QuotaRenews = time.Now().Unix() + 20
    56  	session.QuotaRemaining = 1
    57  	session.QuotaMax = -1
    58  	session.JWTData = user.JWTData{Secret: jwtSecret}
    59  	session.Mutex = &sync.RWMutex{}
    60  	return session
    61  }
    62  
    63  func createJWTSessionWithRSA() *user.SessionState {
    64  	session := createJWTSession()
    65  	session.JWTData.Secret = jwtRSAPubKey
    66  	return session
    67  }
    68  
    69  func createJWTSessionWithECDSA() *user.SessionState {
    70  	session := createJWTSession()
    71  	session.JWTData.Secret = jwtECDSAPublicKey
    72  	return session
    73  }
    74  
    75  func createJWTSessionWithRSAWithPolicy(policyID string) *user.SessionState {
    76  	session := createJWTSessionWithRSA()
    77  	session.SetPolicies(policyID)
    78  	return session
    79  }
    80  
    81  type JwtCreator func() *user.SessionState
    82  
    83  func prepareGenericJWTSession(testName string, method string, claimName string, ApiSkipKid bool) (*APISpec, string) {
    84  	tokenKID := testKey(testName, "token")
    85  
    86  	var jwtToken string
    87  	var sessionFunc JwtCreator
    88  	switch method {
    89  	default:
    90  		log.Warningf("Signing method '%s' is not recognised, defaulting to HMAC signature", method)
    91  		method = HMACSign
    92  		fallthrough
    93  	case HMACSign:
    94  		sessionFunc = createJWTSession
    95  
    96  		jwtToken = createJWKTokenHMAC(func(t *jwt.Token) {
    97  			t.Claims.(jwt.MapClaims)["foo"] = "bar"
    98  			t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
    99  
   100  			if claimName != KID {
   101  				t.Claims.(jwt.MapClaims)[claimName] = tokenKID
   102  				t.Header[KID] = "ignore-this-id"
   103  			} else {
   104  				t.Header[KID] = tokenKID
   105  			}
   106  		})
   107  	case RSASign:
   108  		sessionFunc = createJWTSessionWithRSA
   109  
   110  		jwtToken = CreateJWKToken(func(t *jwt.Token) {
   111  			t.Claims.(jwt.MapClaims)["foo"] = "bar"
   112  			t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
   113  
   114  			if claimName != KID {
   115  				t.Claims.(jwt.MapClaims)[claimName] = tokenKID
   116  				t.Header[KID] = "ignore-this-id"
   117  			} else {
   118  				t.Header[KID] = tokenKID
   119  			}
   120  		})
   121  	case ECDSASign:
   122  		sessionFunc = createJWTSessionWithECDSA
   123  
   124  		jwtToken = CreateJWKTokenECDSA(func(t *jwt.Token) {
   125  			t.Claims.(jwt.MapClaims)["foo"] = "bar"
   126  			t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
   127  
   128  			if claimName != KID {
   129  				t.Claims.(jwt.MapClaims)[claimName] = tokenKID
   130  				t.Header[KID] = "ignore-this-id"
   131  			} else {
   132  				t.Header[KID] = tokenKID
   133  			}
   134  		})
   135  	}
   136  
   137  	spec := BuildAndLoadAPI(func(spec *APISpec) {
   138  		spec.UseKeylessAccess = false
   139  		spec.JWTSigningMethod = method
   140  		spec.EnableJWT = true
   141  		spec.Proxy.ListenPath = "/"
   142  		spec.JWTSkipKid = ApiSkipKid
   143  
   144  		if claimName != KID {
   145  			spec.JWTIdentityBaseField = claimName
   146  		}
   147  	})[0]
   148  	spec.SessionManager.UpdateSession(tokenKID, sessionFunc(), 60, false)
   149  
   150  	return spec, jwtToken
   151  
   152  }
   153  
   154  func TestJWTSessionHMAC(t *testing.T) {
   155  	ts := StartTest()
   156  	defer ts.Close()
   157  
   158  	//If we skip the check then the Id will be taken from SUB and the call will succeed
   159  	_, jwtToken := prepareGenericJWTSession(t.Name(), HMACSign, KID, false)
   160  	defer ResetTestConfig()
   161  
   162  	authHeaders := map[string]string{"authorization": jwtToken}
   163  	t.Run("Request with valid JWT signed with HMAC", func(t *testing.T) {
   164  		ts.Run(t, test.TestCase{
   165  			Headers: authHeaders, Code: http.StatusOK,
   166  		})
   167  	})
   168  }
   169  
   170  func BenchmarkJWTSessionHMAC(b *testing.B) {
   171  	b.ReportAllocs()
   172  
   173  	ts := StartTest()
   174  	defer ts.Close()
   175  
   176  	//If we skip the check then the Id will be taken from SUB and the call will succeed
   177  	_, jwtToken := prepareGenericJWTSession(b.Name(), HMACSign, KID, false)
   178  	defer ResetTestConfig()
   179  
   180  	authHeaders := map[string]string{"authorization": jwtToken}
   181  	for i := 0; i < b.N; i++ {
   182  		ts.Run(b, test.TestCase{
   183  			Headers: authHeaders, Code: http.StatusOK,
   184  		})
   185  	}
   186  }
   187  
   188  func TestJWTHMACIdInSubClaim(t *testing.T) {
   189  
   190  	ts := StartTest()
   191  	defer ts.Close()
   192  
   193  	//Same as above
   194  	_, jwtToken := prepareGenericJWTSession(t.Name(), HMACSign, SUB, true)
   195  	authHeaders := map[string]string{"authorization": jwtToken}
   196  	t.Run("Request with valid JWT/HMAC/Id in SuB/Global-skip-kid/Api-skip-kid", func(t *testing.T) {
   197  		ts.Run(t, test.TestCase{
   198  			Headers: authHeaders, Code: http.StatusOK,
   199  		})
   200  	})
   201  
   202  	// For backward compatibility, if the new config are not set, and the id is in the 'sub' claim while the 'kid' claim
   203  	// in the header is not empty, then the jwt will return 403 - "Key not authorized:token invalid, key not found"
   204  	_, jwtToken = prepareGenericJWTSession(t.Name(), HMACSign, SUB, false)
   205  	authHeaders = map[string]string{"authorization": jwtToken}
   206  	t.Run("Request with valid JWT/HMAC/Id in SuB/Global-dont-skip-kid/Api-dont-skip-kid", func(t *testing.T) {
   207  		ts.Run(t, test.TestCase{
   208  			Headers:   authHeaders,
   209  			Code:      http.StatusForbidden,
   210  			BodyMatch: `Key not authorized:token invalid, key not found`,
   211  		})
   212  	})
   213  
   214  	// Case where the gw always check the 'kid' claim first but if this JWTSkipCheckKidAsId is set on the api level,
   215  	// then it'll work
   216  	_, jwtToken = prepareGenericJWTSession(t.Name(), HMACSign, SUB, true)
   217  	defer ResetTestConfig()
   218  	authHeaders = map[string]string{"authorization": jwtToken}
   219  	t.Run("Request with valid JWT/HMAC/Id in SuB/Global-dont-skip-kid/Api-skip-kid", func(t *testing.T) {
   220  		ts.Run(t, test.TestCase{
   221  			Headers: authHeaders, Code: http.StatusOK,
   222  		})
   223  	})
   224  }
   225  
   226  func TestJWTRSAIdInSubClaim(t *testing.T) {
   227  	ts := StartTest()
   228  	defer ts.Close()
   229  
   230  	_, jwtToken := prepareGenericJWTSession(t.Name(), RSASign, SUB, true)
   231  	authHeaders := map[string]string{"authorization": jwtToken}
   232  	t.Run("Request with valid JWT/RSA/Id in SuB/Global-skip-kid/Api-skip-kid", func(t *testing.T) {
   233  		ts.Run(t, test.TestCase{
   234  			Headers: authHeaders, Code: http.StatusOK,
   235  		})
   236  	})
   237  
   238  	_, jwtToken = prepareGenericJWTSession(t.Name(), RSASign, SUB, false)
   239  	authHeaders = map[string]string{"authorization": jwtToken}
   240  	t.Run("Request with valid JWT/RSA/Id in SuB/Global-dont-skip-kid/Api-dont-skip-kid", func(t *testing.T) {
   241  		ts.Run(t, test.TestCase{
   242  			Headers:   authHeaders,
   243  			Code:      http.StatusForbidden,
   244  			BodyMatch: `Key not authorized:token invalid, key not found`,
   245  		})
   246  	})
   247  
   248  	_, jwtToken = prepareGenericJWTSession(t.Name(), RSASign, SUB, true)
   249  	authHeaders = map[string]string{"authorization": jwtToken}
   250  	t.Run("Request with valid JWT/RSA/Id in SuB/Global-dont-skip-kid/Api-skip-kid", func(t *testing.T) {
   251  		ts.Run(t, test.TestCase{
   252  			Headers: authHeaders, Code: http.StatusOK,
   253  		})
   254  	})
   255  }
   256  
   257  func TestJWTSessionRSA(t *testing.T) {
   258  	ts := StartTest()
   259  	defer ts.Close()
   260  
   261  	//default values, keep backward compatibility
   262  	_, jwtToken := prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   263  	authHeaders := map[string]string{"authorization": jwtToken}
   264  	t.Run("Request with valid JWT", func(t *testing.T) {
   265  		ts.Run(t, test.TestCase{
   266  			Headers: authHeaders, Code: http.StatusOK,
   267  		})
   268  	})
   269  }
   270  
   271  func BenchmarkJWTSessionRSA(b *testing.B) {
   272  	b.ReportAllocs()
   273  
   274  	ts := StartTest()
   275  	defer ts.Close()
   276  
   277  	//default values, keep backward compatibility
   278  	_, jwtToken := prepareGenericJWTSession(b.Name(), RSASign, KID, false)
   279  
   280  	authHeaders := map[string]string{"authorization": jwtToken}
   281  	for i := 0; i < b.N; i++ {
   282  		ts.Run(b, test.TestCase{
   283  			Headers: authHeaders, Code: http.StatusOK,
   284  		})
   285  	}
   286  }
   287  
   288  func TestJWTSessionFailRSA_EmptyJWT(t *testing.T) {
   289  	ts := StartTest()
   290  	defer ts.Close()
   291  
   292  	//default values, same as before (keeps backward compatibility)
   293  	prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   294  
   295  	authHeaders := map[string]string{"authorization": ""}
   296  	t.Run("Request with empty authorization header", func(t *testing.T) {
   297  		ts.Run(t, test.TestCase{
   298  			Headers: authHeaders, Code: 400,
   299  		})
   300  	})
   301  }
   302  
   303  func TestJWTSessionFailRSA_NoAuthHeader(t *testing.T) {
   304  	ts := StartTest()
   305  	defer ts.Close()
   306  
   307  	//default values, same as before (keeps backward compatibility)
   308  	prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   309  
   310  	authHeaders := map[string]string{}
   311  	t.Run("Request without authorization header", func(t *testing.T) {
   312  		ts.Run(t, test.TestCase{
   313  			Headers: authHeaders, Code: http.StatusBadRequest, BodyMatch: `Authorization field missing`,
   314  		})
   315  	})
   316  }
   317  
   318  func TestJWTSessionFailRSA_MalformedJWT(t *testing.T) {
   319  	ts := StartTest()
   320  	defer ts.Close()
   321  
   322  	//default values, same as before (keeps backward compatibility)
   323  	_, jwtToken := prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   324  
   325  	authHeaders := map[string]string{"authorization": jwtToken + "ajhdkjhsdfkjashdkajshdkajhsdkajhsd"}
   326  	t.Run("Request with malformed JWT", func(t *testing.T) {
   327  		ts.Run(t, test.TestCase{
   328  			Headers:   authHeaders,
   329  			Code:      http.StatusForbidden,
   330  			BodyMatch: `Key not authorized:crypto/rsa: verification error`,
   331  		})
   332  	})
   333  }
   334  
   335  func TestJWTSessionFailRSA_MalformedJWT_NOTRACK(t *testing.T) {
   336  	ts := StartTest()
   337  	defer ts.Close()
   338  
   339  	//default values, same as before (keeps backward compatibility)
   340  	spec, jwtToken := prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   341  	spec.DoNotTrack = true
   342  	authHeaders := map[string]string{"authorization": jwtToken + "ajhdkjhsdfkjashdkajshdkajhsdkajhsd"}
   343  
   344  	t.Run("Request with malformed JWT no track", func(t *testing.T) {
   345  		ts.Run(t, test.TestCase{
   346  			Headers:   authHeaders,
   347  			Code:      http.StatusForbidden,
   348  			BodyMatch: `Key not authorized:crypto/rsa: verification error`,
   349  		})
   350  	})
   351  }
   352  
   353  func TestJWTSessionFailRSA_WrongJWT(t *testing.T) {
   354  	ts := StartTest()
   355  	defer ts.Close()
   356  
   357  	//default values, same as before (keeps backward compatibility)
   358  	prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   359  	authHeaders := map[string]string{"authorization": "123"}
   360  
   361  	t.Run("Request with invalid JWT", func(t *testing.T) {
   362  		ts.Run(t, test.TestCase{
   363  			Headers:   authHeaders,
   364  			Code:      http.StatusForbidden,
   365  			BodyMatch: `Key not authorized:token contains an invalid number of segments`,
   366  		})
   367  	})
   368  }
   369  
   370  func TestJWTSessionRSABearer(t *testing.T) {
   371  	ts := StartTest()
   372  	defer ts.Close()
   373  
   374  	//default values, same as before (keeps backward compatibility)
   375  	_, jwtToken := prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   376  	authHeaders := map[string]string{"authorization": "Bearer " + jwtToken}
   377  
   378  	t.Run("Request with valid Bearer", func(t *testing.T) {
   379  		ts.Run(t, test.TestCase{
   380  			Headers: authHeaders, Code: http.StatusOK,
   381  		})
   382  	})
   383  }
   384  
   385  func BenchmarkJWTSessionRSABearer(b *testing.B) {
   386  	b.ReportAllocs()
   387  
   388  	ts := StartTest()
   389  	defer ts.Close()
   390  
   391  	//default values, same as before (keeps backward compatibility)
   392  	_, jwtToken := prepareGenericJWTSession(b.Name(), RSASign, KID, false)
   393  	authHeaders := map[string]string{"authorization": "Bearer " + jwtToken}
   394  
   395  	for i := 0; i < b.N; i++ {
   396  		ts.Run(b, test.TestCase{
   397  			Headers: authHeaders, Code: http.StatusOK,
   398  		})
   399  	}
   400  }
   401  
   402  func TestJWTSessionRSABearerInvalid(t *testing.T) {
   403  	ts := StartTest()
   404  	defer ts.Close()
   405  
   406  	//default values, same as before (keeps backward compatibility)
   407  	_, jwtToken := prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   408  	authHeaders := map[string]string{"authorization": "Bearer: " + jwtToken} // extra ":" makes the value invalid
   409  
   410  	t.Run("Request with invalid Bearer", func(t *testing.T) {
   411  		ts.Run(t, test.TestCase{
   412  			Headers:   authHeaders,
   413  			Code:      http.StatusForbidden,
   414  			BodyMatch: "Key not authorized:illegal base64 data at input byte 6",
   415  		})
   416  	})
   417  }
   418  
   419  func TestJWTSessionRSABearerInvalidTwoBears(t *testing.T) {
   420  	ts := StartTest()
   421  	defer ts.Close()
   422  
   423  	//default values, same as before (keeps backward compatibility)
   424  	_, jwtToken := prepareGenericJWTSession(t.Name(), RSASign, KID, false)
   425  	authHeaders1 := map[string]string{"authorization": "Bearer bearer" + jwtToken}
   426  
   427  	t.Run("Request with Bearer bearer", func(t *testing.T) {
   428  		ts.Run(t, test.TestCase{
   429  			Headers: authHeaders1, Code: http.StatusForbidden,
   430  		})
   431  	})
   432  
   433  	authHeaders2 := map[string]string{"authorization": "bearer Bearer" + jwtToken}
   434  
   435  	t.Run("Request with bearer Bearer", func(t *testing.T) {
   436  		ts.Run(t, test.TestCase{
   437  			Headers: authHeaders2, Code: http.StatusForbidden,
   438  		})
   439  	})
   440  }
   441  
   442  // JWTSessionRSAWithRawSourceOnWithClientID
   443  
   444  func prepareJWTSessionRSAWithRawSourceOnWithClientID(isBench bool) string {
   445  	spec := BuildAndLoadAPI(func(spec *APISpec) {
   446  		spec.OrgID = "default"
   447  		spec.UseKeylessAccess = false
   448  		spec.EnableJWT = true
   449  		spec.JWTSigningMethod = RSASign
   450  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
   451  		spec.JWTIdentityBaseField = "user_id"
   452  		spec.JWTClientIDBaseField = "azp"
   453  		spec.Proxy.ListenPath = "/"
   454  	})[0]
   455  
   456  	policyID := CreatePolicy(func(p *user.Policy) {
   457  		p.OrgID = "default"
   458  		p.AccessRights = map[string]user.AccessDefinition{
   459  			spec.APIID: {
   460  				APIName:  spec.APIDefinition.Name,
   461  				APIID:    spec.APIID,
   462  				Versions: []string{"default"},
   463  			},
   464  		}
   465  	})
   466  
   467  	tokenID := ""
   468  	if isBench {
   469  		tokenID = uuid.New()
   470  	} else {
   471  		tokenID = "1234567891010101"
   472  	}
   473  	session := createJWTSessionWithRSAWithPolicy(policyID)
   474  
   475  	spec.SessionManager.ResetQuota(tokenID, session, false)
   476  	spec.SessionManager.UpdateSession(tokenID, session, 60, false)
   477  
   478  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
   479  		t.Header["kid"] = "12345"
   480  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
   481  		t.Claims.(jwt.MapClaims)["user_id"] = "user"
   482  		t.Claims.(jwt.MapClaims)["azp"] = tokenID
   483  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
   484  	})
   485  
   486  	return jwtToken
   487  }
   488  
   489  func TestJWTSessionRSAWithRawSourceOnWithClientID(t *testing.T) {
   490  	ts := StartTest()
   491  	defer ts.Close()
   492  
   493  	jwtToken := prepareJWTSessionRSAWithRawSourceOnWithClientID(false)
   494  	authHeaders := map[string]string{"authorization": jwtToken}
   495  
   496  	t.Run("Initial request with no policy base field in JWT", func(t *testing.T) {
   497  		ts.Run(t, test.TestCase{
   498  			Headers: authHeaders, Code: http.StatusOK,
   499  		})
   500  	})
   501  }
   502  
   503  func BenchmarkJWTSessionRSAWithRawSourceOnWithClientID(b *testing.B) {
   504  	b.ReportAllocs()
   505  
   506  	ts := StartTest()
   507  	defer ts.Close()
   508  
   509  	jwtToken := prepareJWTSessionRSAWithRawSourceOnWithClientID(true)
   510  	authHeaders := map[string]string{"authorization": jwtToken}
   511  
   512  	for i := 0; i < b.N; i++ {
   513  		ts.Run(b, test.TestCase{
   514  			Headers: authHeaders, Code: http.StatusOK,
   515  		})
   516  	}
   517  }
   518  
   519  // JWTSessionRSAWithRawSource
   520  
   521  func prepareJWTSessionRSAWithRawSource() string {
   522  	BuildAndLoadAPI(func(spec *APISpec) {
   523  		spec.UseKeylessAccess = false
   524  		spec.EnableJWT = true
   525  		spec.JWTSigningMethod = RSASign
   526  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
   527  		spec.JWTIdentityBaseField = "user_id"
   528  		spec.JWTPolicyFieldName = "policy_id"
   529  		spec.Proxy.ListenPath = "/"
   530  	})
   531  
   532  	pID := CreatePolicy()
   533  
   534  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
   535  		t.Header["kid"] = "12345"
   536  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
   537  		t.Claims.(jwt.MapClaims)["user_id"] = "user"
   538  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
   539  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
   540  	})
   541  
   542  	return jwtToken
   543  }
   544  
   545  func TestJWTSessionRSAWithRawSource(t *testing.T) {
   546  	ts := StartTest()
   547  	defer ts.Close()
   548  
   549  	jwtToken := prepareJWTSessionRSAWithRawSource()
   550  
   551  	authHeaders := map[string]string{"authorization": jwtToken}
   552  	t.Run("Initial request with valid policy", func(t *testing.T) {
   553  		ts.Run(t, test.TestCase{
   554  			Headers: authHeaders, Code: http.StatusOK,
   555  		})
   556  	})
   557  }
   558  
   559  func BenchmarkJWTSessionRSAWithRawSource(b *testing.B) {
   560  	b.ReportAllocs()
   561  
   562  	ts := StartTest()
   563  	defer ts.Close()
   564  
   565  	jwtToken := prepareJWTSessionRSAWithRawSource()
   566  
   567  	authHeaders := map[string]string{"authorization": jwtToken}
   568  
   569  	for i := 0; i < b.N; i++ {
   570  		ts.Run(
   571  			b,
   572  			test.TestCase{
   573  				Headers: authHeaders,
   574  				Code:    http.StatusOK,
   575  			},
   576  		)
   577  	}
   578  }
   579  
   580  func TestJWTSessionRSAWithRawSourceInvalidPolicyID(t *testing.T) {
   581  	ts := StartTest()
   582  	defer ts.Close()
   583  
   584  	spec := BuildAPI(func(spec *APISpec) {
   585  		spec.UseKeylessAccess = false
   586  		spec.EnableJWT = true
   587  		spec.JWTSigningMethod = RSASign
   588  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
   589  		spec.JWTIdentityBaseField = "user_id"
   590  		spec.JWTPolicyFieldName = "policy_id"
   591  		spec.Proxy.ListenPath = "/"
   592  	})[0]
   593  
   594  	LoadAPI(spec)
   595  
   596  	CreatePolicy()
   597  
   598  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
   599  		t.Header["kid"] = "12345"
   600  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
   601  		t.Claims.(jwt.MapClaims)["user_id"] = "user"
   602  		t.Claims.(jwt.MapClaims)["policy_id"] = "abcxyz"
   603  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
   604  	})
   605  
   606  	authHeaders := map[string]string{"authorization": jwtToken}
   607  	t.Run("Initial request with invalid policy", func(t *testing.T) {
   608  		ts.Run(t, test.TestCase{
   609  			Headers:   authHeaders,
   610  			Code:      http.StatusForbidden,
   611  			BodyMatch: "key not authorized: no matching policy",
   612  		})
   613  	})
   614  }
   615  
   616  func TestJWTSessionExpiresAtValidationConfigs(t *testing.T) {
   617  	ts := StartTest()
   618  	defer ts.Close()
   619  
   620  	pID := CreatePolicy()
   621  	jwtAuthHeaderGen := func(skew time.Duration) map[string]string {
   622  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
   623  			t.Claims.(jwt.MapClaims)["policy_id"] = pID
   624  			t.Claims.(jwt.MapClaims)["user_id"] = "user123"
   625  			t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(skew).Unix()
   626  		})
   627  
   628  		return map[string]string{"authorization": jwtToken}
   629  	}
   630  
   631  	spec := BuildAPI(func(spec *APISpec) {
   632  		spec.UseKeylessAccess = false
   633  		spec.EnableJWT = true
   634  		spec.JWTSigningMethod = RSASign
   635  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
   636  		spec.JWTIdentityBaseField = "user_id"
   637  		spec.JWTPolicyFieldName = "policy_id"
   638  		spec.Proxy.ListenPath = "/"
   639  	})[0]
   640  
   641  	// This test is successful by definition
   642  	t.Run("Expiry_After_now--Valid_jwt", func(t *testing.T) {
   643  		spec.JWTExpiresAtValidationSkew = 0 //Default value
   644  		LoadAPI(spec)
   645  
   646  		ts.Run(t, test.TestCase{
   647  			Headers: jwtAuthHeaderGen(+time.Second), Code: http.StatusOK,
   648  		})
   649  	})
   650  
   651  	// This test is successful by definition, so it's true also with skew, but just to avoid confusion.
   652  	t.Run("Expiry_After_now-Add_skew--Valid_jwt", func(t *testing.T) {
   653  		spec.JWTExpiresAtValidationSkew = 1
   654  		LoadAPI(spec)
   655  
   656  		ts.Run(t, test.TestCase{
   657  			Headers: jwtAuthHeaderGen(+time.Second), Code: http.StatusOK,
   658  		})
   659  	})
   660  
   661  	t.Run("Expiry_Before_now--Invalid_jwt", func(t *testing.T) {
   662  		spec.JWTExpiresAtValidationSkew = 0 //Default value
   663  		LoadAPI(spec)
   664  
   665  		ts.Run(t, test.TestCase{
   666  			Headers:   jwtAuthHeaderGen(-time.Second),
   667  			Code:      http.StatusUnauthorized,
   668  			BodyMatch: "Key not authorized: token has expired",
   669  		})
   670  	})
   671  
   672  	t.Run("Expired_token-Before_now-Huge_skew--Valid_jwt", func(t *testing.T) {
   673  		spec.JWTExpiresAtValidationSkew = 1000 // This value doesn't matter since validation is disabled
   674  		LoadAPI(spec)
   675  
   676  		ts.Run(t, test.TestCase{
   677  			Headers: jwtAuthHeaderGen(-time.Second), Code: http.StatusOK,
   678  		})
   679  	})
   680  
   681  	t.Run("Expired_token-Before_now-Add_skew--Valid_jwt", func(t *testing.T) {
   682  		spec.JWTExpiresAtValidationSkew = 2
   683  		LoadAPI(spec)
   684  
   685  		ts.Run(t, test.TestCase{
   686  			Headers: jwtAuthHeaderGen(-time.Second), Code: http.StatusOK,
   687  		})
   688  	})
   689  }
   690  
   691  func TestJWTSessionIssueAtValidationConfigs(t *testing.T) {
   692  	ts := StartTest()
   693  	defer ts.Close()
   694  
   695  	pID := CreatePolicy()
   696  	jwtAuthHeaderGen := func(skew time.Duration) map[string]string {
   697  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
   698  			t.Claims.(jwt.MapClaims)["policy_id"] = pID
   699  			t.Claims.(jwt.MapClaims)["user_id"] = "user123"
   700  			t.Claims.(jwt.MapClaims)["iat"] = time.Now().Add(skew).Unix()
   701  		})
   702  
   703  		return map[string]string{"authorization": jwtToken}
   704  	}
   705  
   706  	spec := BuildAPI(func(spec *APISpec) {
   707  		spec.UseKeylessAccess = false
   708  		spec.EnableJWT = true
   709  		spec.JWTSigningMethod = "rsa"
   710  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
   711  		spec.JWTIdentityBaseField = "user_id"
   712  		spec.JWTPolicyFieldName = "policy_id"
   713  		spec.Proxy.ListenPath = "/"
   714  	})[0]
   715  
   716  	// This test is successful by definition
   717  	t.Run("IssuedAt_Before_now-no_skew--Valid_jwt", func(t *testing.T) {
   718  		spec.JWTIssuedAtValidationSkew = 0
   719  
   720  		LoadAPI(spec)
   721  
   722  		ts.Run(t, test.TestCase{
   723  			Headers: jwtAuthHeaderGen(-time.Second), Code: http.StatusOK,
   724  		})
   725  	})
   726  
   727  	t.Run("Expiry_after_now--Invalid_jwt", func(t *testing.T) {
   728  		spec.JWTExpiresAtValidationSkew = 0 //Default value
   729  
   730  		LoadAPI(spec)
   731  
   732  		ts.Run(t, test.TestCase{
   733  			Headers: jwtAuthHeaderGen(-time.Second), Code: http.StatusOK,
   734  		})
   735  	})
   736  
   737  	t.Run("IssueAt-After_now-no_skew--Invalid_jwt", func(t *testing.T) {
   738  		spec.JWTIssuedAtValidationSkew = 0
   739  
   740  		LoadAPI(spec)
   741  
   742  		ts.Run(t, test.TestCase{
   743  			Headers:   jwtAuthHeaderGen(+time.Minute),
   744  			Code:      http.StatusUnauthorized,
   745  			BodyMatch: "Key not authorized: token used before issued",
   746  		})
   747  	})
   748  
   749  	t.Run("IssueAt--After_now-Huge_skew--valid_jwt", func(t *testing.T) {
   750  		spec.JWTIssuedAtValidationSkew = 1000 // This value doesn't matter since validation is disabled
   751  		LoadAPI(spec)
   752  
   753  		ts.Run(t, test.TestCase{
   754  			Headers: jwtAuthHeaderGen(+time.Second),
   755  			Code:    http.StatusOK,
   756  		})
   757  	})
   758  
   759  	// True by definition
   760  	t.Run("IssueAt-Before_now-Add_skew--not_valid_jwt", func(t *testing.T) {
   761  		spec.JWTIssuedAtValidationSkew = 2 // 2 seconds
   762  		LoadAPI(spec)
   763  
   764  		ts.Run(t, test.TestCase{
   765  			Headers: jwtAuthHeaderGen(-3 * time.Second), Code: http.StatusOK,
   766  		})
   767  	})
   768  
   769  	t.Run("IssueAt-After_now-Add_skew--Valid_jwt", func(t *testing.T) {
   770  		spec.JWTIssuedAtValidationSkew = 1
   771  
   772  		LoadAPI(spec)
   773  
   774  		ts.Run(t, test.TestCase{
   775  			Headers: jwtAuthHeaderGen(+time.Second), Code: http.StatusOK,
   776  		})
   777  	})
   778  }
   779  
   780  func TestJWTSessionNotBeforeValidationConfigs(t *testing.T) {
   781  	ts := StartTest()
   782  	defer ts.Close()
   783  
   784  	pID := CreatePolicy()
   785  	jwtAuthHeaderGen := func(skew time.Duration) map[string]string {
   786  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
   787  			t.Claims.(jwt.MapClaims)["policy_id"] = pID
   788  			t.Claims.(jwt.MapClaims)["user_id"] = "user123"
   789  			t.Claims.(jwt.MapClaims)["nbf"] = time.Now().Add(skew).Unix()
   790  		})
   791  		return map[string]string{"authorization": jwtToken}
   792  	}
   793  
   794  	spec := BuildAPI(func(spec *APISpec) {
   795  		spec.UseKeylessAccess = false
   796  		spec.EnableJWT = true
   797  		spec.Proxy.ListenPath = "/"
   798  		spec.JWTSigningMethod = "rsa"
   799  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
   800  		spec.JWTIdentityBaseField = "user_id"
   801  		spec.JWTPolicyFieldName = "policy_id"
   802  	})[0]
   803  
   804  	// This test is successful by definition
   805  	t.Run("NotBefore_Before_now-Valid_jwt", func(t *testing.T) {
   806  		spec.JWTNotBeforeValidationSkew = 0
   807  
   808  		LoadAPI(spec)
   809  
   810  		ts.Run(t, test.TestCase{
   811  			Headers: jwtAuthHeaderGen(-time.Second), Code: http.StatusOK,
   812  		})
   813  	})
   814  
   815  	t.Run("NotBefore_After_now--Invalid_jwt", func(t *testing.T) {
   816  		spec.JWTNotBeforeValidationSkew = 0 //Default value
   817  
   818  		LoadAPI(spec)
   819  
   820  		ts.Run(t, test.TestCase{
   821  			Headers:   jwtAuthHeaderGen(+time.Second),
   822  			Code:      http.StatusUnauthorized,
   823  			BodyMatch: "Key not authorized: token is not valid yet",
   824  		})
   825  	})
   826  
   827  	t.Run("NotBefore_After_now-Add_skew--valid_jwt", func(t *testing.T) {
   828  		spec.JWTNotBeforeValidationSkew = 1
   829  
   830  		LoadAPI(spec)
   831  
   832  		ts.Run(t, test.TestCase{
   833  			Headers: jwtAuthHeaderGen(+time.Second), Code: http.StatusOK,
   834  		})
   835  	})
   836  
   837  	t.Run("NotBefore_After_now-Huge_skew--valid_jwt", func(t *testing.T) {
   838  		spec.JWTNotBeforeValidationSkew = 1000 // This value is so high that it's actually similar to disabling the claim.
   839  
   840  		LoadAPI(spec)
   841  
   842  		ts.Run(t, test.TestCase{
   843  			Headers: jwtAuthHeaderGen(+time.Second), Code: http.StatusOK,
   844  		})
   845  	})
   846  }
   847  
   848  func TestJWTExistingSessionRSAWithRawSourceInvalidPolicyID(t *testing.T) {
   849  	ts := StartTest()
   850  	defer ts.Close()
   851  
   852  	spec := BuildAPI(func(spec *APISpec) {
   853  		spec.UseKeylessAccess = false
   854  		spec.EnableJWT = true
   855  		spec.JWTSigningMethod = RSASign
   856  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
   857  		spec.JWTIdentityBaseField = "user_id"
   858  		spec.JWTPolicyFieldName = "policy_id"
   859  		spec.Proxy.ListenPath = "/"
   860  	})[0]
   861  
   862  	LoadAPI(spec)
   863  
   864  	p1ID := CreatePolicy()
   865  	user_id := uuid.New()
   866  
   867  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
   868  		t.Header["kid"] = "12345"
   869  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
   870  		t.Claims.(jwt.MapClaims)["user_id"] = user_id
   871  		t.Claims.(jwt.MapClaims)["policy_id"] = p1ID
   872  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
   873  	})
   874  
   875  	authHeaders := map[string]string{"authorization": jwtToken}
   876  	t.Run("Initial request with valid policy", func(t *testing.T) {
   877  		ts.Run(t, test.TestCase{
   878  			Headers: authHeaders, Code: http.StatusOK,
   879  		})
   880  	})
   881  
   882  	// put in JWT invalid policy ID and do request again
   883  	jwtTokenInvalidPolicy := CreateJWKToken(func(t *jwt.Token) {
   884  		t.Header["kid"] = "12345"
   885  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
   886  		t.Claims.(jwt.MapClaims)["user_id"] = user_id
   887  		t.Claims.(jwt.MapClaims)["policy_id"] = "abcdef"
   888  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
   889  	})
   890  
   891  	authHeaders = map[string]string{"authorization": jwtTokenInvalidPolicy}
   892  	t.Run("Request with invalid policy in JWT", func(t *testing.T) {
   893  		ts.Run(t, test.TestCase{
   894  			Headers:   authHeaders,
   895  			BodyMatch: "key not authorized: no matching policy",
   896  			Code:      http.StatusForbidden,
   897  		})
   898  	})
   899  }
   900  
   901  func TestJWTScopeToPolicyMapping(t *testing.T) {
   902  	ts := StartTest()
   903  	defer ts.Close()
   904  
   905  	basePolicyID := CreatePolicy(func(p *user.Policy) {
   906  		p.ID = "base"
   907  		p.AccessRights = map[string]user.AccessDefinition{
   908  			"base-api": {
   909  				Limit: &user.APILimit{
   910  					Rate:     111,
   911  					Per:      3600,
   912  					QuotaMax: -1,
   913  				},
   914  			},
   915  		}
   916  		p.Partitions = user.PolicyPartitions{
   917  			PerAPI: true,
   918  		}
   919  	})
   920  
   921  	defaultPolicyID := CreatePolicy(func(p *user.Policy) {
   922  		p.ID = "default"
   923  		p.AccessRights = map[string]user.AccessDefinition{
   924  			"base-api": {
   925  				Limit: &user.APILimit{
   926  					QuotaMax: -1,
   927  				},
   928  			},
   929  		}
   930  	})
   931  
   932  	p1ID := CreatePolicy(func(p *user.Policy) {
   933  		p.ID = "p1"
   934  		p.AccessRights = map[string]user.AccessDefinition{
   935  			"api1": {
   936  				Limit: &user.APILimit{
   937  					Rate:     100,
   938  					Per:      60,
   939  					QuotaMax: -1,
   940  				},
   941  			},
   942  		}
   943  		p.Partitions = user.PolicyPartitions{
   944  			PerAPI: true,
   945  		}
   946  	})
   947  
   948  	p2ID := CreatePolicy(func(p *user.Policy) {
   949  		p.ID = "p2"
   950  		p.AccessRights = map[string]user.AccessDefinition{
   951  			"api2": {
   952  				Limit: &user.APILimit{
   953  					Rate:     500,
   954  					Per:      30,
   955  					QuotaMax: -1,
   956  				},
   957  			},
   958  		}
   959  		p.Partitions = user.PolicyPartitions{
   960  			PerAPI: true,
   961  		}
   962  	})
   963  
   964  	base := BuildAPI(func(spec *APISpec) {
   965  		spec.APIID = "base-api"
   966  		spec.UseKeylessAccess = false
   967  		spec.EnableJWT = true
   968  		spec.JWTSigningMethod = RSASign
   969  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
   970  		spec.JWTIdentityBaseField = "user_id"
   971  		spec.JWTPolicyFieldName = "policy_id"
   972  		spec.JWTDefaultPolicies = []string{defaultPolicyID}
   973  		spec.Proxy.ListenPath = "/base"
   974  		spec.JWTScopeToPolicyMapping = map[string]string{
   975  			"user:read":  p1ID,
   976  			"user:write": p2ID,
   977  		}
   978  		spec.OrgID = "default"
   979  	})[0]
   980  
   981  	spec1 := CloneAPI(base)
   982  	spec1.APIID = "api1"
   983  	spec1.Proxy.ListenPath = "/api1"
   984  
   985  	spec2 := CloneAPI(base)
   986  	spec2.APIID = "api2"
   987  	spec2.Proxy.ListenPath = "/api2"
   988  
   989  	spec3 := CloneAPI(base)
   990  	spec3.APIID = "api3"
   991  	spec3.Proxy.ListenPath = "/api3"
   992  
   993  	LoadAPI(base, spec1, spec2, spec3)
   994  
   995  	userID := "user-" + uuid.New()
   996  	user2ID := "user-" + uuid.New()
   997  	user3ID := "user-" + uuid.New()
   998  
   999  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1000  		t.Claims.(jwt.MapClaims)["user_id"] = userID
  1001  		t.Claims.(jwt.MapClaims)["policy_id"] = basePolicyID
  1002  		t.Claims.(jwt.MapClaims)["scope"] = "user:read user:write"
  1003  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1004  	})
  1005  
  1006  	jwtTokenWithoutBasePol := CreateJWKToken(func(t *jwt.Token) {
  1007  		t.Claims.(jwt.MapClaims)["user_id"] = user2ID
  1008  		t.Claims.(jwt.MapClaims)["scope"] = "user:read user:write"
  1009  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1010  	})
  1011  
  1012  	jwtTokenWithoutBasePolAndScopes := CreateJWKToken(func(t *jwt.Token) {
  1013  		t.Claims.(jwt.MapClaims)["user_id"] = user3ID
  1014  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1015  	})
  1016  
  1017  	authHeaders := map[string]string{"authorization": jwtToken}
  1018  	t.Run("Create JWT session with base and scopes", func(t *testing.T) {
  1019  		ts.Run(t,
  1020  			test.TestCase{
  1021  				Headers: authHeaders,
  1022  				Path:    "/base",
  1023  				Code:    http.StatusOK,
  1024  			})
  1025  	})
  1026  
  1027  	authHeaders = map[string]string{"authorization": jwtTokenWithoutBasePol}
  1028  	t.Run("Create JWT session without base and with scopes", func(t *testing.T) {
  1029  		ts.Run(t,
  1030  			test.TestCase{
  1031  				Headers: authHeaders,
  1032  				Path:    "/api1",
  1033  				Code:    http.StatusOK,
  1034  			})
  1035  	})
  1036  
  1037  	authHeaders = map[string]string{"authorization": jwtTokenWithoutBasePolAndScopes}
  1038  	t.Run("Create JWT session without base and without scopes", func(t *testing.T) {
  1039  		ts.Run(t,
  1040  			test.TestCase{
  1041  				Headers: authHeaders,
  1042  				Path:    "/base",
  1043  				Code:    http.StatusOK,
  1044  			})
  1045  	})
  1046  
  1047  	// check that key has right set of policies assigned - there should be all three - base one and two from scope
  1048  	sessionID := generateToken("default", fmt.Sprintf("%x", md5.Sum([]byte(userID))))
  1049  	t.Run("Request to check that session has got both based and scope policies", func(t *testing.T) {
  1050  		ts.Run(
  1051  			t,
  1052  			test.TestCase{
  1053  				Method:    http.MethodGet,
  1054  				Path:      "/tyk/keys/" + sessionID,
  1055  				AdminAuth: true,
  1056  				Code:      http.StatusOK,
  1057  				BodyMatchFunc: func(data []byte) bool {
  1058  					sessionData := user.SessionState{
  1059  						Mutex: &sync.RWMutex{},
  1060  					}
  1061  					json.Unmarshal(data, &sessionData)
  1062  
  1063  					expect := []string{basePolicyID, p1ID, p2ID}
  1064  					sort.Strings(sessionData.ApplyPolicies)
  1065  					sort.Strings(expect)
  1066  
  1067  					assert.Equal(t, sessionData.ApplyPolicies, expect)
  1068  					return true
  1069  				},
  1070  			},
  1071  		)
  1072  	})
  1073  
  1074  	// check that key has right set of policies assigned - there should be all three - base one and two from scope
  1075  	sessionID = generateToken("default", fmt.Sprintf("%x", md5.Sum([]byte(user2ID))))
  1076  	t.Run("If scopes present no default policy should be used", func(t *testing.T) {
  1077  		ts.Run(
  1078  			t,
  1079  			test.TestCase{
  1080  				Method:    http.MethodGet,
  1081  				Path:      "/tyk/keys/" + sessionID,
  1082  				AdminAuth: true,
  1083  				Code:      http.StatusOK,
  1084  				BodyMatchFunc: func(data []byte) bool {
  1085  					sessionData := user.SessionState{
  1086  						Mutex: &sync.RWMutex{},
  1087  					}
  1088  					json.Unmarshal(data, &sessionData)
  1089  
  1090  					assert.Equal(t, sessionData.ApplyPolicies, []string{p1ID, p2ID})
  1091  
  1092  					return true
  1093  				},
  1094  			},
  1095  		)
  1096  	})
  1097  
  1098  	// check that key has right set of policies assigned - there should be all three - base one and two from scope
  1099  	sessionID = generateToken("default", fmt.Sprintf("%x", md5.Sum([]byte(user3ID))))
  1100  	t.Run("Default policy should be applied if no scopes found", func(t *testing.T) {
  1101  		ts.Run(
  1102  			t,
  1103  			test.TestCase{
  1104  				Method:    http.MethodGet,
  1105  				Path:      "/tyk/keys/" + sessionID,
  1106  				AdminAuth: true,
  1107  				Code:      http.StatusOK,
  1108  				BodyMatchFunc: func(data []byte) bool {
  1109  					sessionData := user.SessionState{
  1110  						Mutex: &sync.RWMutex{},
  1111  					}
  1112  					json.Unmarshal(data, &sessionData)
  1113  
  1114  					assert.Equal(t, sessionData.ApplyPolicies, []string{defaultPolicyID})
  1115  
  1116  					return true
  1117  				},
  1118  			},
  1119  		)
  1120  	})
  1121  
  1122  	authHeaders = map[string]string{"authorization": jwtToken}
  1123  	sessionID = generateToken("default", fmt.Sprintf("%x", md5.Sum([]byte(userID))))
  1124  	// try to access api1 using JWT issued via base-api
  1125  	t.Run("Request to api1", func(t *testing.T) {
  1126  		ts.Run(
  1127  			t,
  1128  			test.TestCase{
  1129  				Headers: authHeaders,
  1130  				Method:  http.MethodGet,
  1131  				Path:    "/api1",
  1132  				Code:    http.StatusOK,
  1133  			},
  1134  		)
  1135  	})
  1136  
  1137  	// try to access api2 using JWT issued via base-api
  1138  	t.Run("Request to api2", func(t *testing.T) {
  1139  		ts.Run(
  1140  			t,
  1141  			test.TestCase{
  1142  				Headers: authHeaders,
  1143  				Method:  http.MethodGet,
  1144  				Path:    "/api2",
  1145  				Code:    http.StatusOK,
  1146  			},
  1147  		)
  1148  	})
  1149  
  1150  	// try to access api3 (which is not granted via base policy nor scope-policy mapping) using JWT issued via base-api
  1151  	t.Run("Request to api3", func(t *testing.T) {
  1152  		ts.Run(
  1153  			t,
  1154  			test.TestCase{
  1155  				Headers: authHeaders,
  1156  				Method:  http.MethodGet,
  1157  				Path:    "/api3",
  1158  				Code:    http.StatusForbidden,
  1159  			},
  1160  		)
  1161  	})
  1162  
  1163  	// try to change scope to policy mapping and request using existing session
  1164  	p3ID := CreatePolicy(func(p *user.Policy) {
  1165  		p.ID = "p3"
  1166  		p.AccessRights = map[string]user.AccessDefinition{
  1167  			spec3.APIID: {
  1168  				Limit: &user.APILimit{
  1169  					Rate:     500,
  1170  					Per:      30,
  1171  					QuotaMax: -1,
  1172  				},
  1173  			},
  1174  		}
  1175  		p.Partitions = user.PolicyPartitions{
  1176  			PerAPI: true,
  1177  		}
  1178  	})
  1179  
  1180  	base.JWTScopeToPolicyMapping = map[string]string{
  1181  		"user:read": p3ID,
  1182  	}
  1183  
  1184  	LoadAPI(base)
  1185  
  1186  	t.Run("Request with changed scope in JWT and key with existing session", func(t *testing.T) {
  1187  		ts.Run(t,
  1188  			test.TestCase{
  1189  				Headers: authHeaders,
  1190  				Path:    "/base",
  1191  				Code:    http.StatusOK,
  1192  			})
  1193  	})
  1194  
  1195  	// check that key has right set of policies assigned - there should be updated list (base one and one from scope)
  1196  	t.Run("Request to check that session has got changed apply_policies value", func(t *testing.T) {
  1197  		ts.Run(
  1198  			t,
  1199  			test.TestCase{
  1200  				Method:    http.MethodGet,
  1201  				Path:      "/tyk/keys/" + sessionID,
  1202  				AdminAuth: true,
  1203  				Code:      http.StatusOK,
  1204  				BodyMatchFunc: func(data []byte) bool {
  1205  					sessionData := user.SessionState{
  1206  						Mutex: &sync.RWMutex{},
  1207  					}
  1208  					json.Unmarshal(data, &sessionData)
  1209  
  1210  					assert.Equal(t, sessionData.ApplyPolicies, []string{basePolicyID, p3ID})
  1211  
  1212  					return true
  1213  				},
  1214  			},
  1215  		)
  1216  	})
  1217  
  1218  }
  1219  
  1220  func TestJWTExistingSessionRSAWithRawSourcePolicyIDChanged(t *testing.T) {
  1221  	ts := StartTest()
  1222  	defer ts.Close()
  1223  
  1224  	spec := BuildAPI(func(spec *APISpec) {
  1225  		spec.UseKeylessAccess = false
  1226  		spec.EnableJWT = true
  1227  		spec.JWTSigningMethod = RSASign
  1228  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
  1229  		spec.JWTIdentityBaseField = "user_id"
  1230  		spec.JWTPolicyFieldName = "policy_id"
  1231  		spec.Proxy.ListenPath = "/"
  1232  		spec.OrgID = "default"
  1233  	})[0]
  1234  
  1235  	LoadAPI(spec)
  1236  
  1237  	p1ID := CreatePolicy(func(p *user.Policy) {
  1238  		p.QuotaMax = 111
  1239  	})
  1240  	p2ID := CreatePolicy(func(p *user.Policy) {
  1241  		p.QuotaMax = 999
  1242  	})
  1243  	user_id := uuid.New()
  1244  
  1245  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1246  		t.Header["kid"] = "12345"
  1247  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1248  		t.Claims.(jwt.MapClaims)["user_id"] = user_id
  1249  		t.Claims.(jwt.MapClaims)["policy_id"] = p1ID
  1250  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1251  	})
  1252  
  1253  	sessionID := generateToken("default", fmt.Sprintf("%x", md5.Sum([]byte(user_id))))
  1254  
  1255  	authHeaders := map[string]string{"authorization": jwtToken}
  1256  	t.Run("Initial request with 1st policy", func(t *testing.T) {
  1257  		ts.Run(
  1258  			t,
  1259  			test.TestCase{
  1260  				Headers: authHeaders, Code: http.StatusOK,
  1261  			},
  1262  			test.TestCase{
  1263  				Method:    http.MethodGet,
  1264  				Path:      "/tyk/keys/" + sessionID,
  1265  				AdminAuth: true,
  1266  				Code:      http.StatusOK,
  1267  				BodyMatch: `"quota_max":111`,
  1268  			},
  1269  		)
  1270  	})
  1271  
  1272  	// check key/session quota
  1273  
  1274  	// put in JWT another valid policy ID and do request again
  1275  	jwtTokenAnotherPolicy := CreateJWKToken(func(t *jwt.Token) {
  1276  		t.Header["kid"] = "12345"
  1277  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1278  		t.Claims.(jwt.MapClaims)["user_id"] = user_id
  1279  		t.Claims.(jwt.MapClaims)["policy_id"] = p2ID
  1280  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1281  	})
  1282  
  1283  	authHeaders = map[string]string{"authorization": jwtTokenAnotherPolicy}
  1284  	t.Run("Request with new valid policy in JWT", func(t *testing.T) {
  1285  		ts.Run(t,
  1286  			test.TestCase{
  1287  				Headers: authHeaders, Code: http.StatusOK,
  1288  			},
  1289  			test.TestCase{
  1290  				Method:    http.MethodGet,
  1291  				Path:      "/tyk/keys/" + sessionID,
  1292  				AdminAuth: true,
  1293  				Code:      http.StatusOK,
  1294  				BodyMatch: `"quota_max":999`,
  1295  			},
  1296  		)
  1297  	})
  1298  }
  1299  
  1300  // JWTSessionRSAWithJWK
  1301  
  1302  func prepareJWTSessionRSAWithJWK() string {
  1303  	BuildAndLoadAPI(func(spec *APISpec) {
  1304  		spec.UseKeylessAccess = false
  1305  		spec.EnableJWT = true
  1306  		spec.JWTSigningMethod = RSASign
  1307  		spec.JWTSource = testHttpJWK
  1308  		spec.JWTIdentityBaseField = "user_id"
  1309  		spec.JWTPolicyFieldName = "policy_id"
  1310  		spec.Proxy.ListenPath = "/"
  1311  	})
  1312  
  1313  	pID := CreatePolicy()
  1314  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1315  		t.Header["kid"] = "12345"
  1316  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1317  		t.Claims.(jwt.MapClaims)["user_id"] = "user"
  1318  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1319  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1320  	})
  1321  
  1322  	return jwtToken
  1323  }
  1324  
  1325  func TestJWTSessionRSAWithJWK(t *testing.T) {
  1326  	ts := StartTest()
  1327  	defer ts.Close()
  1328  
  1329  	jwtToken := prepareJWTSessionRSAWithJWK()
  1330  	authHeaders := map[string]string{"authorization": jwtToken}
  1331  
  1332  	t.Run("JWTSessionRSAWithJWK", func(t *testing.T) {
  1333  		ts.Run(t, test.TestCase{
  1334  			Headers: authHeaders, Code: http.StatusOK,
  1335  		})
  1336  	})
  1337  }
  1338  
  1339  func BenchmarkJWTSessionRSAWithJWK(b *testing.B) {
  1340  	b.ReportAllocs()
  1341  
  1342  	ts := StartTest()
  1343  	defer ts.Close()
  1344  
  1345  	jwtToken := prepareJWTSessionRSAWithJWK()
  1346  	authHeaders := map[string]string{"authorization": jwtToken}
  1347  
  1348  	for i := 0; i < b.N; i++ {
  1349  		ts.Run(
  1350  			b,
  1351  			test.TestCase{
  1352  				Headers: authHeaders,
  1353  				Code:    http.StatusOK,
  1354  			},
  1355  		)
  1356  	}
  1357  }
  1358  
  1359  // JWTSessionRSAWithEncodedJWK
  1360  
  1361  func prepareJWTSessionRSAWithEncodedJWK() (*APISpec, string) {
  1362  	spec := BuildAPI(func(spec *APISpec) {
  1363  		spec.UseKeylessAccess = false
  1364  		spec.EnableJWT = true
  1365  		spec.JWTSigningMethod = RSASign
  1366  		spec.JWTIdentityBaseField = "user_id"
  1367  		spec.JWTPolicyFieldName = "policy_id"
  1368  		spec.Proxy.ListenPath = "/"
  1369  	})[0]
  1370  
  1371  	pID := CreatePolicy()
  1372  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1373  		t.Header["kid"] = "12345"
  1374  		// Set some claims
  1375  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1376  		t.Claims.(jwt.MapClaims)["user_id"] = "user"
  1377  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1378  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1379  	})
  1380  
  1381  	return spec, jwtToken
  1382  }
  1383  
  1384  func TestJWTSessionRSAWithEncodedJWK(t *testing.T) {
  1385  	ts := StartTest()
  1386  	defer ts.Close()
  1387  
  1388  	spec, jwtToken := prepareJWTSessionRSAWithEncodedJWK()
  1389  
  1390  	authHeaders := map[string]string{"authorization": jwtToken}
  1391  
  1392  	t.Run("Direct JWK URL", func(t *testing.T) {
  1393  		spec.JWTSource = testHttpJWK
  1394  		LoadAPI(spec)
  1395  
  1396  		ts.Run(t, test.TestCase{
  1397  			Headers: authHeaders, Code: http.StatusOK,
  1398  		})
  1399  	})
  1400  
  1401  	t.Run("Base64 JWK URL", func(t *testing.T) {
  1402  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(testHttpJWK))
  1403  		LoadAPI(spec)
  1404  
  1405  		ts.Run(t, test.TestCase{
  1406  			Headers: authHeaders, Code: http.StatusOK,
  1407  		})
  1408  	})
  1409  	t.Run("Direct JWK URL with der encoding", func(t *testing.T) {
  1410  		spec.JWTSource = testHttpJWKDER
  1411  		LoadAPI(spec)
  1412  
  1413  		ts.Run(t, test.TestCase{
  1414  			Headers: authHeaders, Code: http.StatusOK,
  1415  		})
  1416  	})
  1417  
  1418  	t.Run("Base64 JWK URL with der encoding", func(t *testing.T) {
  1419  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(testHttpJWKDER))
  1420  		LoadAPI(spec)
  1421  
  1422  		ts.Run(t, test.TestCase{
  1423  			Headers: authHeaders, Code: http.StatusOK,
  1424  		})
  1425  	})
  1426  }
  1427  
  1428  func TestParseRSAKeyFromJWK(t *testing.T) {
  1429  	sample := `MIIC9jCCAd6gAwIBAgIJIgAUUdWegHDtMA0GCSqGSIb3DQEBCwUAMCIxIDAeBgNVBAMTF3B1cGlsLXRlc3QuZXUuYXV0aDAuY29tMB4XDTE3MDMxMDE1MTUyMFoXDTMwMTExNzE1MTUyMFowIjEgMB4GA1UEAxMXcHVwaWwtdGVzdC5ldS5hdXRoMC5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDWW+2PEt6nWK7cTxpkiXYTOsAWi+CCGZzDZNtwqIiLDTIkBb+Hrb70hSMRNXjPckw9+FxYC/egluGEmcEidZbj260Qp63xYpvC8XNXrlvovJqvPLk8ETPolVqYNaWM1UoJsqBPIlmFlwVH+ExCjUL37Kay3gwRXTHVRiPfPCZanqWqMu8CbC+pby1sUaiTIW1bE15v5pdgTZUH94uuMfYTdnWY6DSPWKrgwQUxmn3TJN66DynPgRjMaZaCr6FiDItm1gqE74rkbRcE3nZGM3F+fxUNTsSKjvLBBBV9aDCO408zfCycR7J+HSO2bqBxnewYhweOx23U46A0WNKW5raxAgMBAAGjLzAtMAwGA1UdEwQFMAMBAf8wHQYDVR0OBBYEFCR9T3F1LtZa3AX+LjXX9av8m/2kMA0GCSqGSIb3DQEBCwUAA4IBAQBxot91iXDzJfQVaGV+KoCDuJmOrSLTolKbJOxVoilyY72LnIcQOLgHI5JN7X17GnESTsvMC7OiUcC0RYimfrc9pchWairU/Uky6t4XmOLHQsIKjXkqwkNn3vOkRZB9wsveFQpHVLBpBUZLcPYr+8ZQYegueJpW6zSOEkswOM1U+CzERZaY6dkD8nI8TzozQ6ZLV3iypW/gx/lLT8cQb0EMzLNKSOobT+NEnhhtpy1BnfpAwV8rGENYtyUpq2FTa3kQjBCrR5cBt/07yezyeX8Amcdst3PnLaZMn5k+Elj57FKKDRV+L9rYGeceLbKKJ0uSKuhR9LIVrFaa/pzUKekC`
  1430  	b, err := base64.StdEncoding.DecodeString(sample)
  1431  	if err != nil {
  1432  		t.Fatal(err)
  1433  	}
  1434  	_, err = jwt.ParseRSAPublicKeyFromPEM(b)
  1435  	if err == nil {
  1436  		t.Error("expected an error")
  1437  	}
  1438  	_, err = ParseRSAPublicKey(b)
  1439  	if err != nil {
  1440  		t.Error("decoding as default ", err)
  1441  	}
  1442  }
  1443  
  1444  func BenchmarkJWTSessionRSAWithEncodedJWK(b *testing.B) {
  1445  	b.ReportAllocs()
  1446  
  1447  	ts := StartTest()
  1448  	defer ts.Close()
  1449  
  1450  	spec, jwtToken := prepareJWTSessionRSAWithEncodedJWK()
  1451  	spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(testHttpJWK))
  1452  
  1453  	LoadAPI(spec)
  1454  
  1455  	authHeaders := map[string]string{"authorization": jwtToken}
  1456  
  1457  	for i := 0; i < b.N; i++ {
  1458  		ts.Run(
  1459  			b,
  1460  			test.TestCase{
  1461  				Headers: authHeaders,
  1462  				Code:    http.StatusOK,
  1463  			},
  1464  		)
  1465  	}
  1466  }
  1467  
  1468  func TestJWTHMACIdNewClaim(t *testing.T) {
  1469  	ts := StartTest()
  1470  	defer ts.Close()
  1471  
  1472  	//If we skip the check then the Id will be taken from SUB and the call will succeed
  1473  	_, jwtToken := prepareGenericJWTSession(t.Name(), HMACSign, "user-id", true)
  1474  	defer ResetTestConfig()
  1475  	authHeaders := map[string]string{"authorization": jwtToken}
  1476  	t.Run("Request with valid JWT/HMAC signature/id in user-id claim", func(t *testing.T) {
  1477  		ts.Run(t, test.TestCase{
  1478  			Headers: authHeaders, Code: http.StatusOK,
  1479  		})
  1480  	})
  1481  }
  1482  
  1483  func TestJWTRSAIdInClaimsWithBaseField(t *testing.T) {
  1484  	ts := StartTest()
  1485  	defer ts.Close()
  1486  
  1487  	BuildAndLoadAPI(func(spec *APISpec) {
  1488  		spec.UseKeylessAccess = false
  1489  		spec.EnableJWT = true
  1490  		spec.JWTSigningMethod = RSASign
  1491  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
  1492  		spec.JWTIdentityBaseField = "user_id"
  1493  		spec.JWTPolicyFieldName = "policy_id"
  1494  		spec.Proxy.ListenPath = "/"
  1495  	})
  1496  
  1497  	pID := CreatePolicy()
  1498  
  1499  	//First test - user id in the configured base field 'user_id'
  1500  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1501  		t.Header["kid"] = "12345"
  1502  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1503  		t.Claims.(jwt.MapClaims)["user_id"] = "user123@test.com"
  1504  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1505  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1506  	})
  1507  	authHeaders := map[string]string{"authorization": jwtToken}
  1508  	t.Run("Request with valid JWT/RSA signature/user id in user_id claim", func(t *testing.T) {
  1509  		ts.Run(t, test.TestCase{
  1510  			Headers: authHeaders, Code: http.StatusOK,
  1511  		})
  1512  	})
  1513  
  1514  	//user-id claim configured but it's empty - returning an error
  1515  	jwtToken = CreateJWKToken(func(t *jwt.Token) {
  1516  		t.Header["kid"] = "12345"
  1517  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1518  		t.Claims.(jwt.MapClaims)["user_id"] = ""
  1519  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1520  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1521  	})
  1522  	authHeaders = map[string]string{"authorization": jwtToken}
  1523  	t.Run("Request with valid JWT/RSA signature/empty user_id claim", func(t *testing.T) {
  1524  		ts.Run(t, test.TestCase{
  1525  			Headers:   authHeaders,
  1526  			Code:      http.StatusForbidden,
  1527  			BodyMatch: "found an empty user ID in predefined base field claim user_id",
  1528  		})
  1529  	})
  1530  
  1531  	//user-id claim configured but not found fallback to sub
  1532  	jwtToken = CreateJWKToken(func(t *jwt.Token) {
  1533  		t.Header["kid"] = "12345"
  1534  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1535  		t.Claims.(jwt.MapClaims)["sub"] = "user123@test.com"
  1536  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1537  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1538  	})
  1539  	authHeaders = map[string]string{"authorization": jwtToken}
  1540  	t.Run("Request with valid JWT/RSA signature/user id in sub claim", func(t *testing.T) {
  1541  		ts.Run(t, test.TestCase{
  1542  			Headers: authHeaders, Code: http.StatusOK,
  1543  		})
  1544  	})
  1545  
  1546  	//user-id claim not found fallback to sub that is empty
  1547  	jwtToken = CreateJWKToken(func(t *jwt.Token) {
  1548  		t.Header["kid"] = "12345"
  1549  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1550  		t.Claims.(jwt.MapClaims)["sub"] = ""
  1551  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1552  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1553  	})
  1554  	authHeaders = map[string]string{"authorization": jwtToken}
  1555  	t.Run("Request with valid JWT/RSA signature/empty sub claim", func(t *testing.T) {
  1556  		ts.Run(t, test.TestCase{
  1557  			Headers:   authHeaders,
  1558  			Code:      http.StatusForbidden,
  1559  			BodyMatch: "found an empty user ID in sub claim",
  1560  		})
  1561  	})
  1562  
  1563  	//user-id and sub claims not found
  1564  	jwtToken = CreateJWKToken(func(t *jwt.Token) {
  1565  		t.Header["kid"] = "12345"
  1566  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1567  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1568  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1569  	})
  1570  	authHeaders = map[string]string{"authorization": jwtToken}
  1571  	t.Run("Request with valid JWT/RSA signature/no base field or sub claims", func(t *testing.T) {
  1572  		ts.Run(t, test.TestCase{
  1573  			Headers:   authHeaders,
  1574  			Code:      http.StatusForbidden,
  1575  			BodyMatch: "no suitable claims for user ID were found",
  1576  		})
  1577  	})
  1578  }
  1579  
  1580  func TestJWTRSAIdInClaimsWithoutBaseField(t *testing.T) {
  1581  	ts := StartTest()
  1582  	defer ts.Close()
  1583  
  1584  	BuildAndLoadAPI(func(spec *APISpec) {
  1585  		spec.UseKeylessAccess = false
  1586  		spec.EnableJWT = true
  1587  		spec.JWTSigningMethod = RSASign
  1588  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
  1589  		spec.JWTIdentityBaseField = ""
  1590  		spec.JWTPolicyFieldName = "policy_id"
  1591  		spec.Proxy.ListenPath = "/"
  1592  	})
  1593  
  1594  	pID := CreatePolicy()
  1595  
  1596  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1597  		t.Header["kid"] = "12345"
  1598  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1599  		t.Claims.(jwt.MapClaims)["sub"] = "user123@test.com" //is ignored
  1600  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1601  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1602  	})
  1603  	authHeaders := map[string]string{"authorization": jwtToken}
  1604  	t.Run("Request with valid JWT/RSA signature/id found in default sub", func(t *testing.T) {
  1605  		ts.Run(t, test.TestCase{
  1606  			Headers: authHeaders, Code: http.StatusOK,
  1607  		})
  1608  	})
  1609  
  1610  	//Id is not found since there's no sub claim and user_id has't been set in the api def (spec.JWTIdentityBaseField)
  1611  	jwtToken = CreateJWKToken(func(t *jwt.Token) {
  1612  		t.Header["kid"] = "12345"
  1613  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1614  		t.Claims.(jwt.MapClaims)["user_id"] = "user123@test.com" //is ignored
  1615  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1616  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1617  	})
  1618  	authHeaders = map[string]string{"authorization": jwtToken}
  1619  	t.Run("Request with valid JWT/RSA signature/no id claims", func(t *testing.T) {
  1620  		ts.Run(t, test.TestCase{
  1621  			Headers:   authHeaders,
  1622  			Code:      http.StatusForbidden,
  1623  			BodyMatch: "no suitable claims for user ID were found",
  1624  		})
  1625  	})
  1626  }
  1627  
  1628  func TestJWTDefaultPolicies(t *testing.T) {
  1629  	const apiID = "testapid"
  1630  	const identitySource = "user_id"
  1631  	const policyFieldName = "policy_id"
  1632  
  1633  	ts := StartTest()
  1634  	defer ts.Close()
  1635  
  1636  	defPol1 := CreatePolicy(func(p *user.Policy) {
  1637  		p.AccessRights = map[string]user.AccessDefinition{
  1638  			apiID: {},
  1639  		}
  1640  		p.Partitions = user.PolicyPartitions{
  1641  			Quota: true,
  1642  		}
  1643  	})
  1644  
  1645  	defPol2 := CreatePolicy(func(p *user.Policy) {
  1646  		p.AccessRights = map[string]user.AccessDefinition{
  1647  			apiID: {},
  1648  		}
  1649  		p.Partitions = user.PolicyPartitions{
  1650  			RateLimit: true,
  1651  		}
  1652  	})
  1653  
  1654  	tokenPol := CreatePolicy(func(p *user.Policy) {
  1655  		p.AccessRights = map[string]user.AccessDefinition{
  1656  			apiID: {},
  1657  		}
  1658  		p.Partitions = user.PolicyPartitions{
  1659  			Acl: true,
  1660  		}
  1661  	})
  1662  
  1663  	spec := BuildAPI(func(spec *APISpec) {
  1664  		spec.APIID = apiID
  1665  		spec.UseKeylessAccess = false
  1666  		spec.EnableJWT = true
  1667  		spec.JWTSigningMethod = RSASign
  1668  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
  1669  		spec.JWTIdentityBaseField = identitySource
  1670  		spec.JWTDefaultPolicies = []string{
  1671  			defPol1,
  1672  			defPol2,
  1673  		}
  1674  		spec.Proxy.ListenPath = "/"
  1675  	})[0]
  1676  
  1677  	data := []byte("dummy")
  1678  	keyID := fmt.Sprintf("%x", md5.Sum(data))
  1679  	sessionID := generateToken(spec.OrgID, keyID)
  1680  
  1681  	assert := func(t *testing.T, expected []string) {
  1682  		session, _ := FallbackKeySesionManager.SessionDetail(sessionID, false)
  1683  		actual := session.GetPolicyIDs()
  1684  		if !reflect.DeepEqual(expected, actual) {
  1685  			t.Fatalf("Expected %v, actaul %v", expected, actual)
  1686  		}
  1687  	}
  1688  
  1689  	t.Run("Policy field name empty", func(t *testing.T) {
  1690  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1691  			t.Claims.(jwt.MapClaims)[identitySource] = "dummy"
  1692  			t.Claims.(jwt.MapClaims)[policyFieldName] = tokenPol
  1693  		})
  1694  
  1695  		authHeaders := map[string]string{"authorization": jwtToken}
  1696  
  1697  		// Default
  1698  		LoadAPI(spec)
  1699  		_, _ = ts.Run(t, test.TestCase{
  1700  			Headers: authHeaders, Code: http.StatusOK,
  1701  		})
  1702  		assert(t, []string{defPol1, defPol2})
  1703  
  1704  		// Same to check stored correctly
  1705  		_, _ = ts.Run(t, test.TestCase{
  1706  			Headers: authHeaders, Code: http.StatusOK,
  1707  		})
  1708  		assert(t, []string{defPol1, defPol2})
  1709  
  1710  		// Remove one of default policies
  1711  		spec.JWTDefaultPolicies = []string{defPol1}
  1712  		LoadAPI(spec)
  1713  		_, _ = ts.Run(t, test.TestCase{
  1714  			Headers: authHeaders, Code: http.StatusOK,
  1715  		})
  1716  		assert(t, []string{defPol1})
  1717  
  1718  		// Add a default policy
  1719  		spec.JWTDefaultPolicies = []string{defPol1, defPol2}
  1720  		LoadAPI(spec)
  1721  		_, _ = ts.Run(t, test.TestCase{
  1722  			Headers: authHeaders, Code: http.StatusOK,
  1723  		})
  1724  		assert(t, []string{defPol1, defPol2})
  1725  	})
  1726  
  1727  	t.Run("Policy field name nonempty but empty claim", func(t *testing.T) {
  1728  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1729  			t.Claims.(jwt.MapClaims)[identitySource] = "dummy"
  1730  			t.Claims.(jwt.MapClaims)[policyFieldName] = ""
  1731  		})
  1732  
  1733  		authHeaders := map[string]string{"authorization": jwtToken}
  1734  
  1735  		// Default
  1736  		LoadAPI(spec)
  1737  		_, _ = ts.Run(t, test.TestCase{
  1738  			Headers: authHeaders, Code: http.StatusOK,
  1739  		})
  1740  		assert(t, []string{defPol1, defPol2})
  1741  
  1742  		// Same to check stored correctly
  1743  		_, _ = ts.Run(t, test.TestCase{
  1744  			Headers: authHeaders, Code: http.StatusOK,
  1745  		})
  1746  		assert(t, []string{defPol1, defPol2})
  1747  
  1748  		// Remove one of default policies
  1749  		spec.JWTDefaultPolicies = []string{defPol1}
  1750  		LoadAPI(spec)
  1751  		_, _ = ts.Run(t, test.TestCase{
  1752  			Headers: authHeaders, Code: http.StatusOK,
  1753  		})
  1754  		assert(t, []string{defPol1})
  1755  
  1756  		// Add a default policy
  1757  		spec.JWTDefaultPolicies = []string{defPol1, defPol2}
  1758  		LoadAPI(spec)
  1759  		_, _ = ts.Run(t, test.TestCase{
  1760  			Headers: authHeaders, Code: http.StatusOK,
  1761  		})
  1762  		assert(t, []string{defPol1, defPol2})
  1763  	})
  1764  
  1765  	t.Run("Policy field name nonempty invalid policy ID in claim", func(t *testing.T) {
  1766  		spec.JWTPolicyFieldName = policyFieldName
  1767  		LoadAPI(spec)
  1768  
  1769  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1770  			t.Claims.(jwt.MapClaims)[identitySource] = "dummy"
  1771  			t.Claims.(jwt.MapClaims)[policyFieldName] = "invalid"
  1772  		})
  1773  
  1774  		authHeaders := map[string]string{"authorization": jwtToken}
  1775  
  1776  		_, _ = ts.Run(t, []test.TestCase{
  1777  			{Headers: authHeaders, Code: http.StatusForbidden},
  1778  			{Headers: authHeaders, Code: http.StatusForbidden},
  1779  		}...)
  1780  
  1781  		// Reset
  1782  		spec.JWTPolicyFieldName = ""
  1783  	})
  1784  
  1785  	t.Run("Default to Claim transition", func(t *testing.T) {
  1786  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1787  			t.Claims.(jwt.MapClaims)[identitySource] = "dummy"
  1788  			t.Claims.(jwt.MapClaims)[policyFieldName] = tokenPol
  1789  		})
  1790  
  1791  		authHeaders := map[string]string{"authorization": jwtToken}
  1792  
  1793  		// Default
  1794  		LoadAPI(spec)
  1795  		_, _ = ts.Run(t, test.TestCase{
  1796  			Headers: authHeaders, Code: http.StatusOK,
  1797  		})
  1798  		assert(t, []string{defPol1, defPol2})
  1799  
  1800  		// Same to check stored correctly
  1801  		LoadAPI(spec)
  1802  		_, _ = ts.Run(t, test.TestCase{
  1803  			Headers: authHeaders, Code: http.StatusOK,
  1804  		})
  1805  		assert(t, []string{defPol1, defPol2})
  1806  
  1807  		// Claim
  1808  		spec.JWTPolicyFieldName = policyFieldName
  1809  		LoadAPI(spec)
  1810  		_, _ = ts.Run(t, test.TestCase{
  1811  			Headers: authHeaders, Code: http.StatusOK,
  1812  		})
  1813  		assert(t, []string{tokenPol})
  1814  	})
  1815  
  1816  	t.Run("Claim to Default transition", func(t *testing.T) {
  1817  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1818  			t.Claims.(jwt.MapClaims)[identitySource] = "dummy"
  1819  			t.Claims.(jwt.MapClaims)[policyFieldName] = tokenPol
  1820  		})
  1821  
  1822  		authHeaders := map[string]string{"authorization": jwtToken}
  1823  
  1824  		// Claim
  1825  		spec.JWTPolicyFieldName = policyFieldName
  1826  		LoadAPI(spec)
  1827  		_, _ = ts.Run(t, test.TestCase{
  1828  			Headers: authHeaders, Code: http.StatusOK,
  1829  		})
  1830  		assert(t, []string{tokenPol})
  1831  
  1832  		// Same to check stored correctly
  1833  		_, _ = ts.Run(t, test.TestCase{
  1834  			Headers: authHeaders, Code: http.StatusOK,
  1835  		})
  1836  		assert(t, []string{tokenPol})
  1837  
  1838  		// Default
  1839  		spec.JWTPolicyFieldName = ""
  1840  		LoadAPI(spec)
  1841  		_, _ = ts.Run(t, test.TestCase{
  1842  			Headers: authHeaders, Code: http.StatusOK,
  1843  		})
  1844  		assert(t, []string{defPol1, defPol2})
  1845  	})
  1846  }
  1847  
  1848  func TestJWTECDSASign(t *testing.T) {
  1849  	ts := StartTest()
  1850  	defer ts.Close()
  1851  
  1852  	//If we skip the check then the Id will be taken from SUB and the call will succeed
  1853  	_, jwtToken := prepareGenericJWTSession(t.Name(), ECDSASign, KID, false)
  1854  	defer ResetTestConfig()
  1855  	authHeaders := map[string]string{"authorization": jwtToken}
  1856  	t.Run("Request with valid JWT/ECDSA", func(t *testing.T) {
  1857  		ts.Run(t, test.TestCase{
  1858  			Headers: authHeaders, Code: http.StatusOK,
  1859  		})
  1860  	})
  1861  }
  1862  
  1863  func TestJWTUnknownSign(t *testing.T) {
  1864  	ts := StartTest()
  1865  	defer ts.Close()
  1866  
  1867  	//If we skip the check then the Id will be taken from SUB and the call will succeed
  1868  	_, jwtToken := prepareGenericJWTSession(t.Name(), "bla", KID, false)
  1869  	defer ResetTestConfig()
  1870  	authHeaders := map[string]string{"authorization": jwtToken}
  1871  	t.Run("Request with valid JWT/ECDSA signature needs a test. currently defaults to HMAC", func(t *testing.T) {
  1872  		ts.Run(t, test.TestCase{
  1873  			Headers: authHeaders, Code: http.StatusOK,
  1874  		})
  1875  	})
  1876  }
  1877  
  1878  func TestJWTRSAInvalidPublickKey(t *testing.T) {
  1879  	ts := StartTest()
  1880  	defer ts.Close()
  1881  
  1882  	BuildAndLoadAPI(func(spec *APISpec) {
  1883  		spec.UseKeylessAccess = false
  1884  		spec.EnableJWT = true
  1885  		spec.JWTSigningMethod = RSASign
  1886  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKeyinvalid))
  1887  		spec.JWTPolicyFieldName = "policy_id"
  1888  		spec.Proxy.ListenPath = "/"
  1889  	})
  1890  
  1891  	pID := CreatePolicy()
  1892  
  1893  	jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1894  		t.Header["kid"] = "12345"
  1895  		t.Claims.(jwt.MapClaims)["foo"] = "bar"
  1896  		t.Claims.(jwt.MapClaims)["sub"] = "user123@test.com" //is ignored
  1897  		t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1898  		t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Hour * 72).Unix()
  1899  	})
  1900  	authHeaders := map[string]string{"authorization": jwtToken}
  1901  	t.Run("Request with valid JWT/RSA signature/invalid public key", func(t *testing.T) {
  1902  		ts.Run(t, test.TestCase{
  1903  			Headers:   authHeaders,
  1904  			Code:      http.StatusForbidden,
  1905  			BodyMatch: "Key not authorized:asn1: structure error:",
  1906  		})
  1907  	})
  1908  }
  1909  
  1910  func createExpiringPolicy(pGen ...func(p *user.Policy)) string {
  1911  	pID := keyGen.GenerateAuthKey("")
  1912  	pol := CreateStandardPolicy()
  1913  	pol.ID = pID
  1914  	pol.KeyExpiresIn = 1
  1915  
  1916  	if len(pGen) > 0 {
  1917  		pGen[0](pol)
  1918  	}
  1919  
  1920  	policiesMu.Lock()
  1921  	policiesByID[pID] = *pol
  1922  	policiesMu.Unlock()
  1923  
  1924  	return pID
  1925  }
  1926  
  1927  func TestJWTExpOverride(t *testing.T) {
  1928  	ts := StartTest()
  1929  	defer ts.Close()
  1930  
  1931  	BuildAndLoadAPI(func(spec *APISpec) {
  1932  		spec.UseKeylessAccess = false
  1933  		spec.EnableJWT = true
  1934  		spec.JWTSigningMethod = RSASign
  1935  		spec.JWTSource = base64.StdEncoding.EncodeToString([]byte(jwtRSAPubKey))
  1936  		spec.JWTPolicyFieldName = "policy_id"
  1937  		spec.Proxy.ListenPath = "/"
  1938  	})
  1939  
  1940  	t.Run("JWT expiration bigger then policy", func(t *testing.T) {
  1941  		//create policy which sets keys to have expiry in one second
  1942  		pID := CreatePolicy(func(p *user.Policy) {
  1943  			p.KeyExpiresIn = 1
  1944  		})
  1945  
  1946  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1947  			t.Claims.(jwt.MapClaims)["sub"] = uuid.New()
  1948  			t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1949  			t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Second * 72).Unix()
  1950  		})
  1951  
  1952  		authHeaders := map[string]string{"authorization": jwtToken}
  1953  
  1954  		//JWT expiry overrides internal token which gets expiry from policy so second request will pass
  1955  		ts.Run(t, []test.TestCase{
  1956  			{Headers: authHeaders, Code: http.StatusOK, Delay: 1100 * time.Millisecond},
  1957  			{Headers: authHeaders, Code: http.StatusOK},
  1958  		}...)
  1959  	})
  1960  
  1961  	t.Run("JWT expiration smaller then policy", func(t *testing.T) {
  1962  		pID := CreatePolicy(func(p *user.Policy) {
  1963  			p.KeyExpiresIn = 5
  1964  		})
  1965  
  1966  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1967  			t.Claims.(jwt.MapClaims)["sub"] = uuid.New()
  1968  			t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1969  			t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(-time.Second).Unix()
  1970  		})
  1971  
  1972  		authHeaders := map[string]string{"authorization": jwtToken}
  1973  
  1974  		// Should not allow expired JWTs
  1975  		ts.Run(t, []test.TestCase{
  1976  			{Headers: authHeaders, Code: http.StatusUnauthorized},
  1977  		}...)
  1978  	})
  1979  
  1980  	t.Run("JWT expired but renewed, policy without expiration", func(t *testing.T) {
  1981  		pID := CreatePolicy(func(p *user.Policy) {
  1982  			p.KeyExpiresIn = 0
  1983  		})
  1984  
  1985  		userID := uuid.New()
  1986  
  1987  		jwtToken := CreateJWKToken(func(t *jwt.Token) {
  1988  			t.Claims.(jwt.MapClaims)["sub"] = userID
  1989  			t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1990  			t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Second).Unix()
  1991  		})
  1992  
  1993  		newJwtToken := CreateJWKToken(func(t *jwt.Token) {
  1994  			t.Claims.(jwt.MapClaims)["sub"] = userID
  1995  			t.Claims.(jwt.MapClaims)["policy_id"] = pID
  1996  			t.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(5 * time.Second).Unix()
  1997  		})
  1998  
  1999  		authHeaders := map[string]string{"authorization": jwtToken}
  2000  		newAuthHeaders := map[string]string{"authorization": newJwtToken}
  2001  
  2002  		// Should not allow expired JWTs
  2003  		ts.Run(t, []test.TestCase{
  2004  			{Headers: authHeaders, Code: http.StatusOK, Delay: 1100 * time.Millisecond},
  2005  			{Headers: authHeaders, Code: http.StatusUnauthorized},
  2006  			{Headers: newAuthHeaders, Code: http.StatusOK},
  2007  		}...)
  2008  	})
  2009  
  2010  }