github.com/lingyao2333/mo-zero@v1.4.1/rest/server_test.go (about)

     1  package rest
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"os"
    10  	"strings"
    11  	"sync/atomic"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/lingyao2333/mo-zero/core/conf"
    16  	"github.com/lingyao2333/mo-zero/core/logx"
    17  	"github.com/lingyao2333/mo-zero/rest/chain"
    18  	"github.com/lingyao2333/mo-zero/rest/httpx"
    19  	"github.com/lingyao2333/mo-zero/rest/internal/cors"
    20  	"github.com/lingyao2333/mo-zero/rest/router"
    21  	"github.com/stretchr/testify/assert"
    22  )
    23  
    24  func TestNewServer(t *testing.T) {
    25  	writer := logx.Reset()
    26  	defer logx.SetWriter(writer)
    27  	logx.SetWriter(logx.NewWriter(io.Discard))
    28  
    29  	const configYaml = `
    30  Name: foo
    31  Port: 54321
    32  `
    33  	var cnf RestConf
    34  	assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
    35  
    36  	tests := []struct {
    37  		c    RestConf
    38  		opts []RunOption
    39  		fail bool
    40  	}{
    41  		{
    42  			c:    RestConf{},
    43  			opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
    44  		},
    45  		{
    46  			c:    cnf,
    47  			opts: []RunOption{WithRouter(mockedRouter{})},
    48  		},
    49  		{
    50  			c:    cnf,
    51  			opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
    52  		},
    53  		{
    54  			c:    cnf,
    55  			opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
    56  		},
    57  		{
    58  			c:    cnf,
    59  			opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
    60  		},
    61  		{
    62  			c:    cnf,
    63  			opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
    64  		},
    65  	}
    66  
    67  	for _, test := range tests {
    68  		var svr *Server
    69  		var err error
    70  		if test.fail {
    71  			_, err = NewServer(test.c, test.opts...)
    72  			assert.NotNil(t, err)
    73  			continue
    74  		} else {
    75  			svr = MustNewServer(test.c, test.opts...)
    76  		}
    77  
    78  		svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
    79  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    80  				next.ServeHTTP(w, r)
    81  			})
    82  		}))
    83  		svr.AddRoute(Route{
    84  			Method:  http.MethodGet,
    85  			Path:    "/",
    86  			Handler: nil,
    87  		}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
    88  			WithJwtTransition("preivous", "thenewone"))
    89  
    90  		func() {
    91  			defer func() {
    92  				p := recover()
    93  				switch v := p.(type) {
    94  				case error:
    95  					assert.Equal(t, "foo", v.Error())
    96  				default:
    97  					t.Fail()
    98  				}
    99  			}()
   100  
   101  			svr.Start()
   102  			svr.Stop()
   103  		}()
   104  	}
   105  }
   106  
   107  func TestWithMaxBytes(t *testing.T) {
   108  	const maxBytes = 1000
   109  	var fr featuredRoutes
   110  	WithMaxBytes(maxBytes)(&fr)
   111  	assert.Equal(t, int64(maxBytes), fr.maxBytes)
   112  }
   113  
   114  func TestWithMiddleware(t *testing.T) {
   115  	m := make(map[string]string)
   116  	rt := router.NewRouter()
   117  	handler := func(w http.ResponseWriter, r *http.Request) {
   118  		var v struct {
   119  			Nickname string `form:"nickname"`
   120  			Zipcode  int64  `form:"zipcode"`
   121  		}
   122  
   123  		err := httpx.Parse(r, &v)
   124  		assert.Nil(t, err)
   125  		_, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode))
   126  		assert.Nil(t, err)
   127  	}
   128  	rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc {
   129  		return func(w http.ResponseWriter, r *http.Request) {
   130  			var v struct {
   131  				Name string `path:"name"`
   132  				Year string `path:"year"`
   133  			}
   134  			assert.Nil(t, httpx.ParsePath(r, &v))
   135  			m[v.Name] = v.Year
   136  			next.ServeHTTP(w, r)
   137  		}
   138  	}, Route{
   139  		Method:  http.MethodGet,
   140  		Path:    "/first/:name/:year",
   141  		Handler: handler,
   142  	}, Route{
   143  		Method:  http.MethodGet,
   144  		Path:    "/second/:name/:year",
   145  		Handler: handler,
   146  	})
   147  
   148  	urls := []string{
   149  		"http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
   150  		"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
   151  	}
   152  	for _, route := range rs {
   153  		assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
   154  	}
   155  	for _, url := range urls {
   156  		r, err := http.NewRequest(http.MethodGet, url, nil)
   157  		assert.Nil(t, err)
   158  
   159  		rr := httptest.NewRecorder()
   160  		rt.ServeHTTP(rr, r)
   161  
   162  		assert.Equal(t, "whatever:200000", rr.Body.String())
   163  	}
   164  
   165  	assert.EqualValues(t, map[string]string{
   166  		"kevin": "2017",
   167  		"wan":   "2020",
   168  	}, m)
   169  }
   170  
   171  func TestMultiMiddlewares(t *testing.T) {
   172  	m := make(map[string]string)
   173  	rt := router.NewRouter()
   174  	handler := func(w http.ResponseWriter, r *http.Request) {
   175  		var v struct {
   176  			Nickname string `form:"nickname"`
   177  			Zipcode  int64  `form:"zipcode"`
   178  		}
   179  
   180  		err := httpx.Parse(r, &v)
   181  		assert.Nil(t, err)
   182  		_, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
   183  		assert.Nil(t, err)
   184  	}
   185  	rs := WithMiddlewares([]Middleware{
   186  		func(next http.HandlerFunc) http.HandlerFunc {
   187  			return func(w http.ResponseWriter, r *http.Request) {
   188  				var v struct {
   189  					Name string `path:"name"`
   190  					Year string `path:"year"`
   191  				}
   192  				assert.Nil(t, httpx.ParsePath(r, &v))
   193  				m[v.Name] = v.Year
   194  				next.ServeHTTP(w, r)
   195  			}
   196  		},
   197  		func(next http.HandlerFunc) http.HandlerFunc {
   198  			return func(w http.ResponseWriter, r *http.Request) {
   199  				var v struct {
   200  					Name    string `form:"nickname"`
   201  					Zipcode string `form:"zipcode"`
   202  				}
   203  				assert.Nil(t, httpx.ParseForm(r, &v))
   204  				assert.NotEmpty(t, m)
   205  				m[v.Name] = v.Zipcode + v.Zipcode
   206  				next.ServeHTTP(w, r)
   207  			}
   208  		},
   209  		ToMiddleware(func(next http.Handler) http.Handler {
   210  			return next
   211  		}),
   212  	}, Route{
   213  		Method:  http.MethodGet,
   214  		Path:    "/first/:name/:year",
   215  		Handler: handler,
   216  	}, Route{
   217  		Method:  http.MethodGet,
   218  		Path:    "/second/:name/:year",
   219  		Handler: handler,
   220  	})
   221  
   222  	urls := []string{
   223  		"http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
   224  		"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
   225  	}
   226  	for _, route := range rs {
   227  		assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
   228  	}
   229  	for _, url := range urls {
   230  		r, err := http.NewRequest(http.MethodGet, url, nil)
   231  		assert.Nil(t, err)
   232  
   233  		rr := httptest.NewRecorder()
   234  		rt.ServeHTTP(rr, r)
   235  
   236  		assert.Equal(t, "whatever:200000200000", rr.Body.String())
   237  	}
   238  
   239  	assert.EqualValues(t, map[string]string{
   240  		"kevin":    "2017",
   241  		"wan":      "2020",
   242  		"whatever": "200000200000",
   243  	}, m)
   244  }
   245  
   246  func TestWithPrefix(t *testing.T) {
   247  	fr := featuredRoutes{
   248  		routes: []Route{
   249  			{
   250  				Path: "/hello",
   251  			},
   252  			{
   253  				Path: "/world",
   254  			},
   255  		},
   256  	}
   257  	WithPrefix("/api")(&fr)
   258  	vals := make([]string, 0, len(fr.routes))
   259  	for _, r := range fr.routes {
   260  		vals = append(vals, r.Path)
   261  	}
   262  	assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals)
   263  }
   264  
   265  func TestWithPriority(t *testing.T) {
   266  	var fr featuredRoutes
   267  	WithPriority()(&fr)
   268  	assert.True(t, fr.priority)
   269  }
   270  
   271  func TestWithTimeout(t *testing.T) {
   272  	var fr featuredRoutes
   273  	WithTimeout(time.Hour)(&fr)
   274  	assert.Equal(t, time.Hour, fr.timeout)
   275  }
   276  
   277  func TestWithTLSConfig(t *testing.T) {
   278  	const configYaml = `
   279  Name: foo
   280  Port: 54321
   281  `
   282  	var cnf RestConf
   283  	assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
   284  
   285  	testConfig := &tls.Config{
   286  		CipherSuites: []uint16{
   287  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   288  		},
   289  	}
   290  
   291  	testCases := []struct {
   292  		c    RestConf
   293  		opts []RunOption
   294  		res  *tls.Config
   295  	}{
   296  		{
   297  			c:    cnf,
   298  			opts: []RunOption{WithTLSConfig(testConfig)},
   299  			res:  testConfig,
   300  		},
   301  		{
   302  			c:    cnf,
   303  			opts: []RunOption{WithUnsignedCallback(nil)},
   304  			res:  nil,
   305  		},
   306  	}
   307  
   308  	for _, testCase := range testCases {
   309  		svr, err := NewServer(testCase.c, testCase.opts...)
   310  		assert.Nil(t, err)
   311  		assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
   312  	}
   313  }
   314  
   315  func TestWithCors(t *testing.T) {
   316  	const configYaml = `
   317  Name: foo
   318  Port: 54321
   319  `
   320  	var cnf RestConf
   321  	assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
   322  	rt := router.NewRouter()
   323  	svr, err := NewServer(cnf, WithRouter(rt))
   324  	assert.Nil(t, err)
   325  	defer svr.Stop()
   326  
   327  	opt := WithCors("local")
   328  	opt(svr)
   329  }
   330  
   331  func TestWithCustomCors(t *testing.T) {
   332  	const configYaml = `
   333  Name: foo
   334  Port: 54321
   335  `
   336  	var cnf RestConf
   337  	assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
   338  	rt := router.NewRouter()
   339  	svr, err := NewServer(cnf, WithRouter(rt))
   340  	assert.Nil(t, err)
   341  
   342  	opt := WithCustomCors(func(header http.Header) {
   343  		header.Set("foo", "bar")
   344  	}, func(w http.ResponseWriter) {
   345  		w.WriteHeader(http.StatusOK)
   346  	}, "local")
   347  	opt(svr)
   348  }
   349  
   350  func TestServer_PrintRoutes(t *testing.T) {
   351  	const (
   352  		configYaml = `
   353  Name: foo
   354  Port: 54321
   355  `
   356  		expect = `Routes:
   357    GET /bar
   358    GET /foo
   359    GET /foo/:bar
   360    GET /foo/:bar/baz
   361  `
   362  	)
   363  
   364  	var cnf RestConf
   365  	assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
   366  
   367  	svr, err := NewServer(cnf)
   368  	assert.Nil(t, err)
   369  
   370  	svr.AddRoutes([]Route{
   371  		{
   372  			Method:  http.MethodGet,
   373  			Path:    "/foo",
   374  			Handler: http.NotFound,
   375  		},
   376  		{
   377  			Method:  http.MethodGet,
   378  			Path:    "/bar",
   379  			Handler: http.NotFound,
   380  		},
   381  		{
   382  			Method:  http.MethodGet,
   383  			Path:    "/foo/:bar",
   384  			Handler: http.NotFound,
   385  		},
   386  		{
   387  			Method:  http.MethodGet,
   388  			Path:    "/foo/:bar/baz",
   389  			Handler: http.NotFound,
   390  		},
   391  	})
   392  
   393  	old := os.Stdout
   394  	r, w, err := os.Pipe()
   395  	assert.Nil(t, err)
   396  	os.Stdout = w
   397  	defer func() {
   398  		os.Stdout = old
   399  	}()
   400  
   401  	svr.PrintRoutes()
   402  	ch := make(chan string)
   403  
   404  	go func() {
   405  		var buf strings.Builder
   406  		io.Copy(&buf, r)
   407  		ch <- buf.String()
   408  	}()
   409  
   410  	w.Close()
   411  	out := <-ch
   412  	assert.Equal(t, expect, out)
   413  }
   414  
   415  func TestServer_Routes(t *testing.T) {
   416  	const (
   417  		configYaml = `
   418  Name: foo
   419  Port: 54321
   420  `
   421  		expect = `GET /foo GET /bar GET /foo/:bar GET /foo/:bar/baz`
   422  	)
   423  
   424  	var cnf RestConf
   425  	assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
   426  
   427  	svr, err := NewServer(cnf)
   428  	assert.Nil(t, err)
   429  
   430  	svr.AddRoutes([]Route{
   431  		{
   432  			Method:  http.MethodGet,
   433  			Path:    "/foo",
   434  			Handler: http.NotFound,
   435  		},
   436  		{
   437  			Method:  http.MethodGet,
   438  			Path:    "/bar",
   439  			Handler: http.NotFound,
   440  		},
   441  		{
   442  			Method:  http.MethodGet,
   443  			Path:    "/foo/:bar",
   444  			Handler: http.NotFound,
   445  		},
   446  		{
   447  			Method:  http.MethodGet,
   448  			Path:    "/foo/:bar/baz",
   449  			Handler: http.NotFound,
   450  		},
   451  	})
   452  
   453  	routes := svr.Routes()
   454  	var buf strings.Builder
   455  	for i := 0; i < len(routes); i++ {
   456  		buf.WriteString(routes[i].Method)
   457  		buf.WriteString(" ")
   458  		buf.WriteString(routes[i].Path)
   459  		buf.WriteString(" ")
   460  	}
   461  
   462  	assert.Equal(t, expect, strings.Trim(buf.String(), " "))
   463  }
   464  
   465  func TestHandleError(t *testing.T) {
   466  	assert.NotPanics(t, func() {
   467  		handleError(nil)
   468  		handleError(http.ErrServerClosed)
   469  	})
   470  }
   471  
   472  func TestValidateSecret(t *testing.T) {
   473  	assert.Panics(t, func() {
   474  		validateSecret("short")
   475  	})
   476  }
   477  
   478  func TestServer_WithChain(t *testing.T) {
   479  	var called int32
   480  	middleware1 := func() func(http.Handler) http.Handler {
   481  		return func(next http.Handler) http.Handler {
   482  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   483  				atomic.AddInt32(&called, 1)
   484  				next.ServeHTTP(w, r)
   485  				atomic.AddInt32(&called, 1)
   486  			})
   487  		}
   488  	}
   489  	middleware2 := func() func(http.Handler) http.Handler {
   490  		return func(next http.Handler) http.Handler {
   491  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   492  				atomic.AddInt32(&called, 1)
   493  				next.ServeHTTP(w, r)
   494  				atomic.AddInt32(&called, 1)
   495  			})
   496  		}
   497  	}
   498  
   499  	server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
   500  	server.AddRoutes(
   501  		[]Route{
   502  			{
   503  				Method: http.MethodGet,
   504  				Path:   "/",
   505  				Handler: func(_ http.ResponseWriter, _ *http.Request) {
   506  					atomic.AddInt32(&called, 1)
   507  				},
   508  			},
   509  		},
   510  	)
   511  	rt := router.NewRouter()
   512  	assert.Nil(t, server.ngin.bindRoutes(rt))
   513  	req, err := http.NewRequest(http.MethodGet, "/", http.NoBody)
   514  	assert.Nil(t, err)
   515  	rt.ServeHTTP(httptest.NewRecorder(), req)
   516  	assert.Equal(t, int32(5), atomic.LoadInt32(&called))
   517  }
   518  
   519  func TestServer_WithCors(t *testing.T) {
   520  	var called int32
   521  	middleware := func(next http.Handler) http.Handler {
   522  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   523  			atomic.AddInt32(&called, 1)
   524  			next.ServeHTTP(w, r)
   525  		})
   526  	}
   527  	r := router.NewRouter()
   528  	assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler())))
   529  
   530  	cr := &corsRouter{
   531  		Router:     r,
   532  		middleware: cors.Middleware(nil, "*"),
   533  	}
   534  	req := httptest.NewRequest(http.MethodOptions, "/", http.NoBody)
   535  	cr.ServeHTTP(httptest.NewRecorder(), req)
   536  	assert.Equal(t, int32(0), atomic.LoadInt32(&called))
   537  }