github.com/cloudwego/hertz@v0.9.3/pkg/app/context_test.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package app 18 19 import ( 20 "bytes" 21 "context" 22 "encoding/xml" 23 "errors" 24 "fmt" 25 "html/template" 26 "io/ioutil" 27 "net" 28 "os" 29 "reflect" 30 "strings" 31 "testing" 32 "time" 33 34 "github.com/cloudwego/hertz/internal/bytesconv" 35 "github.com/cloudwego/hertz/internal/bytestr" 36 "github.com/cloudwego/hertz/pkg/app/server/binding" 37 "github.com/cloudwego/hertz/pkg/app/server/render" 38 errs "github.com/cloudwego/hertz/pkg/common/errors" 39 "github.com/cloudwego/hertz/pkg/common/test/assert" 40 "github.com/cloudwego/hertz/pkg/common/test/mock" 41 "github.com/cloudwego/hertz/pkg/common/testdata/proto" 42 "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" 43 "github.com/cloudwego/hertz/pkg/common/utils" 44 "github.com/cloudwego/hertz/pkg/network" 45 "github.com/cloudwego/hertz/pkg/protocol" 46 "github.com/cloudwego/hertz/pkg/protocol/consts" 47 "github.com/cloudwego/hertz/pkg/protocol/http1/req" 48 "github.com/cloudwego/hertz/pkg/protocol/http1/resp" 49 con "github.com/cloudwego/hertz/pkg/route/consts" 50 "github.com/cloudwego/hertz/pkg/route/param" 51 ) 52 53 func TestProtobuf(t *testing.T) { 54 ctx := NewContext(0) 55 body := proto.TestStruct{Body: []byte("Hello World")} 56 ctx.ProtoBuf(consts.StatusOK, &body) 57 58 assert.DeepEqual(t, string(ctx.Response.Body()), "\n\vHello World") 59 } 60 61 func TestPureJson(t *testing.T) { 62 ctx := NewContext(0) 63 ctx.PureJSON(consts.StatusOK, utils.H{ 64 "html": "<b>Hello World</b>", 65 }) 66 if string(ctx.Response.Body()) != "{\"html\":\"<b>Hello World</b>\"}\n" { 67 t.Fatalf("unexpected purejson: %#v, expected: %#v", string(ctx.Response.Body()), "<b>Hello World</b>") 68 } 69 } 70 71 func TestIndentedJSON(t *testing.T) { 72 ctx := NewContext(0) 73 ctx.IndentedJSON(consts.StatusOK, utils.H{ 74 "foo": "bar", 75 "html": "h1", 76 }) 77 if string(ctx.Response.Body()) != "{\n \"foo\": \"bar\",\n \"html\": \"h1\"\n}" { 78 t.Fatalf("unexpected purejson: %#v, expected: %#v", string(ctx.Response.Body()), "{\n \"foo\": \"bar\",\n \"html\": \"<b>\"\n}") 79 } 80 } 81 82 func TestContext(t *testing.T) { 83 reqContext := NewContext(0) 84 reqContext.Set("testContextKey", "testValue") 85 ctx := reqContext 86 if ctx.Value("testContextKey") != "testValue" { 87 t.Fatalf("unexpected value: %#v, expected: %#v", ctx.Value("testContextKey"), "testValue") 88 } 89 } 90 91 func TestValue(t *testing.T) { 92 ctx := NewContext(0) 93 94 v := ctx.Value("testContextKey") 95 assert.Nil(t, v) 96 97 ctx.Set("testContextKey", "testValue") 98 v = ctx.Value("testContextKey") 99 assert.DeepEqual(t, "testValue", v) 100 } 101 102 func TestContextNotModified(t *testing.T) { 103 reqContext := NewContext(0) 104 reqContext.Response.SetStatusCode(consts.StatusOK) 105 if reqContext.Response.StatusCode() != consts.StatusOK { 106 t.Fatalf("unexpected status code: %#v, expected: %#v", reqContext.Response.StatusCode(), consts.StatusOK) 107 } 108 reqContext.NotModified() 109 if reqContext.Response.StatusCode() != consts.StatusNotModified { 110 t.Fatalf("unexpected status code: %#v, expected: %#v", reqContext.Response.StatusCode(), consts.StatusNotModified) 111 } 112 } 113 114 func TestIfModifiedSince(t *testing.T) { 115 ctx := NewContext(0) 116 var req protocol.Request 117 req.Header.Set(string(bytestr.StrIfModifiedSince), "Mon, 02 Jan 2006 15:04:05 MST") 118 req.CopyTo(&ctx.Request) 119 if !ctx.IfModifiedSince(time.Now()) { 120 t.Fatalf("ifModifiedSice error, expected false, but get true") 121 } 122 tt, _ := time.Parse(time.RFC3339, "2004-11-12T11:45:26.371Z") 123 if ctx.IfModifiedSince(tt) { 124 t.Fatalf("ifModifiedSice error, expected true, but get false") 125 } 126 } 127 128 func TestWrite(t *testing.T) { 129 ctx := NewContext(0) 130 l, err := ctx.Write([]byte("test body")) 131 if err != nil { 132 t.Fatalf("unexpected error: %#v", err.Error()) 133 } 134 if l != 9 { 135 t.Fatalf("unexpected len: %#v, expected: %#v", l, 9) 136 } 137 if string(ctx.Response.BodyBytes()) != "test body" { 138 t.Fatalf("unexpected body: %#v, expected: %#v", string(ctx.Response.BodyBytes()), "test body") 139 } 140 } 141 142 func TestSetConnectionClose(t *testing.T) { 143 ctx := NewContext(0) 144 ctx.SetConnectionClose() 145 if !ctx.Response.Header.ConnectionClose() { 146 t.Fatalf("expected close connection, but not") 147 } 148 } 149 150 func TestNotFound(t *testing.T) { 151 ctx := NewContext(0) 152 ctx.NotFound() 153 if ctx.Response.StatusCode() != consts.StatusNotFound || string(ctx.Response.BodyBytes()) != "404 Page not found" { 154 t.Fatalf("unexpected status code or body") 155 } 156 } 157 158 func TestRedirect(t *testing.T) { 159 ctx := NewContext(0) 160 ctx.Redirect(consts.StatusFound, []byte("/hello")) 161 assert.DeepEqual(t, consts.StatusFound, ctx.Response.StatusCode()) 162 163 ctx.redirect([]byte("/hello"), consts.StatusMovedPermanently) 164 assert.DeepEqual(t, consts.StatusMovedPermanently, ctx.Response.StatusCode()) 165 } 166 167 func TestGetRedirectStatusCode(t *testing.T) { 168 val := getRedirectStatusCode(consts.StatusMovedPermanently) 169 assert.DeepEqual(t, consts.StatusMovedPermanently, val) 170 171 val = getRedirectStatusCode(consts.StatusNotFound) 172 assert.DeepEqual(t, consts.StatusFound, val) 173 } 174 175 func TestCookie(t *testing.T) { 176 ctx := NewContext(0) 177 ctx.Request.Header.SetCookie("cookie", "test cookie") 178 if string(ctx.Cookie("cookie")) != "test cookie" { 179 t.Fatalf("unexpected cookie: %#v, expected get: %#v", string(ctx.Cookie("cookie")), "test cookie") 180 } 181 } 182 183 func TestUserAgent(t *testing.T) { 184 ctx := NewContext(0) 185 ctx.Request.Header.SetUserAgentBytes([]byte("user agent")) 186 if string(ctx.UserAgent()) != "user agent" { 187 t.Fatalf("unexpected user agent: %#v, expected get: %#v", string(ctx.UserAgent()), "user agent") 188 } 189 } 190 191 func TestStatus(t *testing.T) { 192 ctx := NewContext(0) 193 ctx.Status(consts.StatusOK) 194 if ctx.Response.StatusCode() != consts.StatusOK { 195 t.Fatalf("expected get consts.StatusOK, but not") 196 } 197 } 198 199 func TestPost(t *testing.T) { 200 ctx := NewContext(0) 201 ctx.Request.Header.SetMethod(consts.MethodPost) 202 if !ctx.IsPost() { 203 t.Fatalf("expected post method , but get: %#v", ctx.Method()) 204 } 205 206 if string(ctx.Method()) != consts.MethodPost { 207 t.Fatalf("expected post method , but get: %#v", ctx.Method()) 208 } 209 } 210 211 func TestGet(t *testing.T) { 212 ctx := NewContext(0) 213 ctx.Request.Header.SetMethod(consts.MethodPost) 214 assert.False(t, ctx.IsGet()) 215 216 ctx.Request.Header.SetMethod(consts.MethodGet) 217 assert.True(t, ctx.IsGet()) 218 } 219 220 func TestCopy(t *testing.T) { 221 t.Parallel() 222 ctx := NewContext(0) 223 ctx.fullPath = "full_path" 224 ctx.Request.Header.Add("header_a", "header_value_a") 225 ctx.Response.Header.Add("header_b", "header_value_b") 226 ctx.Params = param.Params{ 227 {Key: "key_a", Value: "value_a"}, 228 {Key: "key_b", Value: "value_b"}, 229 {Key: "key_c", Value: "value_b"}, 230 {Key: "key_d", Value: "value_b"}, 231 {Key: "key_e", Value: "value_b"}, 232 {Key: "key_f", Value: "value_b"}, 233 {Key: "key_g", Value: "value_b"}, 234 {Key: "key_h", Value: "value_b"}, 235 {Key: "key_i", Value: "value_b"}, 236 } 237 ctx.Set("map_key_a", "map_value_a") 238 ctx.Set("map_key_b", "map_value_b") 239 for i := 0; i <= 10000; i++ { 240 c := ctx.Copy() 241 go func(context *RequestContext) { 242 str, _ := context.Params.Get("key_a") 243 if str != "value_a" { 244 t.Errorf("unexpected value: %#v, expected: %#v", str, "value_a") 245 return 246 } 247 248 if c.fullPath != "full_path" { 249 t.Errorf("unexpected value: %#v, expected: %#v", c.fullPath, "full_path") 250 return 251 } 252 253 reqHeaderStr := context.Request.Header.Get("header_a") 254 if reqHeaderStr != "header_value_a" { 255 t.Errorf("unexpected value: %#v, expected: %#v", reqHeaderStr, "header_value_a") 256 return 257 } 258 259 respHeaderStr := context.Response.Header.Get("header_b") 260 if respHeaderStr != "header_value_b" { 261 t.Errorf("unexpected value: %#v, expected: %#v", respHeaderStr, "header_value_b") 262 return 263 } 264 265 iStr := ctx.Value("map_key_a") 266 if iStr.(string) != "map_value_a" { 267 t.Errorf("unexpected value: %#v, expected: %#v", iStr.(string), "map_value_a") 268 return 269 } 270 271 context.Params = context.Params[0:0] 272 context.Params = append(context.Params, param.Param{Key: "key_a", Value: "value_a_"}) 273 274 context.Request.Header.Reset() 275 context.Request.Header.Add("header_a", "header_value_a_") 276 context.Response.Header.Reset() 277 context.Response.Header.Add("header_b", "header_value_b_") 278 context.Keys = nil 279 context.Keys = make(map[string]interface{}) 280 context.Set("header_value_a", "map_value_a_") 281 }(c) 282 } 283 } 284 285 func TestQuery(t *testing.T) { 286 var r protocol.Request 287 ctx := NewContext(0) 288 s := "POST /foo?name=menu&value= HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3 \r\nabc\r\n0\r\n\r\n" 289 zr := mock.NewZeroCopyReader(s) 290 err := req.Read(&r, zr) 291 if err != nil { 292 t.Fatalf("Unexpected error when reading chunked request: %s", err) 293 } 294 r.CopyTo(&ctx.Request) 295 if ctx.Query("name") != "menu" { 296 t.Fatalf("unexpected query: %#v, expected menu", ctx.Query("name")) 297 } 298 299 if ctx.DefaultQuery("name", "default value") != "menu" { 300 t.Fatalf("unexpected query: %#v, expected menu", ctx.Query("name")) 301 } 302 303 if ctx.DefaultQuery("defaultQuery", "default value") != "default value" { 304 t.Fatalf("unexpected query: %#v, expected `default value`", ctx.Query("defaultQuery")) 305 } 306 } 307 308 func TestMethod(t *testing.T) { 309 ctx := NewContext(0) 310 ctx.Status(consts.StatusOK) 311 if ctx.Response.StatusCode() != consts.StatusOK { 312 t.Fatalf("expected get consts.StatusOK, but not") 313 } 314 } 315 316 func makeCtxByReqString(t *testing.T, s string) *RequestContext { 317 ctx := NewContext(0) 318 319 mr := mock.NewZeroCopyReader(s) 320 if err := req.Read(&ctx.Request, mr); err != nil { 321 t.Fatalf("unexpected error: %s", err) 322 } 323 return ctx 324 } 325 326 func TestPostForm(t *testing.T) { 327 t.Parallel() 328 329 ctx := makeCtxByReqString(t, `POST /upload HTTP/1.1 330 Host: localhost:10000 331 Content-Length: 521 332 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg 333 334 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 335 Content-Disposition: form-data; name="f1" 336 337 value1 338 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 339 Content-Disposition: form-data; name="fileaaa"; filename="TODO" 340 Content-Type: application/octet-stream 341 342 - SessionClient with referer and cookies support. 343 - Client with requests' pipelining support. 344 - ProxyHandler similar to FSHandler. 345 - WebSockets. See https://tools.ietf.org/html/rfc6455 . 346 - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . 347 348 ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- 349 `) 350 351 if ctx.PostForm("f1") != "value1" { 352 t.Fatalf("PostForm get Multipart Form data failed") 353 } 354 if ctx.PostForm("fileaaa") != "" { 355 t.Fatalf("PostForm should not get file") 356 } 357 358 ctx = makeCtxByReqString(t, `POST /upload HTTP/1.1 359 Host: localhost:10000 360 Content-Length: 11 361 Content-Type: application/x-www-form-urlencoded 362 363 hello=world`) 364 365 if ctx.PostForm("hello") != "world" { 366 t.Fatalf("PostForm get form failed") 367 } 368 } 369 370 func TestPostFormArray(t *testing.T) { 371 t.Parallel() 372 373 ctx := makeCtxByReqString(t, `POST /upload HTTP/1.1 374 Host: localhost:10000 375 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg 376 Content-Length: 521 377 378 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 379 Content-Disposition: form-data; name="tag" 380 381 red 382 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 383 Content-Disposition: form-data; name="tag" 384 385 green 386 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 387 Content-Disposition: form-data; name="tag" 388 389 blue 390 ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- 391 `) 392 assert.DeepEqual(t, []string{"red", "green", "blue"}, ctx.PostFormArray("tag")) 393 394 ctx = makeCtxByReqString(t, `POST /upload HTTP/1.1 395 Host: localhost:10000 396 Content-Type: application/x-www-form-urlencoded; charset=UTF-8 397 Content-Length: 26 398 399 tag=red&tag=green&tag=blue 400 `) 401 assert.DeepEqual(t, []string{"red", "green", "blue"}, ctx.PostFormArray("tag")) 402 } 403 404 func TestDefaultPostForm(t *testing.T) { 405 ctx := makeCtxByReqString(t, `POST /upload HTTP/1.1 406 Host: localhost:10000 407 Content-Length: 521 408 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg 409 410 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 411 Content-Disposition: form-data; name="f1" 412 413 value1 414 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 415 Content-Disposition: form-data; name="fileaaa"; filename="TODO" 416 Content-Type: application/octet-stream 417 418 - SessionClient with referer and cookies support. 419 - Client with requests' pipelining support. 420 - ProxyHandler similar to FSHandler. 421 - WebSockets. See https://tools.ietf.org/html/rfc6455 . 422 - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . 423 424 ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- 425 `) 426 427 val := ctx.DefaultPostForm("f1", "no val") 428 assert.DeepEqual(t, "value1", val) 429 430 val = ctx.DefaultPostForm("f99", "no val") 431 assert.DeepEqual(t, "no val", val) 432 } 433 434 func TestRequestContext_FormFile(t *testing.T) { 435 t.Parallel() 436 437 s := `POST /upload HTTP/1.1 438 Host: localhost:10000 439 Content-Length: 521 440 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg 441 442 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 443 Content-Disposition: form-data; name="f1" 444 445 value1 446 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 447 Content-Disposition: form-data; name="fileaaa"; filename="TODO" 448 Content-Type: application/octet-stream 449 450 - SessionClient with referer and cookies support. 451 - Client with requests' pipelining support. 452 - ProxyHandler similar to FSHandler. 453 - WebSockets. See https://tools.ietf.org/html/rfc6455 . 454 - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . 455 456 ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- 457 tailfoobar` 458 459 mr := mock.NewZeroCopyReader(s) 460 461 ctx := NewContext(0) 462 if err := req.Read(&ctx.Request, mr); err != nil { 463 t.Fatalf("unexpected error: %s", err) 464 } 465 tail, err := ioutil.ReadAll(mr) 466 if err != nil { 467 t.Fatalf("unexpected error: %s", err) 468 } 469 if string(tail) != "tailfoobar" { 470 t.Fatalf("unexpected tail %q. Expecting %q", tail, "tailfoobar") 471 } 472 473 f, err := ctx.MultipartForm() 474 if err != nil { 475 t.Fatalf("unexpected error: %s", err) 476 } 477 defer ctx.Request.RemoveMultipartFormFiles() 478 479 // verify files 480 if len(f.File) != 1 { 481 t.Fatalf("unexpected number of file values in multipart form: %d. Expecting 1", len(f.File)) 482 } 483 for k, vv := range f.File { 484 if k != "fileaaa" { 485 t.Fatalf("unexpected file value name %q. Expecting %q", k, "fileaaa") 486 } 487 if len(vv) != 1 { 488 t.Fatalf("unexpected number of file values %d. Expecting 1", len(vv)) 489 } 490 v := vv[0] 491 if v.Filename != "TODO" { 492 t.Fatalf("unexpected filename %q. Expecting %q", v.Filename, "TODO") 493 } 494 ct := v.Header.Get("Content-Type") 495 if ct != consts.MIMEApplicationOctetStream { 496 t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream") 497 } 498 } 499 500 err = ctx.SaveUploadedFile(f.File["fileaaa"][0], "TODO") 501 assert.Nil(t, err) 502 fileInfo, err := os.Stat("TODO") 503 assert.Nil(t, err) 504 assert.DeepEqual(t, "TODO", fileInfo.Name()) 505 assert.DeepEqual(t, f.File["fileaaa"][0].Size, fileInfo.Size()) 506 err = os.Remove("TODO") 507 assert.Nil(t, err) 508 509 ff, err := ctx.FormFile("fileaaa") 510 if err != nil || ff == nil { 511 t.Fatalf("unexpected error happened when ctx.FormFile()") 512 } 513 514 buf := make([]byte, ff.Size) 515 fff, _ := ff.Open() 516 fff.Read(buf) 517 518 if !strings.Contains(string(buf), "- SessionClient") { 519 t.Fatalf("unexpected file content. Expecting %q", "- SessionClient") 520 } 521 522 if !strings.Contains(string(buf), "rfc7540 .") { 523 t.Fatalf("unexpected file content. Expecting %q", "rfc7540 .") 524 } 525 } 526 527 func TestContextRenderFileFromFS(t *testing.T) { 528 t.Parallel() 529 530 ctx := NewContext(0) 531 var req protocol.Request 532 req.Header.SetMethod(consts.MethodGet) 533 req.SetRequestURI("/some/path") 534 req.CopyTo(&ctx.Request) 535 536 ctx.FileFromFS("./fs.go", &FS{ 537 Root: ".", 538 IndexNames: nil, 539 GenerateIndexPages: false, 540 AcceptByteRange: true, 541 }) 542 543 assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) 544 assert.True(t, strings.Contains(resp.GetHTTP1Response(&ctx.Response).String(), "func (fs *FS) initRequestHandler() {")) 545 // when Go version <= 1.16, mime.TypeByExtension will return Content-Type='text/plain; charset=utf-8', 546 // otherwise it will return Content-Type='text/x-go; charset=utf-8' 547 assert.NotEqual(t, "", string(ctx.Response.Header.Peek("Content-Type"))) 548 assert.DeepEqual(t, "/some/path", string(ctx.Request.URI().Path())) 549 } 550 551 func TestContextRenderFile(t *testing.T) { 552 t.Parallel() 553 554 ctx := NewContext(0) 555 var req protocol.Request 556 req.Header.SetMethod(consts.MethodGet) 557 req.SetRequestURI("/") 558 req.CopyTo(&ctx.Request) 559 560 ctx.File("./fs.go") 561 562 assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) 563 assert.True(t, strings.Contains(resp.GetHTTP1Response(&ctx.Response).String(), "func (fs *FS) initRequestHandler() {")) 564 // when Go version <= 1.16, mime.TypeByExtension will return Content-Type='text/plain; charset=utf-8', 565 // otherwise it will return Content-Type='text/x-go; charset=utf-8' 566 assert.NotEqual(t, "", string(ctx.Response.Header.Peek("Content-Type"))) 567 } 568 569 func TestContextRenderAttachment(t *testing.T) { 570 t.Parallel() 571 572 ctx := NewContext(0) 573 var req protocol.Request 574 req.Header.SetMethod(consts.MethodGet) 575 req.SetRequestURI("/") 576 req.CopyTo(&ctx.Request) 577 newFilename := "new_filename.go" 578 579 ctx.FileAttachment("./context.go", newFilename) 580 581 assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) 582 assert.True(t, strings.Contains(resp.GetHTTP1Response(&ctx.Response).String(), 583 "func (ctx *RequestContext) FileAttachment(filepath, filename string) {")) 584 assert.DeepEqual(t, fmt.Sprintf("attachment; filename=\"%s\"", newFilename), 585 string(ctx.Response.Header.Peek("Content-Disposition"))) 586 } 587 588 func TestRequestContext_Header(t *testing.T) { 589 c := NewContext(0) 590 591 c.Header("header_key", "header_val") 592 val := string(c.Response.Header.Peek("header_key")) 593 if val != "header_val" { 594 t.Fatalf("unexpected %q. Expecting %q", val, "header_val") 595 } 596 597 c.Response.Header.Del("header_key") 598 val = string(c.Response.Header.Peek("header_key")) 599 if val != "" { 600 t.Fatalf("unexpected %q. Expecting %q", val, "") 601 } 602 603 c.Header("header_key1", "header_val1") 604 c.Header("header_key1", "") 605 val = string(c.Response.Header.Peek("header_key1")) 606 if val != "" { 607 t.Fatalf("unexpected %q. Expecting %q", val, "") 608 } 609 } 610 611 func TestRequestContext_Keys(t *testing.T) { 612 c := NewContext(0) 613 rightVal := "123" 614 c.Set("key", rightVal) 615 val := c.GetString("key") 616 if val != rightVal { 617 t.Fatalf("unexpected %v. Expecting %v", val, rightVal) 618 } 619 } 620 621 func testFunc(c context.Context, ctx *RequestContext) { 622 ctx.Next(c) 623 } 624 625 func testFunc2(c context.Context, ctx *RequestContext) { 626 ctx.Set("key", "123") 627 } 628 629 func TestRequestContext_Handler(t *testing.T) { 630 c := NewContext(0) 631 c.handlers = HandlersChain{testFunc, testFunc2} 632 633 c.Handler()(context.Background(), c) 634 val := c.GetString("key") 635 if val != "123" { 636 t.Fatalf("unexpected %v. Expecting %v", val, "123") 637 } 638 639 c.handlers = nil 640 handler := c.Handler() 641 assert.Nil(t, handler) 642 } 643 644 func TestRequestContext_Handlers(t *testing.T) { 645 c := NewContext(0) 646 hc := HandlersChain{testFunc, testFunc2} 647 c.SetHandlers(hc) 648 c.Handlers()[1](context.Background(), c) 649 val := c.GetString("key") 650 if val != "123" { 651 t.Fatalf("unexpected %v. Expecting %v", val, "123") 652 } 653 } 654 655 func TestRequestContext_HandlerName(t *testing.T) { 656 c := NewContext(0) 657 c.handlers = HandlersChain{testFunc, testFunc2} 658 val := c.HandlerName() 659 if val != "github.com/cloudwego/hertz/pkg/app.testFunc2" { 660 t.Fatalf("unexpected %v. Expecting %v", val, "github.com/cloudwego/hertz.testFunc2") 661 } 662 } 663 664 func TestNext(t *testing.T) { 665 c := NewContext(0) 666 a := 0 667 668 testFunc1 := func(c context.Context, ctx *RequestContext) { 669 a = 1 670 } 671 testFunc3 := func(c context.Context, ctx *RequestContext) { 672 a = 3 673 } 674 c.handlers = HandlersChain{testFunc1, testFunc3} 675 676 c.Next(context.Background()) 677 678 assert.True(t, c.index == 2) 679 assert.DeepEqual(t, 3, a) 680 } 681 682 func TestContextError(t *testing.T) { 683 c := NewContext(0) 684 assert.Nil(t, c.Errors) 685 686 firstErr := errors.New("first error") 687 c.Error(firstErr) // nolint: errcheck 688 assert.DeepEqual(t, 1, len(c.Errors)) 689 assert.DeepEqual(t, "Error #01: first error\n", c.Errors.String()) 690 691 secondErr := errors.New("second error") 692 c.Error(&errs.Error{ // nolint: errcheck 693 Err: secondErr, 694 Meta: "some data 2", 695 Type: errs.ErrorTypePublic, 696 }) 697 assert.DeepEqual(t, 2, len(c.Errors)) 698 699 assert.DeepEqual(t, firstErr, c.Errors[0].Err) 700 assert.Nil(t, c.Errors[0].Meta) 701 assert.DeepEqual(t, errs.ErrorTypePrivate, c.Errors[0].Type) 702 703 assert.DeepEqual(t, secondErr, c.Errors[1].Err) 704 assert.DeepEqual(t, "some data 2", c.Errors[1].Meta) 705 assert.DeepEqual(t, errs.ErrorTypePublic, c.Errors[1].Type) 706 707 assert.DeepEqual(t, c.Errors.Last(), c.Errors[1]) 708 709 defer func() { 710 if recover() == nil { 711 t.Error("didn't panic") 712 } 713 }() 714 c.Error(nil) // nolint: errcheck 715 } 716 717 func TestContextAbortWithError(t *testing.T) { 718 c := NewContext(0) 719 720 c.AbortWithError(consts.StatusUnauthorized, errors.New("bad input")).SetMeta("some input") // nolint: errcheck 721 722 assert.DeepEqual(t, consts.StatusUnauthorized, c.Response.StatusCode()) 723 assert.DeepEqual(t, con.AbortIndex, c.index) 724 assert.True(t, c.IsAborted()) 725 } 726 727 func TestRender(t *testing.T) { 728 c := NewContext(0) 729 730 c.Render(consts.StatusOK, &render.Data{ 731 ContentType: consts.MIMEApplicationJSONUTF8, 732 Data: []byte("{\"test\":1}"), 733 }) 734 735 assert.DeepEqual(t, consts.StatusOK, c.Response.StatusCode()) 736 assert.True(t, strings.Contains(string(c.Response.Body()), "test")) 737 738 c.Reset() 739 c.Render(110, &render.Data{ 740 ContentType: "application/json; charset=utf-8", 741 Data: []byte("{\"test\":1}"), 742 }) 743 assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) 744 assert.DeepEqual(t, "", string(c.Response.Body())) 745 746 c.Reset() 747 c.Render(consts.StatusNoContent, &render.Data{ 748 ContentType: "application/json; charset=utf-8", 749 Data: []byte("{\"test\":1}"), 750 }) 751 assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) 752 assert.DeepEqual(t, "", string(c.Response.Body())) 753 754 c.Reset() 755 c.Render(consts.StatusNotModified, &render.Data{ 756 ContentType: "application/json; charset=utf-8", 757 Data: []byte("{\"test\":1}"), 758 }) 759 assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) 760 assert.DeepEqual(t, "", string(c.Response.Body())) 761 } 762 763 func TestHTML(t *testing.T) { 764 c := NewContext(0) 765 766 tmpl := template.Must(template.New(""). 767 Delims("{[{", "}]}"). 768 Funcs(template.FuncMap{}). 769 ParseFiles("../common/testdata/template/index.tmpl")) 770 771 r := &render.HTMLProduction{Template: tmpl} 772 c.HTMLRender = r 773 c.HTML(consts.StatusOK, "index.tmpl", utils.H{"title": "Main website"}) 774 775 assert.DeepEqual(t, []byte("text/html; charset=utf-8"), c.Response.Header.Peek("Content-Type")) 776 assert.DeepEqual(t, []byte("<html><h1>Main website</h1></html>"), c.Response.Body()) 777 } 778 779 type xmlmap map[string]interface{} 780 781 // Allows type H to be used with xml.Marshal 782 func (h xmlmap) MarshalXML(e *xml.Encoder, start xml.StartElement) error { 783 start.Name = xml.Name{ 784 Space: "", 785 Local: "map", 786 } 787 if err := e.EncodeToken(start); err != nil { 788 return err 789 } 790 for key, value := range h { 791 elem := xml.StartElement{ 792 Name: xml.Name{Space: "", Local: key}, 793 Attr: []xml.Attr{}, 794 } 795 if err := e.EncodeElement(value, elem); err != nil { 796 return err 797 } 798 } 799 800 return e.EncodeToken(xml.EndElement{Name: start.Name}) 801 } 802 803 func TestXML(t *testing.T) { 804 c := NewContext(0) 805 c.XML(consts.StatusOK, xmlmap{"foo": "bar"}) 806 assert.DeepEqual(t, []byte("<map><foo>bar</foo></map>"), c.Response.Body()) 807 assert.DeepEqual(t, []byte("application/xml; charset=utf-8"), c.Response.Header.Peek("Content-Type")) 808 } 809 810 func TestJSON(t *testing.T) { 811 c := NewContext(0) 812 c.JSON(consts.StatusOK, "test") 813 assert.DeepEqual(t, consts.StatusOK, c.Response.StatusCode()) 814 assert.True(t, strings.Contains(string(c.Response.Body()), "test")) 815 } 816 817 func TestDATA(t *testing.T) { 818 c := NewContext(0) 819 c.Data(consts.StatusOK, "application/json; charset=utf-8", []byte("{\"test\":1}")) 820 assert.DeepEqual(t, consts.StatusOK, c.Response.StatusCode()) 821 assert.True(t, strings.Contains(string(c.Response.Body()), "test")) 822 } 823 824 func TestContextReset(t *testing.T) { 825 c := NewContext(0) 826 827 c.index = 2 828 c.Params = param.Params{param.Param{}} 829 c.Error(errors.New("test")) // nolint: errcheck 830 c.Set("foo", "bar") 831 c.Finished() 832 c.Request.SetIsTLS(true) 833 c.ResetWithoutConn() 834 c.Request.URI() 835 assert.DeepEqual(t, "https", string(c.Request.Scheme())) 836 assert.False(t, c.IsAborted()) 837 assert.DeepEqual(t, 0, len(c.Errors)) 838 assert.Nil(t, c.Errors.Errors()) 839 assert.Nil(t, c.Errors.ByType(errs.ErrorTypeAny)) 840 assert.DeepEqual(t, 0, len(c.Params)) 841 assert.DeepEqual(t, int8(-1), c.index) 842 assert.Nil(t, c.finished) 843 } 844 845 func TestContextContentType(t *testing.T) { 846 c := NewContext(0) 847 c.Request.Header.Set("Content-Type", consts.MIMEApplicationJSONUTF8) 848 assert.DeepEqual(t, consts.MIMEApplicationJSONUTF8, bytesconv.B2s(c.ContentType())) 849 } 850 851 type MockIpConn struct { 852 *mock.Conn 853 RemoteIp string 854 Port int 855 } 856 857 func (c *MockIpConn) RemoteAddr() net.Addr { 858 return &net.UDPAddr{ 859 IP: net.ParseIP(c.RemoteIp), 860 Port: c.Port, 861 } 862 } 863 864 func newContextClientIPTest() *RequestContext { 865 c := NewContext(0) 866 c.conn = &MockIpConn{ 867 Conn: mock.NewConn(""), 868 RemoteIp: "127.0.0.1", 869 Port: 8080, 870 } 871 c.Request.Header.Set("X-Real-IP", " 10.10.10.10 ") 872 c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20, 30.30.30.30") 873 return c 874 } 875 876 func TestClientIp(t *testing.T) { 877 c := newContextClientIPTest() 878 // default X-Forwarded-For and X-Real-IP behaviour 879 assert.DeepEqual(t, "20.20.20.20", c.ClientIP()) 880 881 c.Request.Header.DelBytes([]byte("X-Forwarded-For")) 882 assert.DeepEqual(t, "10.10.10.10", c.ClientIP()) 883 884 c.Request.Header.Set("X-Forwarded-For", "30.30.30.30 ") 885 assert.DeepEqual(t, "30.30.30.30", c.ClientIP()) 886 887 // No trusted CIDRS 888 c = newContextClientIPTest() 889 opts := ClientIPOptions{ 890 RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, 891 TrustedCIDRs: nil, 892 } 893 c.SetClientIPFunc(ClientIPWithOption(opts)) 894 assert.DeepEqual(t, "127.0.0.1", c.ClientIP()) 895 896 _, cidr, _ := net.ParseCIDR("30.30.30.30/32") 897 opts = ClientIPOptions{ 898 RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, 899 TrustedCIDRs: []*net.IPNet{cidr}, 900 } 901 c.SetClientIPFunc(ClientIPWithOption(opts)) 902 assert.DeepEqual(t, "127.0.0.1", c.ClientIP()) 903 904 _, cidr, _ = net.ParseCIDR("127.0.0.1/32") 905 opts = ClientIPOptions{ 906 RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, 907 TrustedCIDRs: []*net.IPNet{cidr}, 908 } 909 c.SetClientIPFunc(ClientIPWithOption(opts)) 910 assert.DeepEqual(t, "30.30.30.30", c.ClientIP()) 911 } 912 913 func TestSetClientIPFunc(t *testing.T) { 914 fn := func(ctx *RequestContext) string { 915 return "" 916 } 917 SetClientIPFunc(fn) 918 assert.DeepEqual(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(defaultClientIP).Pointer()) 919 } 920 921 type mockValidator struct{} 922 923 func (m *mockValidator) ValidateStruct(interface{}) error { 924 return fmt.Errorf("test mock") 925 } 926 927 func (m *mockValidator) Engine() interface{} { 928 return nil 929 } 930 931 func (m *mockValidator) ValidateTag() string { 932 return "vt" 933 } 934 935 func TestSetValidator(t *testing.T) { 936 m := &mockValidator{} 937 c := NewContext(0) 938 c.SetValidator(m) 939 c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{Validator: m})) 940 type User struct { 941 Age int `vt:"$>=0&&$<=130"` 942 } 943 944 user := &User{ 945 Age: 135, 946 } 947 err := c.Validate(user) 948 if err == nil { 949 t.Fatalf("expected an error, but got nil") 950 } 951 assert.DeepEqual(t, "test mock", err.Error()) 952 } 953 954 func TestGetQuery(t *testing.T) { 955 c := NewContext(0) 956 c.Request.SetRequestURI("http://aaa.com?a=1&b=") 957 v, exists := c.GetQuery("b") 958 assert.DeepEqual(t, "", v) 959 assert.DeepEqual(t, true, exists) 960 } 961 962 func TestGetPostForm(t *testing.T) { 963 c := NewContext(0) 964 c.Request.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) 965 c.Request.SetBodyString("a=1&b=") 966 v, exists := c.GetPostForm("b") 967 assert.DeepEqual(t, "", v) 968 assert.DeepEqual(t, true, exists) 969 } 970 971 func TestGetPostFormArray(t *testing.T) { 972 c := NewContext(0) 973 c.Request.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) 974 c.Request.SetBodyString("a=1&b=2&b=3") 975 v, _ := c.GetPostFormArray("b") 976 assert.DeepEqual(t, []string{"2", "3"}, v) 977 } 978 979 func TestRemoteAddr(t *testing.T) { 980 c := NewContext(0) 981 c.Request.SetRequestURI("http://aaa.com?a=1&b=") 982 addr := c.RemoteAddr().String() 983 assert.DeepEqual(t, "0.0.0.0:0", addr) 984 } 985 986 func TestRequestBodyStream(t *testing.T) { 987 c := NewContext(0) 988 s := "testRequestBodyStream" 989 mr := bytes.NewBufferString(s) 990 c.Request.SetBodyStream(mr, -1) 991 data, err := ioutil.ReadAll(c.RequestBodyStream()) 992 assert.Nil(t, err) 993 assert.DeepEqual(t, "testRequestBodyStream", string(data)) 994 } 995 996 func TestContextIsAborted(t *testing.T) { 997 ctx := NewContext(0) 998 assert.False(t, ctx.IsAborted()) 999 1000 ctx.Abort() 1001 assert.True(t, ctx.IsAborted()) 1002 1003 ctx.Next(context.Background()) 1004 assert.True(t, ctx.IsAborted()) 1005 1006 ctx.index++ 1007 assert.True(t, ctx.IsAborted()) 1008 } 1009 1010 func TestContextAbortWithStatus(t *testing.T) { 1011 c := NewContext(0) 1012 1013 c.index = 4 1014 c.AbortWithStatus(consts.StatusUnauthorized) 1015 1016 assert.DeepEqual(t, con.AbortIndex, c.index) 1017 assert.DeepEqual(t, consts.StatusUnauthorized, c.Response.Header.StatusCode()) 1018 assert.True(t, c.IsAborted()) 1019 } 1020 1021 type testJSONAbortMsg struct { 1022 Foo string `json:"foo"` 1023 Bar string `json:"bar"` 1024 } 1025 1026 func TestContextAbortWithStatusJSON(t *testing.T) { 1027 c := NewContext(0) 1028 c.index = 4 1029 1030 in := new(testJSONAbortMsg) 1031 in.Bar = "barValue" 1032 in.Foo = "fooValue" 1033 1034 c.AbortWithStatusJSON(consts.StatusUnsupportedMediaType, in) 1035 1036 assert.DeepEqual(t, con.AbortIndex, c.index) 1037 assert.DeepEqual(t, consts.StatusUnsupportedMediaType, c.Response.Header.StatusCode()) 1038 assert.True(t, c.IsAborted()) 1039 1040 contentType := c.Response.Header.Peek("Content-Type") 1041 assert.DeepEqual(t, consts.MIMEApplicationJSONUTF8, string(contentType)) 1042 1043 jsonStringBody := c.Response.Body() 1044 assert.DeepEqual(t, "{\"foo\":\"fooValue\",\"bar\":\"barValue\"}", string(jsonStringBody)) 1045 } 1046 1047 func TestRequestCtxFormValue(t *testing.T) { 1048 ctx := NewContext(0) 1049 ctx.Request.SetRequestURI("/foo/bar?baz=123&aaa=bbb") 1050 ctx.Request.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) 1051 ctx.Request.SetBodyString("qqq=port&mmm=sddd") 1052 1053 v := ctx.FormValue("baz") 1054 if string(v) != "123" { 1055 t.Fatalf("unexpected value %q. Expecting %q", v, "123") 1056 } 1057 v = ctx.FormValue("mmm") 1058 if string(v) != "sddd" { 1059 t.Fatalf("unexpected value %q. Expecting %q", v, "sddd") 1060 } 1061 v = ctx.FormValue("aaaasdfsdf") 1062 if len(v) > 0 { 1063 t.Fatalf("unexpected value for unknown key %q", v) 1064 } 1065 ctx.Request.Reset() 1066 ctx.Request.SetFormData(map[string]string{ 1067 "a": "1", 1068 }) 1069 v = ctx.FormValue("a") 1070 if string(v) != "1" { 1071 t.Fatalf("unexpected value %q. Expecting %q", v, "1") 1072 } 1073 1074 ctx.Request.Reset() 1075 s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg 1076 Content-Disposition: form-data; name="f" 1077 1078 fff 1079 ------WebKitFormBoundaryJwfATyF8tmxSJnLg 1080 ` 1081 mr := bytes.NewBufferString(s) 1082 ctx.Request.SetBodyStream(mr, -1) 1083 ctx.Request.Header.SetContentLength(len(s)) 1084 ctx.Request.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) 1085 1086 v = ctx.FormValue("f") 1087 if string(v) != "fff" { 1088 t.Fatalf("unexpected value %q. Expecting %q", v, "fff") 1089 } 1090 } 1091 1092 func TestSetCustomFormValueFunc(t *testing.T) { 1093 ctx := NewContext(0) 1094 ctx.Request.SetRequestURI("/foo/bar?aaa=bbb") 1095 ctx.Request.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) 1096 ctx.Request.SetBodyString("aaa=port") 1097 1098 ctx.SetFormValueFunc(func(ctx *RequestContext, key string) []byte { 1099 v := ctx.PostArgs().Peek(key) 1100 if len(v) > 0 { 1101 return v 1102 } 1103 mf, err := ctx.MultipartForm() 1104 if err == nil && mf.Value != nil { 1105 vv := mf.Value[key] 1106 if len(vv) > 0 { 1107 return []byte(vv[0]) 1108 } 1109 } 1110 v = ctx.QueryArgs().Peek(key) 1111 if len(v) > 0 { 1112 return v 1113 } 1114 return nil 1115 }) 1116 1117 v := ctx.FormValue("aaa") 1118 if string(v) != "port" { 1119 t.Fatalf("unexpected value %q. Expecting %q", v, "port") 1120 } 1121 } 1122 1123 func TestContextSetGet(t *testing.T) { 1124 c := &RequestContext{} 1125 c.Set("foo", "bar") 1126 1127 value, err := c.Get("foo") 1128 assert.DeepEqual(t, "bar", value) 1129 assert.True(t, err) 1130 1131 value, err = c.Get("foo2") 1132 assert.Nil(t, value) 1133 assert.False(t, err) 1134 1135 assert.DeepEqual(t, "bar", c.MustGet("foo")) 1136 assert.Panic(t, func() { c.MustGet("no_exist") }) 1137 } 1138 1139 func TestContextSetGetValues(t *testing.T) { 1140 c := &RequestContext{} 1141 c.Set("string", "this is a string") 1142 c.Set("int32", int32(-42)) 1143 c.Set("int64", int64(42424242424242)) 1144 c.Set("uint32", uint32(42)) 1145 c.Set("uint64", uint64(42424242424242)) 1146 c.Set("float32", float32(4.2)) 1147 c.Set("float64", 4.2) 1148 var a interface{} = 1 1149 c.Set("intInterface", a) 1150 1151 assert.DeepEqual(t, c.MustGet("string").(string), "this is a string") 1152 assert.DeepEqual(t, c.MustGet("int32").(int32), int32(-42)) 1153 assert.DeepEqual(t, c.MustGet("int64").(int64), int64(42424242424242)) 1154 assert.DeepEqual(t, c.MustGet("uint32").(uint32), uint32(42)) 1155 assert.DeepEqual(t, c.MustGet("uint64").(uint64), uint64(42424242424242)) 1156 assert.DeepEqual(t, c.MustGet("float32").(float32), float32(4.2)) 1157 assert.DeepEqual(t, c.MustGet("float64").(float64), 4.2) 1158 assert.DeepEqual(t, c.MustGet("intInterface").(int), 1) 1159 } 1160 1161 func TestContextGetString(t *testing.T) { 1162 c := &RequestContext{} 1163 c.Set("string", "this is a string") 1164 assert.DeepEqual(t, "this is a string", c.GetString("string")) 1165 c.Set("bool", false) 1166 assert.DeepEqual(t, "", c.GetString("bool")) 1167 } 1168 1169 func TestContextSetGetBool(t *testing.T) { 1170 c := &RequestContext{} 1171 c.Set("bool", true) 1172 assert.True(t, c.GetBool("bool")) 1173 c.Set("string", "this is a string") 1174 assert.False(t, c.GetBool("string")) 1175 } 1176 1177 func TestContextGetInt(t *testing.T) { 1178 c := &RequestContext{} 1179 c.Set("int", 1) 1180 assert.DeepEqual(t, 1, c.GetInt("int")) 1181 c.Set("string", "this is a string") 1182 assert.DeepEqual(t, 0, c.GetInt("string")) 1183 } 1184 1185 func TestContextGetInt32(t *testing.T) { 1186 c := &RequestContext{} 1187 c.Set("int32", int32(-42)) 1188 assert.DeepEqual(t, int32(-42), c.GetInt32("int32")) 1189 c.Set("string", "this is a string") 1190 assert.DeepEqual(t, int32(0), c.GetInt32("string")) 1191 } 1192 1193 func TestContextGetInt64(t *testing.T) { 1194 c := &RequestContext{} 1195 c.Set("int64", int64(42424242424242)) 1196 assert.DeepEqual(t, int64(42424242424242), c.GetInt64("int64")) 1197 c.Set("string", "this is a string") 1198 assert.DeepEqual(t, int64(0), c.GetInt64("string")) 1199 } 1200 1201 func TestContextGetUint(t *testing.T) { 1202 c := &RequestContext{} 1203 c.Set("uint", uint(1)) 1204 assert.DeepEqual(t, uint(1), c.GetUint("uint")) 1205 c.Set("string", "this is a string") 1206 assert.DeepEqual(t, uint(0), c.GetUint("string")) 1207 } 1208 1209 func TestContextGetUint32(t *testing.T) { 1210 c := &RequestContext{} 1211 c.Set("uint32", uint32(42)) 1212 assert.DeepEqual(t, uint32(42), c.GetUint32("uint32")) 1213 c.Set("string", "this is a string") 1214 assert.DeepEqual(t, uint32(0), c.GetUint32("string")) 1215 } 1216 1217 func TestContextGetUint64(t *testing.T) { 1218 c := &RequestContext{} 1219 c.Set("uint64", uint64(42424242424242)) 1220 assert.DeepEqual(t, uint64(42424242424242), c.GetUint64("uint64")) 1221 c.Set("string", "this is a string") 1222 assert.DeepEqual(t, uint64(0), c.GetUint64("string")) 1223 } 1224 1225 func TestContextGetFloat32(t *testing.T) { 1226 c := &RequestContext{} 1227 c.Set("float32", float32(4.2)) 1228 assert.DeepEqual(t, float32(4.2), c.GetFloat32("float32")) 1229 c.Set("string", "this is a string") 1230 assert.DeepEqual(t, float32(0.0), c.GetFloat32("string")) 1231 } 1232 1233 func TestContextGetFloat64(t *testing.T) { 1234 c := &RequestContext{} 1235 c.Set("float64", 4.2) 1236 assert.DeepEqual(t, 4.2, c.GetFloat64("float64")) 1237 c.Set("string", "this is a string") 1238 assert.DeepEqual(t, 0.0, c.GetFloat64("string")) 1239 } 1240 1241 func TestContextGetTime(t *testing.T) { 1242 c := &RequestContext{} 1243 t1, _ := time.Parse("1/2/2006 15:04:05", "01/01/2017 12:00:00") 1244 c.Set("time", t1) 1245 assert.DeepEqual(t, t1, c.GetTime("time")) 1246 c.Set("string", "this is a string") 1247 assert.DeepEqual(t, time.Time{}, c.GetTime("string")) 1248 } 1249 1250 func TestContextGetDuration(t *testing.T) { 1251 c := &RequestContext{} 1252 c.Set("duration", time.Second) 1253 assert.DeepEqual(t, time.Second, c.GetDuration("duration")) 1254 c.Set("string", "this is a string") 1255 assert.DeepEqual(t, time.Duration(0), c.GetDuration("string")) 1256 } 1257 1258 func TestContextGetStringSlice(t *testing.T) { 1259 c := &RequestContext{} 1260 c.Set("slice", []string{"foo"}) 1261 assert.DeepEqual(t, []string{"foo"}, c.GetStringSlice("slice")) 1262 c.Set("string", "this is a string") 1263 var expected []string 1264 assert.DeepEqual(t, expected, c.GetStringSlice("string")) 1265 } 1266 1267 func TestContextGetStringMap(t *testing.T) { 1268 c := &RequestContext{} 1269 m := make(map[string]interface{}) 1270 m["foo"] = 1 1271 c.Set("map", m) 1272 1273 assert.DeepEqual(t, m, c.GetStringMap("map")) 1274 assert.DeepEqual(t, 1, c.GetStringMap("map")["foo"]) 1275 1276 c.Set("string", "this is a string") 1277 var expected map[string]interface{} 1278 assert.DeepEqual(t, expected, c.GetStringMap("string")) 1279 } 1280 1281 func TestContextGetStringMapString(t *testing.T) { 1282 c := &RequestContext{} 1283 m := make(map[string]string) 1284 m["foo"] = "bar" 1285 c.Set("map", m) 1286 1287 assert.DeepEqual(t, m, c.GetStringMapString("map")) 1288 assert.DeepEqual(t, "bar", c.GetStringMapString("map")["foo"]) 1289 1290 c.Set("string", "this is a string") 1291 var expected map[string]string 1292 assert.DeepEqual(t, expected, c.GetStringMapString("string")) 1293 } 1294 1295 func TestContextGetStringMapStringSlice(t *testing.T) { 1296 c := &RequestContext{} 1297 m := make(map[string][]string) 1298 m["foo"] = []string{"foo"} 1299 c.Set("map", m) 1300 1301 assert.DeepEqual(t, m, c.GetStringMapStringSlice("map")) 1302 assert.DeepEqual(t, []string{"foo"}, c.GetStringMapStringSlice("map")["foo"]) 1303 1304 c.Set("string", "this is a string") 1305 var expected map[string][]string 1306 assert.DeepEqual(t, expected, c.GetStringMapStringSlice("string")) 1307 } 1308 1309 func TestContextTraceInfo(t *testing.T) { 1310 ctx := NewContext(0) 1311 traceIn := traceinfo.NewTraceInfo() 1312 ctx.SetTraceInfo(traceIn) 1313 traceOut := ctx.GetTraceInfo() 1314 1315 assert.DeepEqual(t, traceIn, traceOut) 1316 } 1317 1318 func TestEnableTrace(t *testing.T) { 1319 ctx := NewContext(0) 1320 ctx.SetEnableTrace(true) 1321 trace := ctx.IsEnableTrace() 1322 assert.True(t, trace) 1323 } 1324 1325 func TestForEachKey(t *testing.T) { 1326 ctx := NewContext(0) 1327 ctx.Set("1", "2") 1328 handle := func(k string, v interface{}) { 1329 res := k + v.(string) 1330 assert.DeepEqual(t, res, "12") 1331 } 1332 ctx.ForEachKey(handle) 1333 val, ok := ctx.Get("1") 1334 assert.DeepEqual(t, val, "2") 1335 assert.True(t, ok) 1336 } 1337 1338 func TestFlush(t *testing.T) { 1339 ctx := NewContext(0) 1340 err := ctx.Flush() 1341 assert.Nil(t, err) 1342 } 1343 1344 func TestConn(t *testing.T) { 1345 ctx := NewContext(0) 1346 1347 conn := mock.NewConn("") 1348 1349 ctx.SetConn(conn) 1350 connRes := ctx.GetConn() 1351 1352 val1 := reflect.ValueOf(conn).Pointer() 1353 val2 := reflect.ValueOf(connRes).Pointer() 1354 assert.DeepEqual(t, val1, val2) 1355 } 1356 1357 func TestHijackHandler(t *testing.T) { 1358 ctx := NewContext(0) 1359 handle := func(c network.Conn) { 1360 c.SetReadTimeout(time.Duration(1) * time.Second) 1361 } 1362 ctx.SetHijackHandler(handle) 1363 handleRes := ctx.GetHijackHandler() 1364 1365 val1 := reflect.ValueOf(handle).Pointer() 1366 val2 := reflect.ValueOf(handleRes).Pointer() 1367 assert.DeepEqual(t, val1, val2) 1368 } 1369 1370 func TestGetReader(t *testing.T) { 1371 ctx := NewContext(0) 1372 1373 conn := mock.NewConn("") 1374 1375 ctx.SetConn(conn) 1376 connRes := ctx.GetReader() 1377 1378 val1 := reflect.ValueOf(conn).Pointer() 1379 val2 := reflect.ValueOf(connRes).Pointer() 1380 assert.DeepEqual(t, val1, val2) 1381 } 1382 1383 func TestGetWriter(t *testing.T) { 1384 ctx := NewContext(0) 1385 1386 conn := mock.NewConn("") 1387 1388 ctx.SetConn(conn) 1389 connRes := ctx.GetWriter() 1390 1391 val1 := reflect.ValueOf(conn).Pointer() 1392 val2 := reflect.ValueOf(connRes).Pointer() 1393 assert.DeepEqual(t, val1, val2) 1394 } 1395 1396 func TestIndex(t *testing.T) { 1397 ctx := NewContext(0) 1398 ctx.ResetWithoutConn() 1399 exc := int8(-1) 1400 res := ctx.GetIndex() 1401 assert.DeepEqual(t, exc, res) 1402 ctx.SetIndex(int8(1)) 1403 res = ctx.GetIndex() 1404 exc = int8(1) 1405 assert.DeepEqual(t, exc, res) 1406 } 1407 1408 func TestConcurrentHandlerName(t *testing.T) { 1409 SetConcurrentHandlerNameOperator() 1410 defer SetHandlerNameOperator(&inbuiltHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}}) 1411 h := func(c context.Context, ctx *RequestContext) {} 1412 SetHandlerName(h, "test1") 1413 for i := 0; i < 50; i++ { 1414 go func() { 1415 name := GetHandlerName(h) 1416 assert.DeepEqual(t, "test1", name) 1417 }() 1418 } 1419 1420 time.Sleep(time.Second) 1421 1422 go func() { 1423 SetHandlerName(h, "test2") 1424 }() 1425 1426 time.Sleep(time.Second) 1427 1428 name := GetHandlerName(h) 1429 assert.DeepEqual(t, "test2", name) 1430 } 1431 1432 func TestHandlerName(t *testing.T) { 1433 h := func(c context.Context, ctx *RequestContext) {} 1434 SetHandlerName(h, "test1") 1435 name := GetHandlerName(h) 1436 assert.DeepEqual(t, "test1", name) 1437 } 1438 1439 func TestHijack(t *testing.T) { 1440 ctx := NewContext(0) 1441 h := func(c network.Conn) {} 1442 ctx.Hijack(h) 1443 assert.True(t, ctx.Hijacked()) 1444 } 1445 1446 func TestFinished(t *testing.T) { 1447 ctx := NewContext(0) 1448 ctx.Finished() 1449 1450 ch := make(chan struct{}) 1451 ctx.finished = ch 1452 chRes := ctx.Finished() 1453 1454 send := func() { 1455 time.Sleep(time.Duration(1) * time.Millisecond) 1456 ch <- struct{}{} 1457 } 1458 go send() 1459 val := <-chRes 1460 assert.DeepEqual(t, struct{}{}, val) 1461 } 1462 1463 func TestString(t *testing.T) { 1464 ctx := NewContext(0) 1465 ctx.String(consts.StatusOK, "ok") 1466 assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) 1467 } 1468 1469 func TestFullPath(t *testing.T) { 1470 ctx := NewContext(0) 1471 str := "/hello" 1472 ctx.SetFullPath(str) 1473 val := ctx.FullPath() 1474 assert.DeepEqual(t, str, val) 1475 } 1476 1477 func TestReset(t *testing.T) { 1478 ctx := NewContext(0) 1479 ctx.Reset() 1480 assert.DeepEqual(t, nil, ctx.conn) 1481 } 1482 1483 // func TestParam(t *testing.T) { 1484 // ctx := NewContext(0) 1485 // val := ctx.Param("/user/john") 1486 // assert.DeepEqual(t, "john", val) 1487 // } 1488 1489 func TestGetHeader(t *testing.T) { 1490 ctx := NewContext(0) 1491 ctx.Request.Header.SetContentTypeBytes([]byte(consts.MIMETextPlainUTF8)) 1492 val := ctx.GetHeader("Content-Type") 1493 assert.DeepEqual(t, consts.MIMETextPlainUTF8, string(val)) 1494 } 1495 1496 func TestGetRawData(t *testing.T) { 1497 ctx := NewContext(0) 1498 ctx.Request.SetBody([]byte("hello")) 1499 val := ctx.GetRawData() 1500 assert.DeepEqual(t, "hello", string(val)) 1501 1502 val2, err := ctx.Body() 1503 assert.DeepEqual(t, val, val2) 1504 assert.Nil(t, err) 1505 } 1506 1507 func TestRequestContext_GetRequest(t *testing.T) { 1508 c := &RequestContext{} 1509 c.Request.Header.Set("key1", "value1") 1510 c.Request.SetBody([]byte("test body")) 1511 req := c.GetRequest() 1512 if req.Header.Get("key1") != "value1" { 1513 t.Fatal("should have header: key1:value1") 1514 } 1515 if string(req.Body()) != "test body" { 1516 t.Fatal("should have body: test body") 1517 } 1518 } 1519 1520 func TestRequestContext_GetResponse(t *testing.T) { 1521 c := &RequestContext{} 1522 c.Response.Header.Set("key1", "value1") 1523 c.Response.SetBody([]byte("test body")) 1524 resp := c.GetResponse() 1525 if resp.Header.Get("key1") != "value1" { 1526 t.Fatal("should have header: key1:value1") 1527 } 1528 if string(resp.Body()) != "test body" { 1529 t.Fatal("should have body: test body") 1530 } 1531 } 1532 1533 func TestBindAndValidate(t *testing.T) { 1534 type Test struct { 1535 A string `query:"a"` 1536 B int `query:"b" vd:"$>10"` 1537 } 1538 1539 c := &RequestContext{} 1540 c.Request.SetRequestURI("/foo/bar?a=123&b=11") 1541 1542 var req Test 1543 err := c.BindAndValidate(&req) 1544 if err != nil { 1545 t.Fatalf("unexpected error: %v", err) 1546 } 1547 assert.DeepEqual(t, "123", req.A) 1548 assert.DeepEqual(t, 11, req.B) 1549 1550 c.Request.URI().Reset() 1551 c.Request.SetRequestURI("/foo/bar?a=123&b=9") 1552 req = Test{} 1553 err = c.BindAndValidate(&req) 1554 if err == nil { 1555 t.Fatalf("unexpected nil, expected an error") 1556 } 1557 1558 c.Request.URI().Reset() 1559 c.Request.SetRequestURI("/foo/bar?a=123&b=9") 1560 req = Test{} 1561 err = c.Bind(&req) 1562 if err != nil { 1563 t.Fatalf("unexpected error: %v", err) 1564 } 1565 assert.DeepEqual(t, "123", req.A) 1566 assert.DeepEqual(t, 9, req.B) 1567 1568 err = c.Validate(&req) 1569 if err == nil { 1570 t.Fatalf("unexpected nil, expected an error") 1571 } 1572 } 1573 1574 func TestBindForm(t *testing.T) { 1575 type Test struct { 1576 A string 1577 B int 1578 } 1579 1580 c := &RequestContext{} 1581 c.Request.SetRequestURI("/foo/bar?a=123&b=11") 1582 c.Request.SetBody([]byte("A=123&B=11")) 1583 c.Request.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) 1584 1585 var req Test 1586 err := c.BindForm(&req) 1587 if err != nil { 1588 t.Fatalf("unexpected error: %v", err) 1589 } 1590 assert.DeepEqual(t, "123", req.A) 1591 assert.DeepEqual(t, 11, req.B) 1592 1593 c.Request.SetBody([]byte("")) 1594 err = c.BindForm(&req) 1595 if err == nil { 1596 t.Fatalf("expected error, but get nil") 1597 } 1598 } 1599 1600 type mockBinder struct{} 1601 1602 func (m *mockBinder) Name() string { 1603 return "test binder" 1604 } 1605 1606 func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { 1607 return nil 1608 } 1609 1610 func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { 1611 return fmt.Errorf("test binder") 1612 } 1613 1614 func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { 1615 return nil 1616 } 1617 1618 func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { 1619 return nil 1620 } 1621 1622 func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { 1623 return nil 1624 } 1625 1626 func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { 1627 return nil 1628 } 1629 1630 func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { 1631 return nil 1632 } 1633 1634 func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { 1635 return nil 1636 } 1637 1638 func TestSetBinder(t *testing.T) { 1639 c := NewContext(0) 1640 c.SetBinder(&mockBinder{}) 1641 type T struct{} 1642 req := T{} 1643 err := c.Bind(&req) 1644 assert.Nil(t, err) 1645 err = c.BindAndValidate(&req) 1646 assert.NotNil(t, err) 1647 assert.DeepEqual(t, "test binder", err.Error()) 1648 err = c.BindProtobuf(&req) 1649 assert.Nil(t, err) 1650 err = c.BindJSON(&req) 1651 assert.Nil(t, err) 1652 err = c.BindForm(&req) 1653 assert.NotNil(t, err) 1654 err = c.BindPath(&req) 1655 assert.Nil(t, err) 1656 err = c.BindQuery(&req) 1657 assert.Nil(t, err) 1658 err = c.BindHeader(&req) 1659 assert.Nil(t, err) 1660 } 1661 1662 func TestRequestContext_SetCookie(t *testing.T) { 1663 c := NewContext(0) 1664 c.SetCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteNoneMode, true, true) 1665 assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=None", c.Response.Header.Get("Set-Cookie")) 1666 } 1667 1668 func TestRequestContext_SetPartitionedCookie(t *testing.T) { 1669 c := NewContext(0) 1670 c.SetPartitionedCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteNoneMode, true, true) 1671 assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=None; Partitioned", c.Response.Header.Get("Set-Cookie")) 1672 } 1673 1674 func TestRequestContext_SetCookiePathEmpty(t *testing.T) { 1675 c := NewContext(0) 1676 c.SetCookie("user", "hertz", 1, "", "localhost", protocol.CookieSameSiteDisabled, true, true) 1677 assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure", c.Response.Header.Get("Set-Cookie")) 1678 } 1679 1680 func TestRequestContext_VisitAll(t *testing.T) { 1681 t.Run("VisitAllQueryArgs", func(t *testing.T) { 1682 c := NewContext(0) 1683 var s []string 1684 c.QueryArgs().Add("cloudwego", "hertz") 1685 c.QueryArgs().Add("hello", "world") 1686 c.VisitAllQueryArgs(func(key, value []byte) { 1687 s = append(s, string(key), string(value)) 1688 }) 1689 assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s) 1690 }) 1691 1692 t.Run("VisitAllPostArgs", func(t *testing.T) { 1693 c := NewContext(0) 1694 var s []string 1695 c.PostArgs().Add("cloudwego", "hertz") 1696 c.PostArgs().Add("hello", "world") 1697 c.VisitAllPostArgs(func(key, value []byte) { 1698 s = append(s, string(key), string(value)) 1699 }) 1700 assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s) 1701 }) 1702 1703 t.Run("VisitAllCookie", func(t *testing.T) { 1704 c := NewContext(0) 1705 var s []string 1706 c.Request.Header.Set("Cookie", "aaa=bbb;ccc=ddd") 1707 c.VisitAllCookie(func(key, value []byte) { 1708 s = append(s, string(key), string(value)) 1709 }) 1710 assert.DeepEqual(t, []string{"aaa", "bbb", "ccc", "ddd"}, s) 1711 }) 1712 1713 t.Run("VisitAllHeaders", func(t *testing.T) { 1714 c := NewContext(0) 1715 c.Request.Header.Set("xxx", "yyy") 1716 c.Request.Header.Set("xxx2", "yyy2") 1717 c.VisitAllHeaders( 1718 func(k, v []byte) { 1719 key := string(k) 1720 value := string(v) 1721 if key != "Xxx" && key != "Xxx2" { 1722 t.Fatalf("Unexpected %v. Expected %v", key, "xxx or yyy") 1723 } 1724 if key == "Xxx" && value != "yyy" { 1725 t.Fatalf("Unexpected %v. Expected %v", value, "yyy") 1726 } 1727 if key == "Xxx2" && value != "yyy2" { 1728 t.Fatalf("Unexpected %v. Expected %v", value, "yyy2") 1729 } 1730 }) 1731 }) 1732 } 1733 1734 func BenchmarkInbuiltHandlerNameOperator(b *testing.B) { 1735 for n := 0; n < b.N; n++ { 1736 fn := func(c context.Context, ctx *RequestContext) { 1737 } 1738 SetHandlerName(fn, fmt.Sprintf("%d", n)) 1739 GetHandlerName(fn) 1740 } 1741 } 1742 1743 func BenchmarkConcurrentHandlerNameOperator(b *testing.B) { 1744 SetConcurrentHandlerNameOperator() 1745 defer SetHandlerNameOperator(&inbuiltHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}}) 1746 for n := 0; n < b.N; n++ { 1747 fn := func(c context.Context, ctx *RequestContext) { 1748 } 1749 SetHandlerName(fn, fmt.Sprintf("%d", n)) 1750 GetHandlerName(fn) 1751 } 1752 }