github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/web/web_test.go (about) 1 // Copyright 2013 <chaishushan{AT}gmail.com>. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package web 6 7 import ( 8 "bytes" 9 "encoding/base64" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "io" 14 "io/ioutil" 15 "log" 16 "net/http" 17 "net/url" 18 "runtime" 19 "strconv" 20 "strings" 21 "testing" 22 ) 23 24 func init() { 25 runtime.GOMAXPROCS(4) 26 } 27 28 // ioBuffer is a helper that implements io.ReadWriteCloser, 29 // which is helpful in imitating a net.Conn 30 type ioBuffer struct { 31 input *bytes.Buffer 32 output *bytes.Buffer 33 closed bool 34 } 35 36 func (buf *ioBuffer) Write(p []uint8) (n int, err error) { 37 if buf.closed { 38 return 0, errors.New("Write after Close on ioBuffer") 39 } 40 return buf.output.Write(p) 41 } 42 43 func (buf *ioBuffer) Read(p []byte) (n int, err error) { 44 if buf.closed { 45 return 0, errors.New("Read after Close on ioBuffer") 46 } 47 return buf.input.Read(p) 48 } 49 50 //noop 51 func (buf *ioBuffer) Close() error { 52 buf.closed = true 53 return nil 54 } 55 56 type testResponse struct { 57 statusCode int 58 status string 59 body string 60 headers map[string][]string 61 cookies map[string]string 62 } 63 64 func buildTestResponse(buf *bytes.Buffer) *testResponse { 65 66 response := testResponse{headers: make(map[string][]string), cookies: make(map[string]string)} 67 s := buf.String() 68 contents := strings.SplitN(s, "\r\n\r\n", 2) 69 70 header := contents[0] 71 72 if len(contents) > 1 { 73 response.body = contents[1] 74 } 75 76 headers := strings.Split(header, "\r\n") 77 78 statusParts := strings.SplitN(headers[0], " ", 3) 79 response.statusCode, _ = strconv.Atoi(statusParts[1]) 80 81 for _, h := range headers[1:] { 82 split := strings.SplitN(h, ":", 2) 83 name := strings.TrimSpace(split[0]) 84 value := strings.TrimSpace(split[1]) 85 if _, ok := response.headers[name]; !ok { 86 response.headers[name] = []string{} 87 } 88 89 newheaders := make([]string, len(response.headers[name])+1) 90 copy(newheaders, response.headers[name]) 91 newheaders[len(newheaders)-1] = value 92 response.headers[name] = newheaders 93 94 //if the header is a cookie, set it 95 if name == "Set-Cookie" { 96 i := strings.Index(value, ";") 97 cookie := value[0:i] 98 cookieParts := strings.SplitN(cookie, "=", 2) 99 response.cookies[strings.TrimSpace(cookieParts[0])] = strings.TrimSpace(cookieParts[1]) 100 } 101 } 102 103 return &response 104 } 105 106 func getTestResponse(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *testResponse { 107 req := buildTestRequest(method, path, body, headers, cookies) 108 var buf bytes.Buffer 109 110 tcpb := ioBuffer{input: nil, output: &buf} 111 c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &tcpb} 112 mainServer.Process(&c, req) 113 return buildTestResponse(&buf) 114 } 115 116 func testGet(path string, headers map[string]string) *testResponse { 117 var header http.Header 118 for k, v := range headers { 119 header.Set(k, v) 120 } 121 return getTestResponse("GET", path, "", header, nil) 122 } 123 124 type Test struct { 125 method string 126 path string 127 headers map[string][]string 128 body string 129 expectedStatus int 130 expectedBody string 131 } 132 133 //initialize the routes 134 func init() { 135 mainServer.SetLogger(log.New(ioutil.Discard, "", 0)) 136 Get("/", func() string { return "index" }) 137 Get("/panic", func() { panic(0) }) 138 Get("/echo/(.*)", func(s string) string { return s }) 139 Get("/multiecho/(.*)/(.*)/(.*)/(.*)", func(a, b, c, d string) string { return a + b + c + d }) 140 Post("/post/echo/(.*)", func(s string) string { return s }) 141 Post("/post/echoparam/(.*)", func(ctx *Context, name string) string { return ctx.Params[name] }) 142 143 Get("/error/code/(.*)", func(ctx *Context, code string) string { 144 n, _ := strconv.Atoi(code) 145 message := statusText[n] 146 ctx.Abort(n, message) 147 return "" 148 }) 149 150 Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) }) 151 152 Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) 153 Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() }) 154 155 Get("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) 156 Post("/error/forbidden", func(ctx *Context) { ctx.Forbidden() }) 157 158 Post("/posterror/code/(.*)/(.*)", func(ctx *Context, code string, message string) string { 159 n, _ := strconv.Atoi(code) 160 ctx.Abort(n, message) 161 return "" 162 }) 163 164 Get("/writetest", func(ctx *Context) { ctx.WriteString("hello") }) 165 166 Post("/securecookie/set/(.+)/(.+)", func(ctx *Context, name string, val string) string { 167 ctx.SetSecureCookie(name, val, 60) 168 return "" 169 }) 170 171 Get("/securecookie/get/(.+)", func(ctx *Context, name string) string { 172 val, ok := ctx.GetSecureCookie(name) 173 if !ok { 174 return "" 175 } 176 return val 177 }) 178 Get("/getparam", func(ctx *Context) string { return ctx.Params["a"] }) 179 Get("/fullparams", func(ctx *Context) string { 180 return strings.Join(ctx.Request.Form["a"], ",") 181 }) 182 183 Get("/json", func(ctx *Context) string { 184 ctx.ContentType("json") 185 data, _ := json.Marshal(ctx.Params) 186 return string(data) 187 }) 188 189 Get("/jsonbytes", func(ctx *Context) []byte { 190 ctx.ContentType("json") 191 data, _ := json.Marshal(ctx.Params) 192 return data 193 }) 194 195 Post("/parsejson", func(ctx *Context) string { 196 var tmp = struct { 197 A string 198 B string 199 }{} 200 json.NewDecoder(ctx.Request.Body).Decode(&tmp) 201 return tmp.A + " " + tmp.B 202 }) 203 204 Match("OPTIONS", "/options", func(ctx *Context) { 205 ctx.SetHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS", true) 206 ctx.SetHeader("Access-Control-Max-Age", "1000", true) 207 ctx.WriteHeader(200) 208 }) 209 210 Get("/dupeheader", func(ctx *Context) string { 211 ctx.SetHeader("Server", "myserver", true) 212 return "" 213 }) 214 215 Get("/authorization", func(ctx *Context) string { 216 user, pass, err := ctx.GetBasicAuth() 217 if err != nil { 218 return "fail" 219 } 220 return user + pass 221 }) 222 } 223 224 var tests = []Test{ 225 {"GET", "/", nil, "", 200, "index"}, 226 {"GET", "/echo/hello", nil, "", 200, "hello"}, 227 {"GET", "/echo/hello", nil, "", 200, "hello"}, 228 {"GET", "/multiecho/a/b/c/d", nil, "", 200, "abcd"}, 229 {"POST", "/post/echo/hello", nil, "", 200, "hello"}, 230 {"POST", "/post/echo/hello", nil, "", 200, "hello"}, 231 {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello", 200, "hello"}, 232 {"POST", "/post/echoparam/c?c=hello", nil, "", 200, "hello"}, 233 {"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello\x00", 200, "hello\x00"}, 234 //long url 235 {"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)}, 236 {"GET", "/writetest", nil, "", 200, "hello"}, 237 {"GET", "/error/unauthorized", nil, "", 401, ""}, 238 {"POST", "/error/unauthorized", nil, "", 401, ""}, 239 {"GET", "/error/forbidden", nil, "", 403, ""}, 240 {"POST", "/error/forbidden", nil, "", 403, ""}, 241 {"GET", "/error/notfound/notfound", nil, "", 404, "notfound"}, 242 {"GET", "/doesnotexist", nil, "", 404, "Page not found"}, 243 {"POST", "/doesnotexist", nil, "", 404, "Page not found"}, 244 {"GET", "/error/code/500", nil, "", 500, statusText[500]}, 245 {"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"}, 246 {"GET", "/getparam?a=abcd", nil, "", 200, "abcd"}, 247 {"GET", "/getparam?b=abcd", nil, "", 200, ""}, 248 {"GET", "/fullparams?a=1&a=2&a=3", nil, "", 200, "1,2,3"}, 249 {"GET", "/panic", nil, "", 500, "Server Error"}, 250 {"GET", "/json?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, 251 {"GET", "/jsonbytes?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`}, 252 {"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"}, 253 //{"GET", "/testenv", "", 200, "hello world"}, 254 {"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"}, 255 } 256 257 func buildTestRequest(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *http.Request { 258 host := "127.0.0.1" 259 port := "80" 260 rawurl := "http://" + host + ":" + port + path 261 url_, _ := url.Parse(rawurl) 262 proto := "HTTP/1.1" 263 264 if headers == nil { 265 headers = map[string][]string{} 266 } 267 268 headers["User-Agent"] = []string{"web.go test"} 269 if method == "POST" { 270 headers["Content-Length"] = []string{fmt.Sprintf("%d", len(body))} 271 if headers["Content-Type"] == nil { 272 headers["Content-Type"] = []string{"text/plain"} 273 } 274 } 275 276 req := http.Request{Method: method, 277 URL: url_, 278 Proto: proto, 279 Host: host, 280 Header: http.Header(headers), 281 Body: ioutil.NopCloser(bytes.NewBufferString(body)), 282 } 283 284 for _, cookie := range cookies { 285 req.AddCookie(cookie) 286 } 287 return &req 288 } 289 290 func TestRouting(t *testing.T) { 291 for _, test := range tests { 292 resp := getTestResponse(test.method, test.path, test.body, test.headers, nil) 293 294 if resp.statusCode != test.expectedStatus { 295 t.Fatalf("%v(%v) expected status %d got %d", test.method, test.path, test.expectedStatus, resp.statusCode) 296 } 297 if resp.body != test.expectedBody { 298 t.Fatalf("%v(%v) expected %q got %q", test.method, test.path, test.expectedBody, resp.body) 299 } 300 if cl, ok := resp.headers["Content-Length"]; ok { 301 clExp, _ := strconv.Atoi(cl[0]) 302 clAct := len(resp.body) 303 if clExp != clAct { 304 t.Fatalf("Content-length doesn't match. expected %d got %d", clExp, clAct) 305 } 306 } 307 } 308 } 309 310 func TestHead(t *testing.T) { 311 for _, test := range tests { 312 313 if test.method != "GET" { 314 continue 315 } 316 getresp := getTestResponse("GET", test.path, test.body, test.headers, nil) 317 headresp := getTestResponse("HEAD", test.path, test.body, test.headers, nil) 318 319 if getresp.statusCode != headresp.statusCode { 320 t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) 321 } 322 if len(headresp.body) != 0 { 323 t.Fatalf("head request arrived with a body") 324 } 325 326 var cl []string 327 var getcl, headcl int 328 var hascl1, hascl2 bool 329 330 if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { 331 getcl, _ = strconv.Atoi(cl[0]) 332 } 333 334 if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { 335 headcl, _ = strconv.Atoi(cl[0]) 336 } 337 338 if hascl1 != hascl2 { 339 t.Fatalf("head and get: one has content-length, one doesn't") 340 } 341 342 if hascl1 == true && getcl != headcl { 343 t.Fatalf("head and get content-length differ") 344 } 345 } 346 } 347 348 func buildTestScgiRequest(method string, path string, body string, headers map[string][]string) *bytes.Buffer { 349 var headerBuf bytes.Buffer 350 scgiHeaders := make(map[string]string) 351 352 headerBuf.WriteString("CONTENT_LENGTH") 353 headerBuf.WriteByte(0) 354 headerBuf.WriteString(fmt.Sprintf("%d", len(body))) 355 headerBuf.WriteByte(0) 356 357 scgiHeaders["REQUEST_METHOD"] = method 358 scgiHeaders["HTTP_HOST"] = "127.0.0.1" 359 scgiHeaders["REQUEST_URI"] = path 360 scgiHeaders["SERVER_PORT"] = "80" 361 scgiHeaders["SERVER_PROTOCOL"] = "HTTP/1.1" 362 scgiHeaders["USER_AGENT"] = "web.go test framework" 363 364 for k, v := range headers { 365 //Skip content-length 366 if k == "Content-Length" { 367 continue 368 } 369 key := "HTTP_" + strings.ToUpper(strings.Replace(k, "-", "_", -1)) 370 scgiHeaders[key] = v[0] 371 } 372 for k, v := range scgiHeaders { 373 headerBuf.WriteString(k) 374 headerBuf.WriteByte(0) 375 headerBuf.WriteString(v) 376 headerBuf.WriteByte(0) 377 } 378 headerData := headerBuf.Bytes() 379 380 var buf bytes.Buffer 381 //extra 1 is for the comma at the end 382 dlen := len(headerData) 383 fmt.Fprintf(&buf, "%d:", dlen) 384 buf.Write(headerData) 385 buf.WriteByte(',') 386 buf.WriteString(body) 387 return &buf 388 } 389 390 func TestScgi(t *testing.T) { 391 for _, test := range tests { 392 req := buildTestScgiRequest(test.method, test.path, test.body, test.headers) 393 var output bytes.Buffer 394 nb := ioBuffer{input: req, output: &output} 395 mainServer.handleScgiRequest(&nb) 396 resp := buildTestResponse(&output) 397 398 if resp.statusCode != test.expectedStatus { 399 t.Fatalf("expected status %d got %d", test.expectedStatus, resp.statusCode) 400 } 401 402 if resp.body != test.expectedBody { 403 t.Fatalf("Scgi expected %q got %q", test.expectedBody, resp.body) 404 } 405 } 406 } 407 408 func TestScgiHead(t *testing.T) { 409 for _, test := range tests { 410 411 if test.method != "GET" { 412 continue 413 } 414 415 req := buildTestScgiRequest("GET", test.path, test.body, make(map[string][]string)) 416 var output bytes.Buffer 417 nb := ioBuffer{input: req, output: &output} 418 mainServer.handleScgiRequest(&nb) 419 getresp := buildTestResponse(&output) 420 421 req = buildTestScgiRequest("HEAD", test.path, test.body, make(map[string][]string)) 422 var output2 bytes.Buffer 423 nb = ioBuffer{input: req, output: &output2} 424 mainServer.handleScgiRequest(&nb) 425 headresp := buildTestResponse(&output2) 426 427 if getresp.statusCode != headresp.statusCode { 428 t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode) 429 } 430 if len(headresp.body) != 0 { 431 t.Fatalf("head request arrived with a body") 432 } 433 434 var cl []string 435 var getcl, headcl int 436 var hascl1, hascl2 bool 437 438 if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 { 439 getcl, _ = strconv.Atoi(cl[0]) 440 } 441 442 if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 { 443 headcl, _ = strconv.Atoi(cl[0]) 444 } 445 446 if hascl1 != hascl2 { 447 t.Fatalf("head and get: one has content-length, one doesn't") 448 } 449 450 if hascl1 == true && getcl != headcl { 451 t.Fatalf("head and get content-length differ") 452 } 453 } 454 } 455 456 func TestReadScgiRequest(t *testing.T) { 457 headers := map[string][]string{"User-Agent": {"web.go"}} 458 req := buildTestScgiRequest("POST", "/hello", "Hello world!", headers) 459 var s Server 460 httpReq, err := s.readScgiRequest(&ioBuffer{input: req, output: nil}) 461 if err != nil { 462 t.Fatalf("Error while reading SCGI request: %v", err.Error()) 463 } 464 if httpReq.ContentLength != 12 { 465 t.Fatalf("Content length mismatch, expected %d, got %d ", 12, httpReq.ContentLength) 466 } 467 var body bytes.Buffer 468 io.Copy(&body, httpReq.Body) 469 if body.String() != "Hello world!" { 470 t.Fatalf("Body mismatch, expected %q, got %q ", "Hello world!", body.String()) 471 } 472 } 473 474 func makeCookie(vals map[string]string) []*http.Cookie { 475 var cookies []*http.Cookie 476 for k, v := range vals { 477 c := &http.Cookie{ 478 Name: k, 479 Value: v, 480 } 481 cookies = append(cookies, c) 482 } 483 return cookies 484 } 485 486 func TestSecureCookie(t *testing.T) { 487 mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd" 488 resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil) 489 sval, ok := resp1.cookies["a"] 490 if !ok { 491 t.Fatalf("Failed to get cookie ") 492 } 493 cookies := makeCookie(map[string]string{"a": sval}) 494 495 resp2 := getTestResponse("GET", "/securecookie/get/a", "", nil, cookies) 496 497 if resp2.body != "1" { 498 t.Fatalf("SecureCookie test failed") 499 } 500 } 501 502 func TestEarlyClose(t *testing.T) { 503 var server1 Server 504 server1.Close() 505 } 506 507 func TestOptions(t *testing.T) { 508 resp := getTestResponse("OPTIONS", "/options", "", nil, nil) 509 if resp.headers["Access-Control-Allow-Methods"][0] != "POST, GET, OPTIONS" { 510 t.Fatalf("TestOptions - Access-Control-Allow-Methods failed") 511 } 512 if resp.headers["Access-Control-Max-Age"][0] != "1000" { 513 t.Fatalf("TestOptions - Access-Control-Max-Age failed") 514 } 515 } 516 517 func TestSlug(t *testing.T) { 518 tests := [][]string{ 519 {"", ""}, 520 {"a", "a"}, 521 {"a/b", "a-b"}, 522 {"a b", "a-b"}, 523 {"a////b", "a-b"}, 524 {" a////b ", "a-b"}, 525 {" Manowar / Friends ", "manowar-friends"}, 526 } 527 528 for _, test := range tests { 529 v := Slug(test[0], "-") 530 if v != test[1] { 531 t.Fatalf("TestSlug(%v) failed, expected %v, got %v", test[0], test[1], v) 532 } 533 } 534 } 535 536 // tests that we don't duplicate headers 537 func TestDuplicateHeader(t *testing.T) { 538 resp := testGet("/dupeheader", nil) 539 if len(resp.headers["Server"]) > 1 { 540 t.Fatalf("Expected only one header, got %#v", resp.headers["Server"]) 541 } 542 if resp.headers["Server"][0] != "myserver" { 543 t.Fatalf("Incorrect header, exp 'myserver', got %q", resp.headers["Server"][0]) 544 } 545 } 546 547 func BuildBasicAuthCredentials(user string, pass string) string { 548 s := user + ":" + pass 549 return "Basic " + base64.StdEncoding.EncodeToString([]byte(s)) 550 } 551 552 func BenchmarkProcessGet(b *testing.B) { 553 s := NewServer() 554 s.SetLogger(log.New(ioutil.Discard, "", 0)) 555 s.Get("/echo/(.*)", func(s string) string { 556 return s 557 }) 558 req := buildTestRequest("GET", "/echo/hi", "", nil, nil) 559 var buf bytes.Buffer 560 iob := ioBuffer{input: nil, output: &buf} 561 c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} 562 b.ReportAllocs() 563 b.ResetTimer() 564 for i := 0; i < b.N; i++ { 565 s.Process(&c, req) 566 } 567 } 568 569 func BenchmarkProcessPost(b *testing.B) { 570 s := NewServer() 571 s.SetLogger(log.New(ioutil.Discard, "", 0)) 572 s.Post("/echo/(.*)", func(s string) string { 573 return s 574 }) 575 req := buildTestRequest("POST", "/echo/hi", "", nil, nil) 576 var buf bytes.Buffer 577 iob := ioBuffer{input: nil, output: &buf} 578 c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob} 579 b.ReportAllocs() 580 b.ResetTimer() 581 for i := 0; i < b.N; i++ { 582 s.Process(&c, req) 583 } 584 }