github.com/lingyao2333/mo-zero@v1.4.1/rest/engine_test.go (about) 1 package rest 2 3 import ( 4 "context" 5 "errors" 6 "net/http" 7 "net/http/httptest" 8 "sync/atomic" 9 "testing" 10 "time" 11 12 "github.com/lingyao2333/mo-zero/core/conf" 13 "github.com/lingyao2333/mo-zero/core/logx" 14 "github.com/stretchr/testify/assert" 15 ) 16 17 func TestNewEngine(t *testing.T) { 18 yamls := []string{ 19 `Name: foo 20 Port: 54321 21 `, 22 `Name: foo 23 Port: 54321 24 CpuThreshold: 500 25 `, 26 `Name: foo 27 Port: 54321 28 CpuThreshold: 500 29 Verbose: true 30 `, 31 } 32 33 routes := []featuredRoutes{ 34 { 35 jwt: jwtSetting{}, 36 signature: signatureSetting{}, 37 routes: []Route{{ 38 Method: http.MethodGet, 39 Path: "/", 40 Handler: func(w http.ResponseWriter, r *http.Request) {}, 41 }}, 42 }, 43 { 44 priority: true, 45 jwt: jwtSetting{}, 46 signature: signatureSetting{}, 47 routes: []Route{{ 48 Method: http.MethodGet, 49 Path: "/", 50 Handler: func(w http.ResponseWriter, r *http.Request) {}, 51 }}, 52 }, 53 { 54 priority: true, 55 jwt: jwtSetting{ 56 enabled: true, 57 }, 58 signature: signatureSetting{}, 59 routes: []Route{{ 60 Method: http.MethodGet, 61 Path: "/", 62 Handler: func(w http.ResponseWriter, r *http.Request) {}, 63 }}, 64 }, 65 { 66 priority: true, 67 jwt: jwtSetting{ 68 enabled: true, 69 prevSecret: "thesecret", 70 }, 71 signature: signatureSetting{}, 72 routes: []Route{{ 73 Method: http.MethodGet, 74 Path: "/", 75 Handler: func(w http.ResponseWriter, r *http.Request) {}, 76 }}, 77 }, 78 { 79 priority: true, 80 jwt: jwtSetting{ 81 enabled: true, 82 }, 83 signature: signatureSetting{}, 84 routes: []Route{{ 85 Method: http.MethodGet, 86 Path: "/", 87 Handler: func(w http.ResponseWriter, r *http.Request) {}, 88 }}, 89 }, 90 { 91 priority: true, 92 jwt: jwtSetting{ 93 enabled: true, 94 }, 95 signature: signatureSetting{ 96 enabled: true, 97 }, 98 routes: []Route{{ 99 Method: http.MethodGet, 100 Path: "/", 101 Handler: func(w http.ResponseWriter, r *http.Request) {}, 102 }}, 103 }, 104 { 105 priority: true, 106 jwt: jwtSetting{ 107 enabled: true, 108 }, 109 signature: signatureSetting{ 110 enabled: true, 111 SignatureConf: SignatureConf{ 112 Strict: true, 113 }, 114 }, 115 routes: []Route{{ 116 Method: http.MethodGet, 117 Path: "/", 118 Handler: func(w http.ResponseWriter, r *http.Request) {}, 119 }}, 120 }, 121 { 122 priority: true, 123 jwt: jwtSetting{ 124 enabled: true, 125 }, 126 signature: signatureSetting{ 127 enabled: true, 128 SignatureConf: SignatureConf{ 129 Strict: true, 130 PrivateKeys: []PrivateKeyConf{ 131 { 132 Fingerprint: "a", 133 KeyFile: "b", 134 }, 135 }, 136 }, 137 }, 138 routes: []Route{{ 139 Method: http.MethodGet, 140 Path: "/", 141 Handler: func(w http.ResponseWriter, r *http.Request) {}, 142 }}, 143 }, 144 } 145 146 for _, yaml := range yamls { 147 for _, route := range routes { 148 var cnf RestConf 149 assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf)) 150 ng := newEngine(cnf) 151 ng.addRoutes(route) 152 ng.use(func(next http.HandlerFunc) http.HandlerFunc { 153 return func(w http.ResponseWriter, r *http.Request) { 154 next.ServeHTTP(w, r) 155 } 156 }) 157 assert.NotNil(t, ng.start(mockedRouter{})) 158 } 159 } 160 } 161 162 func TestEngine_checkedTimeout(t *testing.T) { 163 tests := []struct { 164 name string 165 timeout time.Duration 166 expect time.Duration 167 }{ 168 { 169 name: "not set", 170 expect: time.Second, 171 }, 172 { 173 name: "less", 174 timeout: time.Millisecond * 500, 175 expect: time.Millisecond * 500, 176 }, 177 { 178 name: "equal", 179 timeout: time.Second, 180 expect: time.Second, 181 }, 182 { 183 name: "more", 184 timeout: time.Millisecond * 1500, 185 expect: time.Millisecond * 1500, 186 }, 187 } 188 189 ng := newEngine(RestConf{ 190 Timeout: 1000, 191 }) 192 for _, test := range tests { 193 assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout)) 194 } 195 } 196 197 func TestEngine_checkedMaxBytes(t *testing.T) { 198 tests := []struct { 199 name string 200 maxBytes int64 201 expect int64 202 }{ 203 { 204 name: "not set", 205 expect: 1000, 206 }, 207 { 208 name: "less", 209 maxBytes: 500, 210 expect: 500, 211 }, 212 { 213 name: "equal", 214 maxBytes: 1000, 215 expect: 1000, 216 }, 217 { 218 name: "more", 219 maxBytes: 1500, 220 expect: 1500, 221 }, 222 } 223 224 ng := newEngine(RestConf{ 225 MaxBytes: 1000, 226 }) 227 for _, test := range tests { 228 assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes)) 229 } 230 } 231 232 func TestEngine_notFoundHandler(t *testing.T) { 233 logx.Disable() 234 235 ng := newEngine(RestConf{}) 236 ts := httptest.NewServer(ng.notFoundHandler(nil)) 237 defer ts.Close() 238 239 client := ts.Client() 240 err := func(_ context.Context) error { 241 req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody) 242 assert.Nil(t, err) 243 res, err := client.Do(req) 244 assert.Nil(t, err) 245 assert.Equal(t, http.StatusNotFound, res.StatusCode) 246 return res.Body.Close() 247 }(context.Background()) 248 249 assert.Nil(t, err) 250 } 251 252 func TestEngine_notFoundHandlerNotNil(t *testing.T) { 253 logx.Disable() 254 255 ng := newEngine(RestConf{}) 256 var called int32 257 ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 258 atomic.AddInt32(&called, 1) 259 }))) 260 defer ts.Close() 261 262 client := ts.Client() 263 err := func(_ context.Context) error { 264 req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody) 265 assert.Nil(t, err) 266 res, err := client.Do(req) 267 assert.Nil(t, err) 268 assert.Equal(t, http.StatusNotFound, res.StatusCode) 269 return res.Body.Close() 270 }(context.Background()) 271 272 assert.Nil(t, err) 273 assert.Equal(t, int32(1), atomic.LoadInt32(&called)) 274 } 275 276 func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) { 277 logx.Disable() 278 279 ng := newEngine(RestConf{}) 280 var called int32 281 ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 282 atomic.AddInt32(&called, 1) 283 w.WriteHeader(http.StatusExpectationFailed) 284 }))) 285 defer ts.Close() 286 287 client := ts.Client() 288 err := func(_ context.Context) error { 289 req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody) 290 assert.Nil(t, err) 291 res, err := client.Do(req) 292 assert.Nil(t, err) 293 assert.Equal(t, http.StatusExpectationFailed, res.StatusCode) 294 return res.Body.Close() 295 }(context.Background()) 296 297 assert.Nil(t, err) 298 assert.Equal(t, int32(1), atomic.LoadInt32(&called)) 299 } 300 301 func TestEngine_withTimeout(t *testing.T) { 302 logx.Disable() 303 304 tests := []struct { 305 name string 306 timeout int64 307 }{ 308 { 309 name: "not set", 310 }, 311 { 312 name: "set", 313 timeout: 1000, 314 }, 315 } 316 317 for _, test := range tests { 318 test := test 319 t.Run(test.name, func(t *testing.T) { 320 ng := newEngine(RestConf{Timeout: test.timeout}) 321 svr := &http.Server{} 322 ng.withTimeout()(svr) 323 324 assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout) 325 assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout) 326 assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*9/10, svr.WriteTimeout) 327 assert.Equal(t, time.Duration(0), svr.IdleTimeout) 328 }) 329 } 330 } 331 332 type mockedRouter struct{} 333 334 func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { 335 } 336 337 func (m mockedRouter) Handle(_, _ string, handler http.Handler) error { 338 return errors.New("foo") 339 } 340 341 func (m mockedRouter) SetNotFoundHandler(_ http.Handler) { 342 } 343 344 func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) { 345 }