github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/mw_request_signing_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"crypto/x509"
     5  	"encoding/pem"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"net/url"
     9  	"strings"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  
    16  	"github.com/TykTechnologies/tyk/apidef"
    17  	"github.com/TykTechnologies/tyk/test"
    18  	"github.com/TykTechnologies/tyk/user"
    19  	"github.com/justinas/alice"
    20  )
    21  
    22  var algoList = [4]string{"hmac-sha1", "hmac-sha256", "hmac-sha384", "hmac-sha512"}
    23  
    24  func getMiddlewareChain(spec *APISpec) http.Handler {
    25  	remote, _ := url.Parse(TestHttpAny)
    26  	proxy := TykNewSingleHostReverseProxy(remote, spec, nil)
    27  	proxyHandler := ProxyHandler(proxy, spec)
    28  	baseMid := BaseMiddleware{Spec: spec, Proxy: proxy}
    29  	chain := alice.New(mwList(
    30  		&IPWhiteListMiddleware{baseMid},
    31  		&IPBlackListMiddleware{BaseMiddleware: baseMid},
    32  		&RequestSigning{BaseMiddleware: baseMid},
    33  		&HTTPSignatureValidationMiddleware{BaseMiddleware: baseMid},
    34  		&VersionCheck{BaseMiddleware: baseMid},
    35  	)...).Then(proxyHandler)
    36  	return chain
    37  }
    38  
    39  func generateSession(algo, data string) string {
    40  	sessionKey := CreateSession(func(s *user.SessionState) {
    41  		if strings.HasPrefix(algo, "rsa") {
    42  			s.RSACertificateId = data
    43  			s.EnableHTTPSignatureValidation = true
    44  		} else {
    45  			s.HmacSecret = data
    46  			s.HMACEnabled = true
    47  		}
    48  		s.Mutex = &sync.RWMutex{}
    49  	})
    50  
    51  	return sessionKey
    52  }
    53  
    54  func generateSpec(algo string, data string, sessionKey string, headerList []string) (specs []*APISpec) {
    55  	return BuildAndLoadAPI(func(spec *APISpec) {
    56  		spec.Proxy.ListenPath = "/test"
    57  		spec.UseKeylessAccess = true
    58  		spec.EnableSignatureChecking = true
    59  		spec.RequestSigning.IsEnabled = true
    60  		spec.RequestSigning.KeyId = sessionKey
    61  		spec.HmacAllowedClockSkew = 5000
    62  
    63  		if strings.HasPrefix(algo, "rsa") {
    64  			spec.RequestSigning.CertificateId = data
    65  		} else {
    66  			spec.RequestSigning.Secret = data
    67  		}
    68  		spec.RequestSigning.Algorithm = algo
    69  		if headerList != nil {
    70  			spec.RequestSigning.HeaderList = headerList
    71  		}
    72  
    73  	})
    74  }
    75  
    76  func TestHMACRequestSigning(t *testing.T) {
    77  	ts := StartTest()
    78  	defer ts.Close()
    79  	secret := "9879879878787878"
    80  
    81  	for _, algo := range algoList {
    82  		name := "Test with " + algo
    83  		t.Run(name, func(t *testing.T) {
    84  			sessionKey := generateSession(algo, secret)
    85  			specs := generateSpec(algo, secret, sessionKey, nil)
    86  
    87  			req := TestReq(t, "get", "/test/get", nil)
    88  			recorder := httptest.NewRecorder()
    89  			chain := getMiddlewareChain(specs[0])
    90  			chain.ServeHTTP(recorder, req)
    91  
    92  			if recorder.Code != 200 {
    93  				t.Error("HMAC request signing failed with error:", recorder.Body.String())
    94  			}
    95  		})
    96  	}
    97  
    98  	t.Run("Empty secret", func(t *testing.T) {
    99  		algo := "hmac-sha256"
   100  		secret := ""
   101  
   102  		sessionKey := generateSession(algo, secret)
   103  		specs := generateSpec(algo, secret, sessionKey, nil)
   104  
   105  		recorder := httptest.NewRecorder()
   106  		chain := getMiddlewareChain(specs[0])
   107  
   108  		req := TestReq(t, "get", "/test/get", nil)
   109  		chain.ServeHTTP(recorder, req)
   110  
   111  		if recorder.Code != 500 {
   112  			t.Error("Expected status code 500 got ", recorder.Code)
   113  		}
   114  	})
   115  
   116  	t.Run("Invalid secret", func(t *testing.T) {
   117  		algo := "hmac-sha256"
   118  		secret := "12345"
   119  
   120  		sessionKey := generateSession(algo, secret)
   121  		specs := generateSpec(algo, "789", sessionKey, nil)
   122  
   123  		recorder := httptest.NewRecorder()
   124  		chain := getMiddlewareChain(specs[0])
   125  
   126  		req := TestReq(t, "get", "/test/get", nil)
   127  		chain.ServeHTTP(recorder, req)
   128  
   129  		if recorder.Code != 400 {
   130  			t.Error("Expected status code 400 got ", recorder.Code)
   131  		}
   132  	})
   133  
   134  	t.Run("Valid Custom headerList", func(t *testing.T) {
   135  		algo := "hmac-sha1"
   136  		headerList := []string{"foo", "date"}
   137  
   138  		sessionKey := generateSession(algo, secret)
   139  		specs := generateSpec(algo, secret, sessionKey, headerList)
   140  
   141  		recorder := httptest.NewRecorder()
   142  		chain := getMiddlewareChain(specs[0])
   143  
   144  		req := TestReq(t, "get", "/test/get", nil)
   145  		refDate := "Mon, 02 Jan 2006 15:04:05 MST"
   146  		tim := time.Now().Format(refDate)
   147  		req.Header.Add("foo", "bar")
   148  		req.Header.Add("date", tim)
   149  		chain.ServeHTTP(recorder, req)
   150  
   151  		if recorder.Code != 200 {
   152  			t.Error("HMAC request signing failed with error:", recorder.Body.String())
   153  		}
   154  	})
   155  
   156  	t.Run("Invalid Custom headerList", func(t *testing.T) {
   157  		algo := "hmac-sha1"
   158  		headerList := []string{"foo"}
   159  
   160  		sessionKey := generateSession(algo, secret)
   161  		specs := generateSpec(algo, secret, sessionKey, headerList)
   162  
   163  		req := TestReq(t, "get", "/test/get", nil)
   164  		recorder := httptest.NewRecorder()
   165  		chain := getMiddlewareChain(specs[0])
   166  		chain.ServeHTTP(recorder, req)
   167  
   168  		if recorder.Code != 200 {
   169  			t.Error("HMAC request signing failed with error:", recorder.Body.String())
   170  		}
   171  	})
   172  
   173  	t.Run("Invalid algorithm", func(t *testing.T) {
   174  		algo := "hmac-123"
   175  		sessionKey := generateSession(algo, secret)
   176  		specs := generateSpec(algo, secret, sessionKey, nil)
   177  
   178  		req := TestReq(t, "get", "/test/get", nil)
   179  		recorder := httptest.NewRecorder()
   180  		chain := getMiddlewareChain(specs[0])
   181  		chain.ServeHTTP(recorder, req)
   182  
   183  		if recorder.Code != 500 {
   184  			t.Error("Expected status code 500 got ", recorder.Code)
   185  		}
   186  	})
   187  
   188  	t.Run("Invalid Date field", func(t *testing.T) {
   189  		algo := "hmac-sha1"
   190  		sessionKey := generateSession(algo, secret)
   191  		specs := generateSpec(algo, secret, sessionKey, nil)
   192  
   193  		req := TestReq(t, "get", "/test/get", nil)
   194  		// invalid date
   195  		req.Header.Add("date", "Mon, 02 Jan 2006 15:04:05 GMT")
   196  
   197  		recorder := httptest.NewRecorder()
   198  		chain := getMiddlewareChain(specs[0])
   199  		chain.ServeHTTP(recorder, req)
   200  
   201  		if recorder.Code != 400 {
   202  			t.Error("Expected status code 400 got ", recorder.Code)
   203  		}
   204  	})
   205  
   206  	t.Run("Custom Signature header", func(t *testing.T) {
   207  		algo := "hmac-sha256"
   208  
   209  		sessionKey := generateSession(algo, secret)
   210  		specs := generateSpec(algo, secret, sessionKey, nil)
   211  		api := specs[0]
   212  
   213  		api.AuthConfigs = make(map[string]apidef.AuthConfig)
   214  		api.AuthConfigs["hmac"] = apidef.AuthConfig{
   215  			AuthHeaderName: "something",
   216  		}
   217  
   218  		api.RequestSigning.SignatureHeader = "something"
   219  
   220  		recorder := httptest.NewRecorder()
   221  		chain := getMiddlewareChain(api)
   222  
   223  		req := TestReq(t, "get", "/test/get", nil)
   224  		chain.ServeHTTP(recorder, req)
   225  
   226  		if recorder.Code != 200 {
   227  			t.Error("HMAC request signing failed with error:", recorder.Body.String())
   228  		}
   229  	})
   230  }
   231  
   232  func TestRSARequestSigning(t *testing.T) {
   233  	ts := StartTest()
   234  	defer ts.Close()
   235  
   236  	_, _, combinedPem, cert := genServerCertificate()
   237  	privCertId, _ := CertificateManager.Add(combinedPem, "")
   238  	defer CertificateManager.Delete(privCertId, "")
   239  
   240  	x509Cert, _ := x509.ParseCertificate(cert.Certificate[0])
   241  	pubDer, _ := x509.MarshalPKIXPublicKey(x509Cert.PublicKey)
   242  	pubPem := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDer})
   243  	pubCertId, _ := CertificateManager.Add(pubPem, "")
   244  	defer CertificateManager.Delete(pubCertId, "")
   245  
   246  	name := "Test with rsa-sha256"
   247  	t.Run(name, func(t *testing.T) {
   248  		algo := "rsa-sha256"
   249  		sessionKey := generateSession(algo, pubCertId)
   250  		specs := generateSpec(algo, privCertId, sessionKey, nil)
   251  
   252  		req := TestReq(t, "get", "/test/get", nil)
   253  		recorder := httptest.NewRecorder()
   254  		chain := getMiddlewareChain(specs[0])
   255  		chain.ServeHTTP(recorder, req)
   256  
   257  		if recorder.Code != 200 {
   258  			t.Error("RSA request signing failed with error:", recorder.Body.String())
   259  		}
   260  	})
   261  
   262  	t.Run("Invalid certificate id", func(t *testing.T) {
   263  		algo := "rsa-sha256"
   264  		sessionKey := generateSession(algo, pubCertId)
   265  		specs := generateSpec(algo, "12345", sessionKey, nil)
   266  
   267  		req := TestReq(t, "get", "/test/get", nil)
   268  		recorder := httptest.NewRecorder()
   269  		chain := getMiddlewareChain(specs[0])
   270  		chain.ServeHTTP(recorder, req)
   271  
   272  		if recorder.Code != 500 {
   273  			t.Error("Expected status code 500 got ", recorder.Code)
   274  		}
   275  	})
   276  
   277  	t.Run("empty certificate id", func(t *testing.T) {
   278  		algo := "rsa-sha256"
   279  		sessionKey := generateSession(algo, pubCertId)
   280  		specs := generateSpec(algo, "", sessionKey, nil)
   281  
   282  		req := TestReq(t, "get", "/test/get", nil)
   283  		recorder := httptest.NewRecorder()
   284  		chain := getMiddlewareChain(specs[0])
   285  		chain.ServeHTTP(recorder, req)
   286  
   287  		if recorder.Code != 500 {
   288  			t.Error("Expected status code 500 got ", recorder.Code)
   289  		}
   290  	})
   291  
   292  	t.Run("Invalid algorithm", func(t *testing.T) {
   293  		algo := "rsa-123"
   294  		sessionKey := generateSession(algo, pubCertId)
   295  		specs := generateSpec(algo, privCertId, sessionKey, nil)
   296  
   297  		req := TestReq(t, "get", "/test/get", nil)
   298  		recorder := httptest.NewRecorder()
   299  		chain := getMiddlewareChain(specs[0])
   300  		chain.ServeHTTP(recorder, req)
   301  
   302  		if recorder.Code != 500 {
   303  			t.Error("Expected status code 500 got ", recorder.Code)
   304  		}
   305  	})
   306  
   307  	t.Run("Invalid Date field", func(t *testing.T) {
   308  		algo := "rsa-sha256"
   309  		sessionKey := generateSession(algo, pubCertId)
   310  		specs := generateSpec(algo, privCertId, sessionKey, nil)
   311  
   312  		req := TestReq(t, "get", "/test/get", nil)
   313  		req.Header.Add("date", "Mon, 02 Jan 2006 15:04:05 GMT")
   314  		recorder := httptest.NewRecorder()
   315  		chain := getMiddlewareChain(specs[0])
   316  		chain.ServeHTTP(recorder, req)
   317  
   318  		if recorder.Code != 400 {
   319  			t.Error("Expected status code 400 got ", recorder.Code)
   320  		}
   321  	})
   322  
   323  	t.Run("Custom headerList", func(t *testing.T) {
   324  		algo := "rsa-sha256"
   325  		headerList := []string{"foo", "date"}
   326  
   327  		sessionKey := generateSession(algo, pubCertId)
   328  		specs := generateSpec(algo, privCertId, sessionKey, headerList)
   329  
   330  		req := TestReq(t, "get", "/test/get", nil)
   331  
   332  		refDate := "Mon, 02 Jan 2006 15:04:05 MST"
   333  		tim := time.Now().Format(refDate)
   334  		req.Header.Add("foo", "bar")
   335  		req.Header.Add("date", tim)
   336  
   337  		recorder := httptest.NewRecorder()
   338  		chain := getMiddlewareChain(specs[0])
   339  		chain.ServeHTTP(recorder, req)
   340  
   341  		if recorder.Code != 200 {
   342  			t.Error("RSA request signing failed with error ", recorder.Body.String())
   343  		}
   344  	})
   345  
   346  	t.Run("Non-existing Custom headers", func(t *testing.T) {
   347  		algo := "rsa-sha256"
   348  		headerList := []string{"foo"}
   349  
   350  		sessionKey := generateSession(algo, pubCertId)
   351  		specs := generateSpec(algo, privCertId, sessionKey, headerList)
   352  
   353  		req := TestReq(t, "get", "/test/get", nil)
   354  
   355  		recorder := httptest.NewRecorder()
   356  		chain := getMiddlewareChain(specs[0])
   357  		chain.ServeHTTP(recorder, req)
   358  
   359  		if recorder.Code != 200 {
   360  			t.Error("RSA request signing failed with error ", recorder.Body.String())
   361  		}
   362  	})
   363  
   364  	t.Run("Custom Signature header", func(t *testing.T) {
   365  		algo := "rsa-sha256"
   366  
   367  		sessionKey := generateSession(algo, pubCertId)
   368  		specs := generateSpec(algo, privCertId, sessionKey, nil)
   369  
   370  		api := specs[0]
   371  		api.AuthConfigs = make(map[string]apidef.AuthConfig)
   372  		api.AuthConfigs["hmac"] = apidef.AuthConfig{
   373  			AuthHeaderName: "something",
   374  		}
   375  
   376  		api.RequestSigning.SignatureHeader = "something"
   377  
   378  		req := TestReq(t, "get", "/test/get", nil)
   379  
   380  		recorder := httptest.NewRecorder()
   381  		chain := getMiddlewareChain(specs[0])
   382  		chain.ServeHTTP(recorder, req)
   383  
   384  		if recorder.Code != 200 {
   385  			t.Error("RSA request signing failed with error ", recorder.Body.String())
   386  		}
   387  	})
   388  }
   389  
   390  func TestStripListenPath(t *testing.T) {
   391  	ts := StartTest()
   392  	defer ts.Close()
   393  
   394  	algo := "hmac-sha256"
   395  	secret := "12345"
   396  	sessionKey := generateSession(algo, secret)
   397  
   398  	t.Run("Off", func(t *testing.T) {
   399  		specs := generateSpec(algo, secret, sessionKey, nil)
   400  		req := TestReq(t, "get", "/test/get", nil)
   401  
   402  		recorder := httptest.NewRecorder()
   403  		chain := getMiddlewareChain(specs[0])
   404  		chain.ServeHTTP(recorder, req)
   405  
   406  		if recorder.Code != 200 {
   407  			t.Error("Expected status code 200 got ", recorder.Code)
   408  		}
   409  	})
   410  
   411  	t.Run("On", func(t *testing.T) {
   412  		BuildAndLoadAPI(func(spec *APISpec) {
   413  			spec.APIID = "protected"
   414  			spec.Proxy.ListenPath = "/protected"
   415  			spec.EnableSignatureChecking = true
   416  			spec.UseKeylessAccess = false
   417  			spec.Proxy.StripListenPath = true
   418  		}, func(spec *APISpec) {
   419  			spec.APIID = "trailingSlash"
   420  			spec.Proxy.ListenPath = "/trailingSlash/"
   421  			spec.Proxy.StripListenPath = true
   422  			spec.RequestSigning.IsEnabled = true
   423  			spec.RequestSigning.Secret = secret
   424  			spec.RequestSigning.KeyId = sessionKey
   425  			spec.RequestSigning.Algorithm = algo
   426  			spec.Proxy.TargetURL = ts.URL
   427  		}, func(spec *APISpec) {
   428  			spec.APIID = "withoutTrailingSlash"
   429  			spec.Proxy.ListenPath = "/withoutTrailingSlash"
   430  			spec.Proxy.StripListenPath = true
   431  			spec.RequestSigning.IsEnabled = true
   432  			spec.RequestSigning.Secret = secret
   433  			spec.RequestSigning.KeyId = sessionKey
   434  			spec.RequestSigning.Algorithm = algo
   435  			spec.Proxy.TargetURL = ts.URL
   436  		})
   437  
   438  		ts.Run(t, []test.TestCase{
   439  			{Path: "/trailingSlash/protected/get", Method: http.MethodGet, Code: http.StatusOK},
   440  			{Path: "/withoutTrailingSlash/protected/get", Method: http.MethodGet, Code: http.StatusOK},
   441  		}...)
   442  	})
   443  }
   444  
   445  func TestWithURLRewrite(t *testing.T) {
   446  	ts := StartTest()
   447  	defer ts.Close()
   448  
   449  	algo := "hmac-sha256"
   450  	secret := "12345"
   451  
   452  	sessionKey := CreateSession(func(session *user.SessionState) {
   453  		session.EnableHTTPSignatureValidation = true
   454  		session.HmacSecret = secret
   455  		session.Mutex = &sync.RWMutex{}
   456  	})
   457  
   458  	t.Run("looping", func(t *testing.T) {
   459  		BuildAndLoadAPI(func(spec *APISpec) {
   460  			spec.APIID = "protected"
   461  			spec.Proxy.ListenPath = "/protected"
   462  			spec.EnableSignatureChecking = true
   463  			spec.UseKeylessAccess = false
   464  		}, func(spec *APISpec) {
   465  			spec.APIID = "test"
   466  			spec.Proxy.ListenPath = "/test"
   467  			spec.Proxy.StripListenPath = true
   468  			spec.RequestSigning.IsEnabled = true
   469  			spec.RequestSigning.Secret = secret
   470  			spec.RequestSigning.KeyId = sessionKey
   471  			spec.RequestSigning.Algorithm = algo
   472  
   473  			version := spec.VersionData.Versions["v1"]
   474  			version.UseExtendedPaths = true
   475  			version.ExtendedPaths.URLRewrite = []apidef.URLRewriteMeta{
   476  				{
   477  					Path:         "/get",
   478  					Method:       "GET",
   479  					MatchPattern: "/get",
   480  					RewriteTo:    "tyk://protected/get",
   481  				},
   482  				{
   483  					Path:         "/self",
   484  					Method:       "GET",
   485  					MatchPattern: "/self",
   486  					RewriteTo:    "tyk://protected/test/get",
   487  				},
   488  			}
   489  
   490  			spec.VersionData.Versions["v1"] = version
   491  		})
   492  
   493  		ts.Run(t, []test.TestCase{
   494  			{Path: "/test/get", Method: http.MethodGet, Code: http.StatusOK},
   495  			// ensure listen path is not stripped in case url rewrite
   496  			{Path: "/test/self", Method: http.MethodGet, Code: http.StatusOK},
   497  		}...)
   498  	})
   499  
   500  	t.Run("external", func(t *testing.T) {
   501  		BuildAndLoadAPI(func(spec *APISpec) {
   502  			spec.APIID = "protected"
   503  			spec.Proxy.ListenPath = "/protected"
   504  			spec.EnableSignatureChecking = true
   505  		}, func(spec *APISpec) {
   506  			spec.APIID = "test"
   507  			spec.Proxy.ListenPath = "/test/"
   508  			spec.Proxy.StripListenPath = true
   509  			spec.RequestSigning.IsEnabled = true
   510  			spec.RequestSigning.Secret = secret
   511  			spec.RequestSigning.KeyId = sessionKey
   512  			spec.RequestSigning.Algorithm = algo
   513  
   514  			version := spec.VersionData.Versions["v1"]
   515  			version.UseExtendedPaths = true
   516  			version.ExtendedPaths.URLRewrite = []apidef.URLRewriteMeta{
   517  				{
   518  					Path:         "get",
   519  					Method:       "GET",
   520  					MatchPattern: "get",
   521  					RewriteTo:    ts.URL + "/protected/get",
   522  				},
   523  				{
   524  					Path:         "self",
   525  					Method:       "GET",
   526  					MatchPattern: "self",
   527  					RewriteTo:    ts.URL + "/protected/test/get",
   528  				},
   529  			}
   530  
   531  			spec.VersionData.Versions["v1"] = version
   532  		})
   533  
   534  		ts.Run(t, []test.TestCase{
   535  			{Path: "/test/get", Method: http.MethodGet, Code: http.StatusOK},
   536  			// ensure listen path is not stripped in case url rewrite
   537  			{Path: "/test/self", Method: http.MethodGet, Code: http.StatusOK},
   538  		}...)
   539  	})
   540  
   541  }
   542  
   543  func TestRequestSigning_getRequestPath(t *testing.T) {
   544  	api := BuildAPI(func(spec *APISpec) {
   545  		spec.Proxy.ListenPath = "/test/"
   546  		spec.Proxy.StripListenPath = false
   547  	})[0]
   548  
   549  	rs := RequestSigning{BaseMiddleware{Spec: api}}
   550  
   551  	req, _ := http.NewRequest(http.MethodGet, "http://example.com/test/get?param1=value1", nil)
   552  
   553  	t.Run("StripListenPath=true", func(t *testing.T) {
   554  		api.Proxy.StripListenPath = true
   555  		assert.Equal(t, "/get?param1=value1", rs.getRequestPath(req))
   556  
   557  		t.Run("path is empty", func(t *testing.T) {
   558  			reqWithEmptyPath, _ := http.NewRequest(http.MethodGet, "http://example.com/test/", nil)
   559  			assert.Equal(t, "/", rs.getRequestPath(reqWithEmptyPath))
   560  		})
   561  
   562  		api.Proxy.StripListenPath = false
   563  	})
   564  
   565  	t.Run("URL rewrite", func(t *testing.T) {
   566  		rewrittenURL := &url.URL{Path: "/test/rewritten", RawQuery: "param1=value1"}
   567  		ctxSetURLRewriteTarget(req, rewrittenURL)
   568  		assert.Equal(t, "/test/rewritten?param1=value1", rs.getRequestPath(req))
   569  		ctxSetURLRewriteTarget(req, nil)
   570  	})
   571  }