github.com/lingyao2333/mo-zero@v1.4.1/rest/engine_test.go (about)

     1  package rest
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"sync/atomic"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/lingyao2333/mo-zero/core/conf"
    13  	"github.com/lingyao2333/mo-zero/core/logx"
    14  	"github.com/stretchr/testify/assert"
    15  )
    16  
    17  func TestNewEngine(t *testing.T) {
    18  	yamls := []string{
    19  		`Name: foo
    20  Port: 54321
    21  `,
    22  		`Name: foo
    23  Port: 54321
    24  CpuThreshold: 500
    25  `,
    26  		`Name: foo
    27  Port: 54321
    28  CpuThreshold: 500
    29  Verbose: true
    30  `,
    31  	}
    32  
    33  	routes := []featuredRoutes{
    34  		{
    35  			jwt:       jwtSetting{},
    36  			signature: signatureSetting{},
    37  			routes: []Route{{
    38  				Method:  http.MethodGet,
    39  				Path:    "/",
    40  				Handler: func(w http.ResponseWriter, r *http.Request) {},
    41  			}},
    42  		},
    43  		{
    44  			priority:  true,
    45  			jwt:       jwtSetting{},
    46  			signature: signatureSetting{},
    47  			routes: []Route{{
    48  				Method:  http.MethodGet,
    49  				Path:    "/",
    50  				Handler: func(w http.ResponseWriter, r *http.Request) {},
    51  			}},
    52  		},
    53  		{
    54  			priority: true,
    55  			jwt: jwtSetting{
    56  				enabled: true,
    57  			},
    58  			signature: signatureSetting{},
    59  			routes: []Route{{
    60  				Method:  http.MethodGet,
    61  				Path:    "/",
    62  				Handler: func(w http.ResponseWriter, r *http.Request) {},
    63  			}},
    64  		},
    65  		{
    66  			priority: true,
    67  			jwt: jwtSetting{
    68  				enabled:    true,
    69  				prevSecret: "thesecret",
    70  			},
    71  			signature: signatureSetting{},
    72  			routes: []Route{{
    73  				Method:  http.MethodGet,
    74  				Path:    "/",
    75  				Handler: func(w http.ResponseWriter, r *http.Request) {},
    76  			}},
    77  		},
    78  		{
    79  			priority: true,
    80  			jwt: jwtSetting{
    81  				enabled: true,
    82  			},
    83  			signature: signatureSetting{},
    84  			routes: []Route{{
    85  				Method:  http.MethodGet,
    86  				Path:    "/",
    87  				Handler: func(w http.ResponseWriter, r *http.Request) {},
    88  			}},
    89  		},
    90  		{
    91  			priority: true,
    92  			jwt: jwtSetting{
    93  				enabled: true,
    94  			},
    95  			signature: signatureSetting{
    96  				enabled: true,
    97  			},
    98  			routes: []Route{{
    99  				Method:  http.MethodGet,
   100  				Path:    "/",
   101  				Handler: func(w http.ResponseWriter, r *http.Request) {},
   102  			}},
   103  		},
   104  		{
   105  			priority: true,
   106  			jwt: jwtSetting{
   107  				enabled: true,
   108  			},
   109  			signature: signatureSetting{
   110  				enabled: true,
   111  				SignatureConf: SignatureConf{
   112  					Strict: true,
   113  				},
   114  			},
   115  			routes: []Route{{
   116  				Method:  http.MethodGet,
   117  				Path:    "/",
   118  				Handler: func(w http.ResponseWriter, r *http.Request) {},
   119  			}},
   120  		},
   121  		{
   122  			priority: true,
   123  			jwt: jwtSetting{
   124  				enabled: true,
   125  			},
   126  			signature: signatureSetting{
   127  				enabled: true,
   128  				SignatureConf: SignatureConf{
   129  					Strict: true,
   130  					PrivateKeys: []PrivateKeyConf{
   131  						{
   132  							Fingerprint: "a",
   133  							KeyFile:     "b",
   134  						},
   135  					},
   136  				},
   137  			},
   138  			routes: []Route{{
   139  				Method:  http.MethodGet,
   140  				Path:    "/",
   141  				Handler: func(w http.ResponseWriter, r *http.Request) {},
   142  			}},
   143  		},
   144  	}
   145  
   146  	for _, yaml := range yamls {
   147  		for _, route := range routes {
   148  			var cnf RestConf
   149  			assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
   150  			ng := newEngine(cnf)
   151  			ng.addRoutes(route)
   152  			ng.use(func(next http.HandlerFunc) http.HandlerFunc {
   153  				return func(w http.ResponseWriter, r *http.Request) {
   154  					next.ServeHTTP(w, r)
   155  				}
   156  			})
   157  			assert.NotNil(t, ng.start(mockedRouter{}))
   158  		}
   159  	}
   160  }
   161  
   162  func TestEngine_checkedTimeout(t *testing.T) {
   163  	tests := []struct {
   164  		name    string
   165  		timeout time.Duration
   166  		expect  time.Duration
   167  	}{
   168  		{
   169  			name:   "not set",
   170  			expect: time.Second,
   171  		},
   172  		{
   173  			name:    "less",
   174  			timeout: time.Millisecond * 500,
   175  			expect:  time.Millisecond * 500,
   176  		},
   177  		{
   178  			name:    "equal",
   179  			timeout: time.Second,
   180  			expect:  time.Second,
   181  		},
   182  		{
   183  			name:    "more",
   184  			timeout: time.Millisecond * 1500,
   185  			expect:  time.Millisecond * 1500,
   186  		},
   187  	}
   188  
   189  	ng := newEngine(RestConf{
   190  		Timeout: 1000,
   191  	})
   192  	for _, test := range tests {
   193  		assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
   194  	}
   195  }
   196  
   197  func TestEngine_checkedMaxBytes(t *testing.T) {
   198  	tests := []struct {
   199  		name     string
   200  		maxBytes int64
   201  		expect   int64
   202  	}{
   203  		{
   204  			name:   "not set",
   205  			expect: 1000,
   206  		},
   207  		{
   208  			name:     "less",
   209  			maxBytes: 500,
   210  			expect:   500,
   211  		},
   212  		{
   213  			name:     "equal",
   214  			maxBytes: 1000,
   215  			expect:   1000,
   216  		},
   217  		{
   218  			name:     "more",
   219  			maxBytes: 1500,
   220  			expect:   1500,
   221  		},
   222  	}
   223  
   224  	ng := newEngine(RestConf{
   225  		MaxBytes: 1000,
   226  	})
   227  	for _, test := range tests {
   228  		assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
   229  	}
   230  }
   231  
   232  func TestEngine_notFoundHandler(t *testing.T) {
   233  	logx.Disable()
   234  
   235  	ng := newEngine(RestConf{})
   236  	ts := httptest.NewServer(ng.notFoundHandler(nil))
   237  	defer ts.Close()
   238  
   239  	client := ts.Client()
   240  	err := func(_ context.Context) error {
   241  		req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
   242  		assert.Nil(t, err)
   243  		res, err := client.Do(req)
   244  		assert.Nil(t, err)
   245  		assert.Equal(t, http.StatusNotFound, res.StatusCode)
   246  		return res.Body.Close()
   247  	}(context.Background())
   248  
   249  	assert.Nil(t, err)
   250  }
   251  
   252  func TestEngine_notFoundHandlerNotNil(t *testing.T) {
   253  	logx.Disable()
   254  
   255  	ng := newEngine(RestConf{})
   256  	var called int32
   257  	ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   258  		atomic.AddInt32(&called, 1)
   259  	})))
   260  	defer ts.Close()
   261  
   262  	client := ts.Client()
   263  	err := func(_ context.Context) error {
   264  		req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
   265  		assert.Nil(t, err)
   266  		res, err := client.Do(req)
   267  		assert.Nil(t, err)
   268  		assert.Equal(t, http.StatusNotFound, res.StatusCode)
   269  		return res.Body.Close()
   270  	}(context.Background())
   271  
   272  	assert.Nil(t, err)
   273  	assert.Equal(t, int32(1), atomic.LoadInt32(&called))
   274  }
   275  
   276  func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
   277  	logx.Disable()
   278  
   279  	ng := newEngine(RestConf{})
   280  	var called int32
   281  	ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   282  		atomic.AddInt32(&called, 1)
   283  		w.WriteHeader(http.StatusExpectationFailed)
   284  	})))
   285  	defer ts.Close()
   286  
   287  	client := ts.Client()
   288  	err := func(_ context.Context) error {
   289  		req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
   290  		assert.Nil(t, err)
   291  		res, err := client.Do(req)
   292  		assert.Nil(t, err)
   293  		assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
   294  		return res.Body.Close()
   295  	}(context.Background())
   296  
   297  	assert.Nil(t, err)
   298  	assert.Equal(t, int32(1), atomic.LoadInt32(&called))
   299  }
   300  
   301  func TestEngine_withTimeout(t *testing.T) {
   302  	logx.Disable()
   303  
   304  	tests := []struct {
   305  		name    string
   306  		timeout int64
   307  	}{
   308  		{
   309  			name: "not set",
   310  		},
   311  		{
   312  			name:    "set",
   313  			timeout: 1000,
   314  		},
   315  	}
   316  
   317  	for _, test := range tests {
   318  		test := test
   319  		t.Run(test.name, func(t *testing.T) {
   320  			ng := newEngine(RestConf{Timeout: test.timeout})
   321  			svr := &http.Server{}
   322  			ng.withTimeout()(svr)
   323  
   324  			assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
   325  			assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
   326  			assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*9/10, svr.WriteTimeout)
   327  			assert.Equal(t, time.Duration(0), svr.IdleTimeout)
   328  		})
   329  	}
   330  }
   331  
   332  type mockedRouter struct{}
   333  
   334  func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
   335  }
   336  
   337  func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
   338  	return errors.New("foo")
   339  }
   340  
   341  func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
   342  }
   343  
   344  func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
   345  }