github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/web/web_test.go (about)

     1  // Copyright 2013 <chaishushan{AT}gmail.com>. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package web
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/base64"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"log"
    16  	"net/http"
    17  	"net/url"
    18  	"runtime"
    19  	"strconv"
    20  	"strings"
    21  	"testing"
    22  )
    23  
    24  func init() {
    25  	runtime.GOMAXPROCS(4)
    26  }
    27  
    28  // ioBuffer is a helper that implements io.ReadWriteCloser,
    29  // which is helpful in imitating a net.Conn
    30  type ioBuffer struct {
    31  	input  *bytes.Buffer
    32  	output *bytes.Buffer
    33  	closed bool
    34  }
    35  
    36  func (buf *ioBuffer) Write(p []uint8) (n int, err error) {
    37  	if buf.closed {
    38  		return 0, errors.New("Write after Close on ioBuffer")
    39  	}
    40  	return buf.output.Write(p)
    41  }
    42  
    43  func (buf *ioBuffer) Read(p []byte) (n int, err error) {
    44  	if buf.closed {
    45  		return 0, errors.New("Read after Close on ioBuffer")
    46  	}
    47  	return buf.input.Read(p)
    48  }
    49  
    50  //noop
    51  func (buf *ioBuffer) Close() error {
    52  	buf.closed = true
    53  	return nil
    54  }
    55  
    56  type testResponse struct {
    57  	statusCode int
    58  	status     string
    59  	body       string
    60  	headers    map[string][]string
    61  	cookies    map[string]string
    62  }
    63  
    64  func buildTestResponse(buf *bytes.Buffer) *testResponse {
    65  
    66  	response := testResponse{headers: make(map[string][]string), cookies: make(map[string]string)}
    67  	s := buf.String()
    68  	contents := strings.SplitN(s, "\r\n\r\n", 2)
    69  
    70  	header := contents[0]
    71  
    72  	if len(contents) > 1 {
    73  		response.body = contents[1]
    74  	}
    75  
    76  	headers := strings.Split(header, "\r\n")
    77  
    78  	statusParts := strings.SplitN(headers[0], " ", 3)
    79  	response.statusCode, _ = strconv.Atoi(statusParts[1])
    80  
    81  	for _, h := range headers[1:] {
    82  		split := strings.SplitN(h, ":", 2)
    83  		name := strings.TrimSpace(split[0])
    84  		value := strings.TrimSpace(split[1])
    85  		if _, ok := response.headers[name]; !ok {
    86  			response.headers[name] = []string{}
    87  		}
    88  
    89  		newheaders := make([]string, len(response.headers[name])+1)
    90  		copy(newheaders, response.headers[name])
    91  		newheaders[len(newheaders)-1] = value
    92  		response.headers[name] = newheaders
    93  
    94  		//if the header is a cookie, set it
    95  		if name == "Set-Cookie" {
    96  			i := strings.Index(value, ";")
    97  			cookie := value[0:i]
    98  			cookieParts := strings.SplitN(cookie, "=", 2)
    99  			response.cookies[strings.TrimSpace(cookieParts[0])] = strings.TrimSpace(cookieParts[1])
   100  		}
   101  	}
   102  
   103  	return &response
   104  }
   105  
   106  func getTestResponse(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *testResponse {
   107  	req := buildTestRequest(method, path, body, headers, cookies)
   108  	var buf bytes.Buffer
   109  
   110  	tcpb := ioBuffer{input: nil, output: &buf}
   111  	c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &tcpb}
   112  	mainServer.Process(&c, req)
   113  	return buildTestResponse(&buf)
   114  }
   115  
   116  func testGet(path string, headers map[string]string) *testResponse {
   117  	var header http.Header
   118  	for k, v := range headers {
   119  		header.Set(k, v)
   120  	}
   121  	return getTestResponse("GET", path, "", header, nil)
   122  }
   123  
   124  type Test struct {
   125  	method         string
   126  	path           string
   127  	headers        map[string][]string
   128  	body           string
   129  	expectedStatus int
   130  	expectedBody   string
   131  }
   132  
   133  //initialize the routes
   134  func init() {
   135  	mainServer.SetLogger(log.New(ioutil.Discard, "", 0))
   136  	Get("/", func() string { return "index" })
   137  	Get("/panic", func() { panic(0) })
   138  	Get("/echo/(.*)", func(s string) string { return s })
   139  	Get("/multiecho/(.*)/(.*)/(.*)/(.*)", func(a, b, c, d string) string { return a + b + c + d })
   140  	Post("/post/echo/(.*)", func(s string) string { return s })
   141  	Post("/post/echoparam/(.*)", func(ctx *Context, name string) string { return ctx.Params[name] })
   142  
   143  	Get("/error/code/(.*)", func(ctx *Context, code string) string {
   144  		n, _ := strconv.Atoi(code)
   145  		message := statusText[n]
   146  		ctx.Abort(n, message)
   147  		return ""
   148  	})
   149  
   150  	Get("/error/notfound/(.*)", func(ctx *Context, message string) { ctx.NotFound(message) })
   151  
   152  	Get("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() })
   153  	Post("/error/unauthorized", func(ctx *Context) { ctx.Unauthorized() })
   154  
   155  	Get("/error/forbidden", func(ctx *Context) { ctx.Forbidden() })
   156  	Post("/error/forbidden", func(ctx *Context) { ctx.Forbidden() })
   157  
   158  	Post("/posterror/code/(.*)/(.*)", func(ctx *Context, code string, message string) string {
   159  		n, _ := strconv.Atoi(code)
   160  		ctx.Abort(n, message)
   161  		return ""
   162  	})
   163  
   164  	Get("/writetest", func(ctx *Context) { ctx.WriteString("hello") })
   165  
   166  	Post("/securecookie/set/(.+)/(.+)", func(ctx *Context, name string, val string) string {
   167  		ctx.SetSecureCookie(name, val, 60)
   168  		return ""
   169  	})
   170  
   171  	Get("/securecookie/get/(.+)", func(ctx *Context, name string) string {
   172  		val, ok := ctx.GetSecureCookie(name)
   173  		if !ok {
   174  			return ""
   175  		}
   176  		return val
   177  	})
   178  	Get("/getparam", func(ctx *Context) string { return ctx.Params["a"] })
   179  	Get("/fullparams", func(ctx *Context) string {
   180  		return strings.Join(ctx.Request.Form["a"], ",")
   181  	})
   182  
   183  	Get("/json", func(ctx *Context) string {
   184  		ctx.ContentType("json")
   185  		data, _ := json.Marshal(ctx.Params)
   186  		return string(data)
   187  	})
   188  
   189  	Get("/jsonbytes", func(ctx *Context) []byte {
   190  		ctx.ContentType("json")
   191  		data, _ := json.Marshal(ctx.Params)
   192  		return data
   193  	})
   194  
   195  	Post("/parsejson", func(ctx *Context) string {
   196  		var tmp = struct {
   197  			A string
   198  			B string
   199  		}{}
   200  		json.NewDecoder(ctx.Request.Body).Decode(&tmp)
   201  		return tmp.A + " " + tmp.B
   202  	})
   203  
   204  	Match("OPTIONS", "/options", func(ctx *Context) {
   205  		ctx.SetHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS", true)
   206  		ctx.SetHeader("Access-Control-Max-Age", "1000", true)
   207  		ctx.WriteHeader(200)
   208  	})
   209  
   210  	Get("/dupeheader", func(ctx *Context) string {
   211  		ctx.SetHeader("Server", "myserver", true)
   212  		return ""
   213  	})
   214  
   215  	Get("/authorization", func(ctx *Context) string {
   216  		user, pass, err := ctx.GetBasicAuth()
   217  		if err != nil {
   218  			return "fail"
   219  		}
   220  		return user + pass
   221  	})
   222  }
   223  
   224  var tests = []Test{
   225  	{"GET", "/", nil, "", 200, "index"},
   226  	{"GET", "/echo/hello", nil, "", 200, "hello"},
   227  	{"GET", "/echo/hello", nil, "", 200, "hello"},
   228  	{"GET", "/multiecho/a/b/c/d", nil, "", 200, "abcd"},
   229  	{"POST", "/post/echo/hello", nil, "", 200, "hello"},
   230  	{"POST", "/post/echo/hello", nil, "", 200, "hello"},
   231  	{"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello", 200, "hello"},
   232  	{"POST", "/post/echoparam/c?c=hello", nil, "", 200, "hello"},
   233  	{"POST", "/post/echoparam/a", map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, "a=hello\x00", 200, "hello\x00"},
   234  	//long url
   235  	{"GET", "/echo/" + strings.Repeat("0123456789", 100), nil, "", 200, strings.Repeat("0123456789", 100)},
   236  	{"GET", "/writetest", nil, "", 200, "hello"},
   237  	{"GET", "/error/unauthorized", nil, "", 401, ""},
   238  	{"POST", "/error/unauthorized", nil, "", 401, ""},
   239  	{"GET", "/error/forbidden", nil, "", 403, ""},
   240  	{"POST", "/error/forbidden", nil, "", 403, ""},
   241  	{"GET", "/error/notfound/notfound", nil, "", 404, "notfound"},
   242  	{"GET", "/doesnotexist", nil, "", 404, "Page not found"},
   243  	{"POST", "/doesnotexist", nil, "", 404, "Page not found"},
   244  	{"GET", "/error/code/500", nil, "", 500, statusText[500]},
   245  	{"POST", "/posterror/code/410/failedrequest", nil, "", 410, "failedrequest"},
   246  	{"GET", "/getparam?a=abcd", nil, "", 200, "abcd"},
   247  	{"GET", "/getparam?b=abcd", nil, "", 200, ""},
   248  	{"GET", "/fullparams?a=1&a=2&a=3", nil, "", 200, "1,2,3"},
   249  	{"GET", "/panic", nil, "", 500, "Server Error"},
   250  	{"GET", "/json?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`},
   251  	{"GET", "/jsonbytes?a=1&b=2", nil, "", 200, `{"a":"1","b":"2"}`},
   252  	{"POST", "/parsejson", map[string][]string{"Content-Type": {"application/json"}}, `{"a":"hello", "b":"world"}`, 200, "hello world"},
   253  	//{"GET", "/testenv", "", 200, "hello world"},
   254  	{"GET", "/authorization", map[string][]string{"Authorization": {BuildBasicAuthCredentials("foo", "bar")}}, "", 200, "foobar"},
   255  }
   256  
   257  func buildTestRequest(method string, path string, body string, headers map[string][]string, cookies []*http.Cookie) *http.Request {
   258  	host := "127.0.0.1"
   259  	port := "80"
   260  	rawurl := "http://" + host + ":" + port + path
   261  	url_, _ := url.Parse(rawurl)
   262  	proto := "HTTP/1.1"
   263  
   264  	if headers == nil {
   265  		headers = map[string][]string{}
   266  	}
   267  
   268  	headers["User-Agent"] = []string{"web.go test"}
   269  	if method == "POST" {
   270  		headers["Content-Length"] = []string{fmt.Sprintf("%d", len(body))}
   271  		if headers["Content-Type"] == nil {
   272  			headers["Content-Type"] = []string{"text/plain"}
   273  		}
   274  	}
   275  
   276  	req := http.Request{Method: method,
   277  		URL:    url_,
   278  		Proto:  proto,
   279  		Host:   host,
   280  		Header: http.Header(headers),
   281  		Body:   ioutil.NopCloser(bytes.NewBufferString(body)),
   282  	}
   283  
   284  	for _, cookie := range cookies {
   285  		req.AddCookie(cookie)
   286  	}
   287  	return &req
   288  }
   289  
   290  func TestRouting(t *testing.T) {
   291  	for _, test := range tests {
   292  		resp := getTestResponse(test.method, test.path, test.body, test.headers, nil)
   293  
   294  		if resp.statusCode != test.expectedStatus {
   295  			t.Fatalf("%v(%v) expected status %d got %d", test.method, test.path, test.expectedStatus, resp.statusCode)
   296  		}
   297  		if resp.body != test.expectedBody {
   298  			t.Fatalf("%v(%v) expected %q got %q", test.method, test.path, test.expectedBody, resp.body)
   299  		}
   300  		if cl, ok := resp.headers["Content-Length"]; ok {
   301  			clExp, _ := strconv.Atoi(cl[0])
   302  			clAct := len(resp.body)
   303  			if clExp != clAct {
   304  				t.Fatalf("Content-length doesn't match. expected %d got %d", clExp, clAct)
   305  			}
   306  		}
   307  	}
   308  }
   309  
   310  func TestHead(t *testing.T) {
   311  	for _, test := range tests {
   312  
   313  		if test.method != "GET" {
   314  			continue
   315  		}
   316  		getresp := getTestResponse("GET", test.path, test.body, test.headers, nil)
   317  		headresp := getTestResponse("HEAD", test.path, test.body, test.headers, nil)
   318  
   319  		if getresp.statusCode != headresp.statusCode {
   320  			t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode)
   321  		}
   322  		if len(headresp.body) != 0 {
   323  			t.Fatalf("head request arrived with a body")
   324  		}
   325  
   326  		var cl []string
   327  		var getcl, headcl int
   328  		var hascl1, hascl2 bool
   329  
   330  		if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 {
   331  			getcl, _ = strconv.Atoi(cl[0])
   332  		}
   333  
   334  		if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 {
   335  			headcl, _ = strconv.Atoi(cl[0])
   336  		}
   337  
   338  		if hascl1 != hascl2 {
   339  			t.Fatalf("head and get: one has content-length, one doesn't")
   340  		}
   341  
   342  		if hascl1 == true && getcl != headcl {
   343  			t.Fatalf("head and get content-length differ")
   344  		}
   345  	}
   346  }
   347  
   348  func buildTestScgiRequest(method string, path string, body string, headers map[string][]string) *bytes.Buffer {
   349  	var headerBuf bytes.Buffer
   350  	scgiHeaders := make(map[string]string)
   351  
   352  	headerBuf.WriteString("CONTENT_LENGTH")
   353  	headerBuf.WriteByte(0)
   354  	headerBuf.WriteString(fmt.Sprintf("%d", len(body)))
   355  	headerBuf.WriteByte(0)
   356  
   357  	scgiHeaders["REQUEST_METHOD"] = method
   358  	scgiHeaders["HTTP_HOST"] = "127.0.0.1"
   359  	scgiHeaders["REQUEST_URI"] = path
   360  	scgiHeaders["SERVER_PORT"] = "80"
   361  	scgiHeaders["SERVER_PROTOCOL"] = "HTTP/1.1"
   362  	scgiHeaders["USER_AGENT"] = "web.go test framework"
   363  
   364  	for k, v := range headers {
   365  		//Skip content-length
   366  		if k == "Content-Length" {
   367  			continue
   368  		}
   369  		key := "HTTP_" + strings.ToUpper(strings.Replace(k, "-", "_", -1))
   370  		scgiHeaders[key] = v[0]
   371  	}
   372  	for k, v := range scgiHeaders {
   373  		headerBuf.WriteString(k)
   374  		headerBuf.WriteByte(0)
   375  		headerBuf.WriteString(v)
   376  		headerBuf.WriteByte(0)
   377  	}
   378  	headerData := headerBuf.Bytes()
   379  
   380  	var buf bytes.Buffer
   381  	//extra 1 is for the comma at the end
   382  	dlen := len(headerData)
   383  	fmt.Fprintf(&buf, "%d:", dlen)
   384  	buf.Write(headerData)
   385  	buf.WriteByte(',')
   386  	buf.WriteString(body)
   387  	return &buf
   388  }
   389  
   390  func TestScgi(t *testing.T) {
   391  	for _, test := range tests {
   392  		req := buildTestScgiRequest(test.method, test.path, test.body, test.headers)
   393  		var output bytes.Buffer
   394  		nb := ioBuffer{input: req, output: &output}
   395  		mainServer.handleScgiRequest(&nb)
   396  		resp := buildTestResponse(&output)
   397  
   398  		if resp.statusCode != test.expectedStatus {
   399  			t.Fatalf("expected status %d got %d", test.expectedStatus, resp.statusCode)
   400  		}
   401  
   402  		if resp.body != test.expectedBody {
   403  			t.Fatalf("Scgi expected %q got %q", test.expectedBody, resp.body)
   404  		}
   405  	}
   406  }
   407  
   408  func TestScgiHead(t *testing.T) {
   409  	for _, test := range tests {
   410  
   411  		if test.method != "GET" {
   412  			continue
   413  		}
   414  
   415  		req := buildTestScgiRequest("GET", test.path, test.body, make(map[string][]string))
   416  		var output bytes.Buffer
   417  		nb := ioBuffer{input: req, output: &output}
   418  		mainServer.handleScgiRequest(&nb)
   419  		getresp := buildTestResponse(&output)
   420  
   421  		req = buildTestScgiRequest("HEAD", test.path, test.body, make(map[string][]string))
   422  		var output2 bytes.Buffer
   423  		nb = ioBuffer{input: req, output: &output2}
   424  		mainServer.handleScgiRequest(&nb)
   425  		headresp := buildTestResponse(&output2)
   426  
   427  		if getresp.statusCode != headresp.statusCode {
   428  			t.Fatalf("head and get status differ. expected %d got %d", getresp.statusCode, headresp.statusCode)
   429  		}
   430  		if len(headresp.body) != 0 {
   431  			t.Fatalf("head request arrived with a body")
   432  		}
   433  
   434  		var cl []string
   435  		var getcl, headcl int
   436  		var hascl1, hascl2 bool
   437  
   438  		if cl, hascl1 = getresp.headers["Content-Length"]; hascl1 {
   439  			getcl, _ = strconv.Atoi(cl[0])
   440  		}
   441  
   442  		if cl, hascl2 = headresp.headers["Content-Length"]; hascl2 {
   443  			headcl, _ = strconv.Atoi(cl[0])
   444  		}
   445  
   446  		if hascl1 != hascl2 {
   447  			t.Fatalf("head and get: one has content-length, one doesn't")
   448  		}
   449  
   450  		if hascl1 == true && getcl != headcl {
   451  			t.Fatalf("head and get content-length differ")
   452  		}
   453  	}
   454  }
   455  
   456  func TestReadScgiRequest(t *testing.T) {
   457  	headers := map[string][]string{"User-Agent": {"web.go"}}
   458  	req := buildTestScgiRequest("POST", "/hello", "Hello world!", headers)
   459  	var s Server
   460  	httpReq, err := s.readScgiRequest(&ioBuffer{input: req, output: nil})
   461  	if err != nil {
   462  		t.Fatalf("Error while reading SCGI request: %v", err.Error())
   463  	}
   464  	if httpReq.ContentLength != 12 {
   465  		t.Fatalf("Content length mismatch, expected %d, got %d ", 12, httpReq.ContentLength)
   466  	}
   467  	var body bytes.Buffer
   468  	io.Copy(&body, httpReq.Body)
   469  	if body.String() != "Hello world!" {
   470  		t.Fatalf("Body mismatch, expected %q, got %q ", "Hello world!", body.String())
   471  	}
   472  }
   473  
   474  func makeCookie(vals map[string]string) []*http.Cookie {
   475  	var cookies []*http.Cookie
   476  	for k, v := range vals {
   477  		c := &http.Cookie{
   478  			Name:  k,
   479  			Value: v,
   480  		}
   481  		cookies = append(cookies, c)
   482  	}
   483  	return cookies
   484  }
   485  
   486  func TestSecureCookie(t *testing.T) {
   487  	mainServer.Config.CookieSecret = "7C19QRmwf3mHZ9CPAaPQ0hsWeufKd"
   488  	resp1 := getTestResponse("POST", "/securecookie/set/a/1", "", nil, nil)
   489  	sval, ok := resp1.cookies["a"]
   490  	if !ok {
   491  		t.Fatalf("Failed to get cookie ")
   492  	}
   493  	cookies := makeCookie(map[string]string{"a": sval})
   494  
   495  	resp2 := getTestResponse("GET", "/securecookie/get/a", "", nil, cookies)
   496  
   497  	if resp2.body != "1" {
   498  		t.Fatalf("SecureCookie test failed")
   499  	}
   500  }
   501  
   502  func TestEarlyClose(t *testing.T) {
   503  	var server1 Server
   504  	server1.Close()
   505  }
   506  
   507  func TestOptions(t *testing.T) {
   508  	resp := getTestResponse("OPTIONS", "/options", "", nil, nil)
   509  	if resp.headers["Access-Control-Allow-Methods"][0] != "POST, GET, OPTIONS" {
   510  		t.Fatalf("TestOptions - Access-Control-Allow-Methods failed")
   511  	}
   512  	if resp.headers["Access-Control-Max-Age"][0] != "1000" {
   513  		t.Fatalf("TestOptions - Access-Control-Max-Age failed")
   514  	}
   515  }
   516  
   517  func TestSlug(t *testing.T) {
   518  	tests := [][]string{
   519  		{"", ""},
   520  		{"a", "a"},
   521  		{"a/b", "a-b"},
   522  		{"a b", "a-b"},
   523  		{"a////b", "a-b"},
   524  		{" a////b ", "a-b"},
   525  		{" Manowar / Friends ", "manowar-friends"},
   526  	}
   527  
   528  	for _, test := range tests {
   529  		v := Slug(test[0], "-")
   530  		if v != test[1] {
   531  			t.Fatalf("TestSlug(%v) failed, expected %v, got %v", test[0], test[1], v)
   532  		}
   533  	}
   534  }
   535  
   536  // tests that we don't duplicate headers
   537  func TestDuplicateHeader(t *testing.T) {
   538  	resp := testGet("/dupeheader", nil)
   539  	if len(resp.headers["Server"]) > 1 {
   540  		t.Fatalf("Expected only one header, got %#v", resp.headers["Server"])
   541  	}
   542  	if resp.headers["Server"][0] != "myserver" {
   543  		t.Fatalf("Incorrect header, exp 'myserver', got %q", resp.headers["Server"][0])
   544  	}
   545  }
   546  
   547  func BuildBasicAuthCredentials(user string, pass string) string {
   548  	s := user + ":" + pass
   549  	return "Basic " + base64.StdEncoding.EncodeToString([]byte(s))
   550  }
   551  
   552  func BenchmarkProcessGet(b *testing.B) {
   553  	s := NewServer()
   554  	s.SetLogger(log.New(ioutil.Discard, "", 0))
   555  	s.Get("/echo/(.*)", func(s string) string {
   556  		return s
   557  	})
   558  	req := buildTestRequest("GET", "/echo/hi", "", nil, nil)
   559  	var buf bytes.Buffer
   560  	iob := ioBuffer{input: nil, output: &buf}
   561  	c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob}
   562  	b.ReportAllocs()
   563  	b.ResetTimer()
   564  	for i := 0; i < b.N; i++ {
   565  		s.Process(&c, req)
   566  	}
   567  }
   568  
   569  func BenchmarkProcessPost(b *testing.B) {
   570  	s := NewServer()
   571  	s.SetLogger(log.New(ioutil.Discard, "", 0))
   572  	s.Post("/echo/(.*)", func(s string) string {
   573  		return s
   574  	})
   575  	req := buildTestRequest("POST", "/echo/hi", "", nil, nil)
   576  	var buf bytes.Buffer
   577  	iob := ioBuffer{input: nil, output: &buf}
   578  	c := scgiConn{wroteHeaders: false, req: req, headers: make(map[string][]string), fd: &iob}
   579  	b.ReportAllocs()
   580  	b.ResetTimer()
   581  	for i := 0; i < b.N; i++ {
   582  		s.Process(&c, req)
   583  	}
   584  }