github.com/arthur-befumo/witchcraft-go-server@v1.12.0/wrouter/router_test.go (about)

     1  // Copyright (c) 2018 Palantir Technologies. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package wrouter_test
    16  
    17  import (
    18  	"fmt"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"sort"
    22  	"testing"
    23  
    24  	// underscore import to use zap implementation
    25  	_ "github.com/palantir/witchcraft-go-logging/wlog-zap"
    26  	"github.com/palantir/witchcraft-go-server/wrouter"
    27  	"github.com/palantir/witchcraft-go-server/wrouter/wgorillamux"
    28  	"github.com/palantir/witchcraft-go-server/wrouter/whttprouter"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  )
    32  
    33  func TestRouterImpls(t *testing.T) {
    34  	for i, tc := range []struct {
    35  		name string
    36  		impl wrouter.RouterImpl
    37  	}{
    38  		{"wgorillamux", wgorillamux.New()},
    39  		{"whttprouter", whttprouter.New()},
    40  	} {
    41  		// create router
    42  		r := wrouter.New(tc.impl, nil)
    43  
    44  		// register routes
    45  		matched := make(map[string]bool)
    46  		err := r.Register("GET", "/foo", mustMatchHandler(t, "GET", "/foo", nil, matched))
    47  		require.NoError(t, err, "Case %d: %s", i, tc.name)
    48  		err = r.Register("GET", "/datasets/{rid}", mustMatchHandler(t, "GET", "/datasets/id-500", map[string]string{
    49  			"rid": "id-500",
    50  		}, matched))
    51  		require.NoError(t, err, "Case %d: %s", i, tc.name)
    52  
    53  		subrouter := r.Subrouter("/foo")
    54  		err = subrouter.Register("GET", "/{id}", mustMatchHandler(t, "GET", "/foo/13", map[string]string{
    55  			"id": "13",
    56  		}, matched))
    57  		require.NoError(t, err, "Case %d: %s", i, tc.name)
    58  
    59  		err = r.Register("GET", "/file/{path*}", mustMatchHandler(t, "GET", "/file/var/data/my-file.txt", map[string]string{
    60  			"path": "var/data/my-file.txt",
    61  		}, matched))
    62  		require.NoError(t, err, "Case %d: %s", i, tc.name)
    63  
    64  		wantRoutes := []wrouter.RouteSpec{
    65  			{
    66  				Method:       "GET",
    67  				PathTemplate: "/datasets/{rid}",
    68  			},
    69  			{
    70  				Method:       "GET",
    71  				PathTemplate: "/file/{path*}",
    72  			},
    73  			{
    74  				Method:       "GET",
    75  				PathTemplate: "/foo",
    76  			},
    77  			{
    78  				Method:       "GET",
    79  				PathTemplate: "/foo/{id}",
    80  			},
    81  		}
    82  		assert.Equal(t, wantRoutes, r.RegisteredRoutes(), "Case %d: %s", i, tc.name)
    83  
    84  		server := httptest.NewServer(r)
    85  		defer server.Close()
    86  
    87  		_, err = http.Get(server.URL + "/datasets/id-500")
    88  		require.NoError(t, err, "Case %d: %s", i, tc.name)
    89  		_, err = http.Get(server.URL + "/file/var/data/my-file.txt")
    90  		require.NoError(t, err, "Case %d: %s", i, tc.name)
    91  		_, err = http.Get(server.URL + "/foo")
    92  		require.NoError(t, err, "Case %d: %s", i, tc.name)
    93  		_, err = http.Get(server.URL + "/foo/13")
    94  		require.NoError(t, err, "Case %d: %s", i, tc.name)
    95  
    96  		var sortedKeys []string
    97  		for k := range matched {
    98  			sortedKeys = append(sortedKeys, k)
    99  		}
   100  		sort.Strings(sortedKeys)
   101  		for _, k := range sortedKeys {
   102  			assert.True(t, matched[k], "Case %d: %s\nMatcher not called for %s", i, tc.name, k)
   103  		}
   104  	}
   105  }
   106  
   107  // Smoke test that tests registering routes, using subrouters, listing routes, etc. on all router implementations.
   108  func TestRouterImplSmoke(t *testing.T) {
   109  	for i, tc := range []struct {
   110  		name string
   111  		impl wrouter.RouterImpl
   112  	}{
   113  		{"wgorillamux", wgorillamux.New()},
   114  		{"whttprouter", whttprouter.New()},
   115  	} {
   116  		func() {
   117  			// create router
   118  			r := wrouter.New(tc.impl, nil)
   119  
   120  			// register routes
   121  			matched := make(map[string]bool)
   122  			err := r.Register("GET", "/foo", mustMatchHandler(t, "GET", "/foo", nil, matched))
   123  			require.NoError(t, err, "Case %d: %s", i, tc.name)
   124  			err = r.Register("GET", "/datasets/{rid}", mustMatchHandler(t, "GET", "/datasets/id-500", map[string]string{
   125  				"rid": "id-500",
   126  			}, matched))
   127  			require.NoError(t, err, "Case %d: %s", i, tc.name)
   128  
   129  			subrouter := r.Subrouter("/foo")
   130  			err = subrouter.Register("GET", "/{id}", mustMatchHandler(t, "GET", "/foo/13", map[string]string{
   131  				"id": "13",
   132  			}, matched))
   133  			require.NoError(t, err, "Case %d: %s", i, tc.name)
   134  
   135  			err = r.Register("GET", "/file/{path*}", mustMatchHandler(t, "GET", "/file/var/data/my-file.txt", map[string]string{
   136  				"path": "var/data/my-file.txt",
   137  			}, matched))
   138  			require.NoError(t, err, "Case %d: %s", i, tc.name)
   139  
   140  			wantRoutes := []wrouter.RouteSpec{
   141  				{
   142  					Method:       "GET",
   143  					PathTemplate: "/datasets/{rid}",
   144  				},
   145  				{
   146  					Method:       "GET",
   147  					PathTemplate: "/file/{path*}",
   148  				},
   149  				{
   150  					Method:       "GET",
   151  					PathTemplate: "/foo",
   152  				},
   153  				{
   154  					Method:       "GET",
   155  					PathTemplate: "/foo/{id}",
   156  				},
   157  			}
   158  			assert.Equal(t, wantRoutes, r.RegisteredRoutes(), "Case %d: %s", i, tc.name)
   159  
   160  			server := httptest.NewServer(r)
   161  			defer server.Close()
   162  
   163  			_, err = http.Get(server.URL + "/datasets/id-500")
   164  			require.NoError(t, err, "Case %d: %s", i, tc.name)
   165  			_, err = http.Get(server.URL + "/file/var/data/my-file.txt")
   166  			require.NoError(t, err, "Case %d: %s", i, tc.name)
   167  			_, err = http.Get(server.URL + "/foo")
   168  			require.NoError(t, err, "Case %d: %s", i, tc.name)
   169  			_, err = http.Get(server.URL + "/foo/13")
   170  			require.NoError(t, err, "Case %d: %s", i, tc.name)
   171  
   172  			var sortedKeys []string
   173  			for k := range matched {
   174  				sortedKeys = append(sortedKeys, k)
   175  			}
   176  			sort.Strings(sortedKeys)
   177  			for _, k := range sortedKeys {
   178  				assert.True(t, matched[k], "Case %d: %s\nMatcher not called for %s", i, tc.name, k)
   179  			}
   180  		}()
   181  	}
   182  }
   183  
   184  // Tests specific cases for registering routes and calling endpoints that hit them.
   185  func TestRouterImplRouteHandling(t *testing.T) {
   186  	for i, tc := range []struct {
   187  		name   string
   188  		routes []singleRouteTest
   189  	}{
   190  		{
   191  			name: "empty path",
   192  			routes: []singleRouteTest{
   193  				{
   194  					method: "GET",
   195  					path:   "/",
   196  					reqURL: "/",
   197  				},
   198  			},
   199  		},
   200  		{
   201  			name: "paths with common starting variables",
   202  			routes: []singleRouteTest{
   203  				{
   204  					method: "GET",
   205  					path:   "/{entityId}",
   206  					reqURL: "/latest",
   207  					wantPathParams: map[string]string{
   208  						"entityId": "latest",
   209  					},
   210  				},
   211  				{
   212  					method: "GET",
   213  					path:   "/{entityId}/{date}",
   214  					reqURL: "/test-id/test-date",
   215  					wantPathParams: map[string]string{
   216  						"entityId": "test-id",
   217  						"date":     "test-date",
   218  					},
   219  				},
   220  			},
   221  		},
   222  		{
   223  			name: "paths with same path expression with only method differing",
   224  			routes: []singleRouteTest{
   225  				{
   226  					method: "GET",
   227  					path:   "/{getId*}",
   228  					reqURL: "/latest",
   229  					wantPathParams: map[string]string{
   230  						"getId": "latest",
   231  					},
   232  				},
   233  				{
   234  					method: "POST",
   235  					path:   "/{postId*}",
   236  					reqURL: "/test-id",
   237  					wantPathParams: map[string]string{
   238  						"postId": "test-id",
   239  					},
   240  				},
   241  			},
   242  		},
   243  		// Following test does not work for zhttprouter due to https://github.com/julienschmidt/httprouter/issues/183.
   244  		// The preceding test ("paths with common starting variables") demonstrates a work-around for this behavior --
   245  		// one can register a path where the shorter segment shares the path parameter variable and then only execute
   246  		// the handler if the variable value matches the desired literal path (in this case, "latest").
   247  		//{
   248  		//	name: "path with literal and variable at same level",
   249  		//	routes: []singleRouteTest{
   250  		//		{
   251  		//			method: "GET",
   252  		//			path:   "/latest",
   253  		//			reqURL: "/latest",
   254  		//		},
   255  		//		{
   256  		//			method: "GET",
   257  		//			path:   "/{entityId}/{date}",
   258  		//			reqURL: "/test-id/test-date",
   259  		//			wantPathParams: map[string]string{
   260  		//				"entityId": "test-id",
   261  		//				"date": "test-date",
   262  		//			},
   263  		//		},
   264  		//	},
   265  		//},
   266  	} {
   267  		for _, routerImpl := range []struct {
   268  			name string
   269  			impl wrouter.RouterImpl
   270  		}{
   271  			{"wgorillamux", wgorillamux.New()},
   272  			{"whttprouter", whttprouter.New()},
   273  		} {
   274  			func() {
   275  				// create router
   276  				r := wrouter.New(routerImpl.impl, nil)
   277  
   278  				// register routes
   279  				matched := make(map[string]bool)
   280  				for _, currRoute := range tc.routes {
   281  					err := r.Register(currRoute.method, currRoute.path, mustMatchHandler(t, currRoute.method, currRoute.reqURL, currRoute.wantPathParams, matched))
   282  					require.NoError(t, err, "Case %d: %s %s", i, tc.name, routerImpl.name)
   283  				}
   284  
   285  				// start server
   286  				server := httptest.NewServer(r)
   287  				defer server.Close()
   288  
   289  				// make HTTP calls
   290  				for _, currRoute := range tc.routes {
   291  					req, err := http.NewRequest(
   292  						currRoute.method,
   293  						server.URL+currRoute.reqURL,
   294  						nil,
   295  					)
   296  					require.NoError(t, err, "Case %d: %s %s", i, tc.name, routerImpl.name)
   297  					_, err = http.DefaultClient.Do(req)
   298  					require.NoError(t, err, "Case %d: %s %s", i, tc.name, routerImpl.name)
   299  				}
   300  
   301  				// verify results
   302  				var sortedKeys []string
   303  				for k := range matched {
   304  					sortedKeys = append(sortedKeys, k)
   305  				}
   306  				sort.Strings(sortedKeys)
   307  				for _, k := range sortedKeys {
   308  					assert.True(t, matched[k], "Case %d: %s %s\nMatcher not called for %s", i, tc.name, routerImpl.name, k)
   309  				}
   310  			}()
   311  		}
   312  	}
   313  }
   314  
   315  type singleRouteTest struct {
   316  	method         string
   317  	path           string
   318  	reqURL         string
   319  	wantPathParams map[string]string
   320  }
   321  
   322  func mustMatchHandler(t *testing.T, method, path string, pathVars map[string]string, matched map[string]bool) http.Handler {
   323  	matched[fmt.Sprintf("[%s] %s", method, path)] = false
   324  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   325  		assert.Equal(t, method, r.Method, "Method did not match expected for path %s", path)
   326  		assert.Equal(t, path, r.URL.Path, "Path did not match expected for path %s", path)
   327  		assert.Equal(t, pathVars, wrouter.PathParams(r), "Path params did not match expected for path %s", path)
   328  		matched[fmt.Sprintf("[%s] %s", method, path)] = true
   329  	})
   330  }