github.com/wangkui503/aero@v1.0.0/Context_test.go (about) 1 package aero_test 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "io/ioutil" 8 "net/http" 9 "net/http/httptest" 10 "strconv" 11 "strings" 12 "testing" 13 "time" 14 15 "github.com/aerogo/session" 16 jsoniter "github.com/json-iterator/go" 17 18 "github.com/aerogo/aero" 19 "github.com/stretchr/testify/assert" 20 ) 21 22 func TestContextResponseHeader(t *testing.T) { 23 app := aero.New() 24 25 // Register route 26 app.Get("/", func(ctx *aero.Context) string { 27 ctx.Response().Header().Set("X-Custom", "42") 28 return ctx.Text(helloWorld) 29 }) 30 31 // Get response 32 response := getResponse(app, "/") 33 34 // Verify response 35 assert.Equal(t, http.StatusOK, response.Code) 36 assert.Equal(t, helloWorld, response.Body.String()) 37 assert.Equal(t, "42", response.Header().Get("X-Custom")) 38 } 39 40 func TestContextError(t *testing.T) { 41 app := aero.New() 42 43 // Register route 44 app.Get("/", func(ctx *aero.Context) string { 45 return ctx.Error(http.StatusUnauthorized, "Not authorized", errors.New("Not logged in")) 46 }) 47 48 app.Get("/explanation-only", func(ctx *aero.Context) string { 49 return ctx.Error(http.StatusUnauthorized, "Not authorized", nil) 50 }) 51 52 app.Get("/unknown-error", func(ctx *aero.Context) string { 53 return ctx.Error(http.StatusUnauthorized) 54 }) 55 56 // Verify response with known error 57 response := getResponse(app, "/") 58 assert.Equal(t, http.StatusUnauthorized, response.Code) 59 assert.Contains(t, response.Body.String(), "Not logged in") 60 61 // Verify response with explanation only 62 response = getResponse(app, "/explanation-only") 63 assert.Equal(t, http.StatusUnauthorized, response.Code) 64 assert.Contains(t, response.Body.String(), "Not authorized") 65 66 // Verify response with unknown error 67 response = getResponse(app, "/unknown-error") 68 assert.Equal(t, http.StatusUnauthorized, response.Code) 69 assert.Contains(t, response.Body.String(), "Unknown error") 70 } 71 72 func TestContextURI(t *testing.T) { 73 app := aero.New() 74 75 // Register route 76 app.Get("/uri", func(ctx *aero.Context) string { 77 return ctx.URI() 78 }) 79 80 app.Get("/set-uri", func(ctx *aero.Context) string { 81 ctx.SetURI("/hello") 82 return ctx.URI() 83 }) 84 85 // Verify response with read-only URI 86 response := getResponse(app, "/uri") 87 assert.Equal(t, http.StatusOK, response.Code) 88 assert.Contains(t, response.Body.String(), "/uri") 89 90 // Verify response with modified URI 91 response = getResponse(app, "/set-uri") 92 assert.Equal(t, http.StatusOK, response.Code) 93 assert.Contains(t, response.Body.String(), "/hello") 94 } 95 96 func TestContextRealIP(t *testing.T) { 97 app := aero.New() 98 99 // Register route 100 app.Get("/ip", func(ctx *aero.Context) string { 101 return ctx.RealIP() 102 }) 103 104 // Get response 105 response := getResponse(app, "/ip") 106 107 // Verify response 108 assert.Equal(t, http.StatusOK, response.Code) 109 assert.Contains(t, response.Body.String(), "") 110 } 111 112 func TestContextSession(t *testing.T) { 113 app := aero.New() 114 115 // Register route 116 app.Get("/", func(ctx *aero.Context) string { 117 assert.Equal(t, false, ctx.HasSession()) 118 ctx.Session().Set("custom", helloWorld) 119 assert.Equal(t, true, ctx.HasSession()) 120 121 return ctx.Text(ctx.Session().GetString("custom")) 122 }) 123 124 // Get response 125 response := getResponse(app, "/") 126 127 // Verify response 128 assert.Equal(t, http.StatusOK, response.Code) 129 assert.Equal(t, helloWorld, response.Body.String()) 130 } 131 132 func TestContextSessionInvalidCookie(t *testing.T) { 133 app := aero.New() 134 135 // Register route 136 app.Get("/", func(ctx *aero.Context) string { 137 assert.Equal(t, false, ctx.HasSession()) 138 ctx.Session().Set("custom", helloWorld) 139 assert.Equal(t, true, ctx.HasSession()) 140 141 return ctx.Text(ctx.Session().GetString("custom")) 142 }) 143 144 // Create request 145 request, _ := http.NewRequest("GET", "/", nil) 146 request.Header.Set("Accept-Encoding", "gzip") 147 request.Header.Set("Cookie", "sid=invalid") 148 149 // Get response 150 response := httptest.NewRecorder() 151 app.Handler().ServeHTTP(response, request) 152 153 // Verify response 154 assert.Equal(t, http.StatusOK, response.Code) 155 assert.Equal(t, helloWorld, response.Body.String()) 156 } 157 158 func TestContextSessionValidCookie(t *testing.T) { 159 app := aero.New() 160 161 // Register routes 162 app.Get("/1", func(ctx *aero.Context) string { 163 assert.Equal(t, false, ctx.HasSession()) 164 ctx.Session().Set("custom", helloWorld) 165 assert.Equal(t, true, ctx.HasSession()) 166 assert.Equal(t, ctx.Session().GetString("sid"), ctx.Session().ID()) 167 168 return ctx.Text(ctx.Session().GetString("custom")) 169 }) 170 171 app.Get("/2", func(ctx *aero.Context) string { 172 assert.Equal(t, true, ctx.HasSession()) 173 assert.Equal(t, ctx.Session().GetString("sid"), ctx.Session().ID()) 174 175 return ctx.Text(ctx.Session().GetString("custom")) 176 }) 177 178 app.Get("/3", func(ctx *aero.Context) string { 179 assert.Equal(t, ctx.Session().GetString("sid"), ctx.Session().ID()) 180 181 return ctx.Text(ctx.Session().GetString("custom")) 182 }) 183 184 // Create request 1 185 request1, _ := http.NewRequest("GET", "/1", nil) 186 187 // Get response 1 188 response1 := httptest.NewRecorder() 189 app.Handler().ServeHTTP(response1, request1) 190 191 // Verify response 1 192 assert.Equal(t, http.StatusOK, response1.Code) 193 assert.Equal(t, helloWorld, response1.Body.String()) 194 195 setCookie := response1.Header().Get("Set-Cookie") 196 assert.NotEmpty(t, setCookie) 197 assert.Contains(t, setCookie, "sid=") 198 199 cookieParts := strings.Split(setCookie, ";") 200 sidLine := strings.TrimSpace(cookieParts[0]) 201 sidParts := strings.Split(sidLine, "=") 202 sid := sidParts[1] 203 assert.True(t, session.IsValidID(sid)) 204 205 // Create request 2 206 request2, _ := http.NewRequest("GET", "/2", nil) 207 request2.AddCookie(&http.Cookie{ 208 Name: "sid", 209 Value: sid, 210 }) 211 212 // Get response 2 213 response2 := httptest.NewRecorder() 214 app.Handler().ServeHTTP(response2, request2) 215 216 // Verify response 2 217 assert.Equal(t, http.StatusOK, response2.Code) 218 assert.Equal(t, helloWorld, response2.Body.String()) 219 220 // Create request 3 221 request3, _ := http.NewRequest("GET", "/3", nil) 222 request3.AddCookie(&http.Cookie{ 223 Name: "sid", 224 Value: sid, 225 }) 226 227 // Get response 3 228 response3 := httptest.NewRecorder() 229 app.Handler().ServeHTTP(response3, request3) 230 231 // Verify response 3 232 assert.Equal(t, http.StatusOK, response3.Code) 233 assert.Equal(t, helloWorld, response3.Body.String()) 234 } 235 236 func TestContextContentTypes(t *testing.T) { 237 app := aero.New() 238 239 // Register routes 240 app.Get("/json", func(ctx *aero.Context) string { 241 return ctx.JSON(app.Config) 242 }) 243 244 app.Get("/jsonld", func(ctx *aero.Context) string { 245 return ctx.JSONLinkedData(app.Config) 246 }) 247 248 app.Get("/html", func(ctx *aero.Context) string { 249 return ctx.HTML("<html></html>") 250 }) 251 252 app.Get("/css", func(ctx *aero.Context) string { 253 return ctx.CSS("body{}") 254 }) 255 256 app.Get("/js", func(ctx *aero.Context) string { 257 return ctx.JavaScript("console.log(42)") 258 }) 259 260 app.Get("/files/*file", func(ctx *aero.Context) string { 261 return ctx.File(ctx.Get("file")) 262 }) 263 264 // Get responses 265 responseJSON := getResponse(app, "/json") 266 responseJSONLD := getResponse(app, "/jsonld") 267 responseHTML := getResponse(app, "/html") 268 responseCSS := getResponse(app, "/css") 269 responseJS := getResponse(app, "/js") 270 responseFile := getResponse(app, "/files/Application.go") 271 responseMediaFile := getResponse(app, "/files/docs/usage.gif") 272 273 // Verify JSON response 274 json, _ := jsoniter.Marshal(app.Config) 275 assert.Equal(t, http.StatusOK, responseJSON.Code) 276 assert.Equal(t, json, responseJSON.Body.Bytes()) 277 assert.Contains(t, responseJSON.Header().Get("Content-Type"), "application/json") 278 279 // Verify JSON+LD response 280 assert.Equal(t, http.StatusOK, responseJSONLD.Code) 281 assert.Equal(t, json, responseJSONLD.Body.Bytes()) 282 assert.Contains(t, responseJSONLD.Header().Get("Content-Type"), "application/ld+json") 283 284 // Verify HTML response 285 assert.Equal(t, http.StatusOK, responseHTML.Code) 286 assert.Equal(t, "<html></html>", responseHTML.Body.String()) 287 assert.Contains(t, responseHTML.Header().Get("Content-Type"), "text/html") 288 289 // Verify CSS response 290 assert.Equal(t, http.StatusOK, responseCSS.Code) 291 assert.Equal(t, "body{}", responseCSS.Body.String()) 292 assert.Contains(t, responseCSS.Header().Get("Content-Type"), "text/css") 293 294 // Verify JS response 295 assert.Equal(t, http.StatusOK, responseJS.Code) 296 assert.Equal(t, "console.log(42)", responseJS.Body.String()) 297 assert.Contains(t, responseJS.Header().Get("Content-Type"), "application/javascript") 298 299 // Verify file response 300 appSourceCode, _ := ioutil.ReadFile("Application.go") 301 assert.Equal(t, http.StatusOK, responseFile.Code) 302 assert.Equal(t, appSourceCode, responseFile.Body.Bytes()) 303 assert.Contains(t, responseFile.Header().Get("Content-Type"), "text/plain") 304 305 // Verify media file response 306 imageData, _ := ioutil.ReadFile("docs/usage.gif") 307 assert.Equal(t, http.StatusOK, responseMediaFile.Code) 308 assert.Equal(t, imageData, responseMediaFile.Body.Bytes()) 309 assert.Contains(t, responseMediaFile.Header().Get("Content-Type"), "image/gif") 310 } 311 312 func TestContextReader(t *testing.T) { 313 app := aero.New() 314 config, _ := jsoniter.MarshalToString(app.Config) 315 316 // ReadAll 317 app.Get("/readall", func(ctx *aero.Context) string { 318 reader, writer := io.Pipe() 319 320 go func() { 321 defer writer.Close() 322 encoder := jsoniter.NewEncoder(writer) 323 encoder.Encode(app.Config) 324 }() 325 326 return ctx.ReadAll(reader) 327 }) 328 329 // Reader 330 app.Get("/reader", func(ctx *aero.Context) string { 331 reader, writer := io.Pipe() 332 333 go func() { 334 defer writer.Close() 335 encoder := jsoniter.NewEncoder(writer) 336 encoder.Encode(app.Config) 337 }() 338 339 return ctx.Reader(reader) 340 }) 341 342 // ReadSeeker 343 app.Get("/readseeker", func(ctx *aero.Context) string { 344 return ctx.ReadSeeker(strings.NewReader(config)) 345 }) 346 347 routes := []string{ 348 "/readall", 349 "/reader", 350 "/readseeker", 351 } 352 353 for _, route := range routes { 354 // Verify response 355 response := getResponse(app, route) 356 assert.Equal(t, http.StatusOK, response.Code) 357 assert.Equal(t, config, strings.TrimSpace(response.Body.String())) 358 } 359 } 360 361 func TestContextHTTP2Push(t *testing.T) { 362 app := aero.New() 363 app.Config.Push = append(app.Config.Push, "/pushed.css") 364 365 // Register routes 366 app.Get("/", func(ctx *aero.Context) string { 367 return ctx.HTML("<html></html>") 368 }) 369 370 app.Get("/pushed.css", func(ctx *aero.Context) string { 371 return ctx.CSS("body{}") 372 }) 373 374 // Add no-op push condition 375 app.AddPushCondition(func(ctx *aero.Context) bool { 376 return true 377 }) 378 379 // Get response 380 response := getResponse(app, "/") 381 382 // Verify response 383 assert.Equal(t, http.StatusOK, response.Code) 384 assert.Equal(t, "<html></html>", response.Body.String()) 385 } 386 387 func TestContextGetInt(t *testing.T) { 388 app := aero.New() 389 390 // Register route 391 app.Get("/:number", func(ctx *aero.Context) string { 392 number, err := ctx.GetInt("number") 393 assert.NoError(t, err) 394 assert.NotZero(t, number) 395 396 return ctx.Text(strconv.Itoa(number * 2)) 397 }) 398 399 // Get response 400 response := getResponse(app, "/21") 401 402 // Verify response 403 assert.Equal(t, http.StatusOK, response.Code) 404 assert.Equal(t, "42", response.Body.String()) 405 } 406 407 func TestContextUserAgent(t *testing.T) { 408 app := aero.New() 409 agent := "Luke Skywalker" 410 411 // Register route 412 app.Get("/", func(ctx *aero.Context) string { 413 userAgent := ctx.UserAgent() 414 return ctx.Text(userAgent) 415 }) 416 417 // Create request 418 request, _ := http.NewRequest("GET", "/", nil) 419 request.Header.Set("User-Agent", agent) 420 421 // Get response 422 response := httptest.NewRecorder() 423 app.Handler().ServeHTTP(response, request) 424 425 // Verify response 426 assert.Equal(t, http.StatusOK, response.Code) 427 assert.Equal(t, agent, response.Body.String()) 428 } 429 430 func TestContextRedirect(t *testing.T) { 431 app := aero.New() 432 433 // Register routes 434 app.Get("/permanent", func(ctx *aero.Context) string { 435 return ctx.RedirectPermanently("/target") 436 }) 437 438 app.Get("/temporary", func(ctx *aero.Context) string { 439 return ctx.Redirect("/target") 440 }) 441 442 // Get temporary response 443 response := getResponse(app, "/temporary") 444 445 // Verify response 446 assert.Equal(t, http.StatusFound, response.Code) 447 assert.Equal(t, "", response.Body.String()) 448 449 // Get permanent response 450 response = getResponse(app, "/permanent") 451 452 // Verify response 453 assert.Equal(t, http.StatusMovedPermanently, response.Code) 454 assert.Equal(t, "", response.Body.String()) 455 } 456 457 func TestContextQuery(t *testing.T) { 458 app := aero.New() 459 search := "Luke Skywalker" 460 461 // Register route 462 app.Get("/", func(ctx *aero.Context) string { 463 search := ctx.Query("search") 464 return ctx.Text(search) 465 }) 466 467 // Create request 468 request, _ := http.NewRequest("GET", "/?search="+search, nil) 469 470 // Get response 471 response := httptest.NewRecorder() 472 app.Handler().ServeHTTP(response, request) 473 474 // Verify response 475 assert.Equal(t, http.StatusOK, response.Code) 476 assert.Equal(t, search, response.Body.String()) 477 } 478 479 func TestContextEventStream(t *testing.T) { 480 app := aero.New() 481 482 // Register route 483 app.Get("/", func(ctx *aero.Context) string { 484 stream := aero.NewEventStream() 485 486 go func() { 487 for { 488 select { 489 case <-stream.Closed: 490 close(stream.Events) 491 return 492 493 case <-time.After(10 * time.Millisecond): 494 stream.Events <- &aero.Event{ 495 Name: "ping", 496 Data: "{}", 497 } 498 499 stream.Events <- &aero.Event{ 500 Name: "ping", 501 Data: []byte("{}"), 502 } 503 504 stream.Events <- &aero.Event{ 505 Name: "ping", 506 Data: struct { 507 Message string `json:"message"` 508 }{ 509 Message: "Hello", 510 }, 511 } 512 513 stream.Events <- &aero.Event{ 514 Name: "ping", 515 Data: nil, 516 } 517 } 518 } 519 }() 520 521 return ctx.EventStream(stream) 522 }) 523 524 // Create request 525 request, _ := http.NewRequest("GET", "/", nil) 526 ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) 527 defer cancel() 528 request = request.WithContext(ctx) 529 530 // Get response 531 response := httptest.NewRecorder() 532 app.Handler().ServeHTTP(response, request) 533 534 // Verify response 535 assert.Equal(t, http.StatusOK, response.Code) 536 } 537 538 func TestBigResponse(t *testing.T) { 539 text := strings.Repeat("Hello World", 1000000) 540 app := aero.New() 541 542 // Make sure GZip is enabled 543 assert.Equal(t, true, app.Config.GZip) 544 545 // Register route 546 app.Get("/", func(ctx *aero.Context) string { 547 return ctx.Text(text) 548 }) 549 550 // Get response 551 response := getResponse(app, "/") 552 553 // Verify the response 554 assert.Equal(t, http.StatusOK, response.Code) 555 assert.Equal(t, "gzip", response.Header().Get("Content-Encoding")) 556 } 557 558 func TestBigResponseNoGzip(t *testing.T) { 559 text := strings.Repeat("Hello World", 1000000) 560 app := aero.New() 561 562 // Register route 563 app.Get("/", func(ctx *aero.Context) string { 564 return ctx.Text(text) 565 }) 566 567 // Create request and record response 568 request, _ := http.NewRequest("GET", "/", nil) 569 response := httptest.NewRecorder() 570 app.Handler().ServeHTTP(response, request) 571 572 // Verify the response 573 assert.Equal(t, http.StatusOK, response.Code) 574 assert.Equal(t, "", response.Header().Get("Content-Encoding")) 575 } 576 577 func TestBigResponse304(t *testing.T) { 578 text := strings.Repeat("Hello World", 1000000) 579 app := aero.New() 580 581 // Register route 582 app.Get("/", func(ctx *aero.Context) string { 583 return ctx.Text(text) 584 }) 585 586 // Create request and record response 587 request, _ := http.NewRequest("GET", "/", nil) 588 response := httptest.NewRecorder() 589 app.Handler().ServeHTTP(response, request) 590 etag := response.Header().Get("ETag") 591 592 // Verify the response 593 assert.Equal(t, http.StatusOK, response.Code) 594 assert.NotEmpty(t, response.Body.String()) 595 596 // Set if-none-match to the etag we just received 597 request, _ = http.NewRequest("GET", "/", nil) 598 request.Header.Set("If-None-Match", etag) 599 response = httptest.NewRecorder() 600 app.Handler().ServeHTTP(response, request) 601 602 // Verify the response 603 assert.Equal(t, 304, response.Code) 604 assert.Empty(t, response.Body.String()) 605 }