github.com/keltia/go-ipfs@v0.3.8-0.20150909044612-210793031c63/commands/http/handler_test.go (about)

     1  package http
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"net/url"
     7  	"testing"
     8  
     9  	cors "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/rs/cors"
    10  
    11  	cmds "github.com/ipfs/go-ipfs/commands"
    12  	ipfscmd "github.com/ipfs/go-ipfs/core/commands"
    13  	coremock "github.com/ipfs/go-ipfs/core/mock"
    14  )
    15  
    16  func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
    17  	for name, value := range reqHeaders {
    18  		if resHeaders.Get(name) != value {
    19  			t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, resHeaders.Get(name))
    20  		}
    21  	}
    22  }
    23  
    24  func assertStatus(t *testing.T, actual, expected int) {
    25  	if actual != expected {
    26  		t.Errorf("Expected status: %d got: %d", expected, actual)
    27  	}
    28  }
    29  
    30  func originCfg(origins []string) *ServerConfig {
    31  	return &ServerConfig{
    32  		CORSOpts: &cors.Options{
    33  			AllowedOrigins: origins,
    34  			AllowedMethods: []string{"GET", "PUT", "POST"},
    35  		},
    36  	}
    37  }
    38  
    39  type testCase struct {
    40  	Method       string
    41  	Path         string
    42  	Code         int
    43  	Origin       string
    44  	Referer      string
    45  	AllowOrigins []string
    46  	ReqHeaders   map[string]string
    47  	ResHeaders   map[string]string
    48  }
    49  
    50  var defaultOrigins = []string{
    51  	"http://localhost",
    52  	"http://127.0.0.1",
    53  	"https://localhost",
    54  	"https://127.0.0.1",
    55  }
    56  
    57  func getTestServer(t *testing.T, origins []string) *httptest.Server {
    58  	cmdsCtx, err := coremock.MockCmdsCtx()
    59  	if err != nil {
    60  		t.Error("failure to initialize mock cmds ctx", err)
    61  		return nil
    62  	}
    63  
    64  	cmdRoot := &cmds.Command{
    65  		Subcommands: map[string]*cmds.Command{
    66  			"version": ipfscmd.VersionCmd,
    67  		},
    68  	}
    69  
    70  	if len(origins) == 0 {
    71  		origins = defaultOrigins
    72  	}
    73  
    74  	handler := NewHandler(cmdsCtx, cmdRoot, originCfg(origins))
    75  	return httptest.NewServer(handler)
    76  }
    77  
    78  func (tc *testCase) test(t *testing.T) {
    79  	// defaults
    80  	method := tc.Method
    81  	if method == "" {
    82  		method = "GET"
    83  	}
    84  
    85  	path := tc.Path
    86  	if path == "" {
    87  		path = "/api/v0/version"
    88  	}
    89  
    90  	expectCode := tc.Code
    91  	if expectCode == 0 {
    92  		expectCode = 200
    93  	}
    94  
    95  	// request
    96  	req, err := http.NewRequest(method, path, nil)
    97  	if err != nil {
    98  		t.Error(err)
    99  		return
   100  	}
   101  
   102  	for k, v := range tc.ReqHeaders {
   103  		req.Header.Add(k, v)
   104  	}
   105  	if tc.Origin != "" {
   106  		req.Header.Add("Origin", tc.Origin)
   107  	}
   108  	if tc.Referer != "" {
   109  		req.Header.Add("Referer", tc.Referer)
   110  	}
   111  
   112  	// server
   113  	server := getTestServer(t, tc.AllowOrigins)
   114  	if server == nil {
   115  		return
   116  	}
   117  	defer server.Close()
   118  
   119  	req.URL, err = url.Parse(server.URL + path)
   120  	if err != nil {
   121  		t.Error(err)
   122  		return
   123  	}
   124  
   125  	res, err := http.DefaultClient.Do(req)
   126  	if err != nil {
   127  		t.Error(err)
   128  		return
   129  	}
   130  
   131  	// checks
   132  	t.Log("GET", server.URL+path, req.Header, res.Header)
   133  	assertHeaders(t, res.Header, tc.ResHeaders)
   134  	assertStatus(t, res.StatusCode, expectCode)
   135  }
   136  
   137  func TestDisallowedOrigins(t *testing.T) {
   138  	gtc := func(origin string, allowedOrigins []string) testCase {
   139  		return testCase{
   140  			Origin:       origin,
   141  			AllowOrigins: allowedOrigins,
   142  			ResHeaders: map[string]string{
   143  				ACAOrigin:                       "",
   144  				ACAMethods:                      "",
   145  				ACACredentials:                  "",
   146  				"Access-Control-Max-Age":        "",
   147  				"Access-Control-Expose-Headers": "",
   148  			},
   149  			Code: http.StatusForbidden,
   150  		}
   151  	}
   152  
   153  	tcs := []testCase{
   154  		gtc("http://barbaz.com", nil),
   155  		gtc("http://barbaz.com", []string{"http://localhost"}),
   156  		gtc("http://127.0.0.1", []string{"http://localhost"}),
   157  		gtc("http://localhost", []string{"http://127.0.0.1"}),
   158  		gtc("http://127.0.0.1:1234", nil),
   159  		gtc("http://localhost:1234", nil),
   160  	}
   161  
   162  	for _, tc := range tcs {
   163  		tc.test(t)
   164  	}
   165  }
   166  
   167  func TestAllowedOrigins(t *testing.T) {
   168  	gtc := func(origin string, allowedOrigins []string) testCase {
   169  		return testCase{
   170  			Origin:       origin,
   171  			AllowOrigins: allowedOrigins,
   172  			ResHeaders: map[string]string{
   173  				ACAOrigin:                       origin,
   174  				ACAMethods:                      "",
   175  				ACACredentials:                  "",
   176  				"Access-Control-Max-Age":        "",
   177  				"Access-Control-Expose-Headers": "",
   178  			},
   179  			Code: http.StatusOK,
   180  		}
   181  	}
   182  
   183  	tcs := []testCase{
   184  		gtc("http://barbaz.com", []string{"http://barbaz.com", "http://localhost"}),
   185  		gtc("http://localhost", []string{"http://barbaz.com", "http://localhost"}),
   186  		gtc("http://localhost", nil),
   187  		gtc("http://127.0.0.1", nil),
   188  	}
   189  
   190  	for _, tc := range tcs {
   191  		tc.test(t)
   192  	}
   193  }
   194  
   195  func TestWildcardOrigin(t *testing.T) {
   196  	gtc := func(origin string, allowedOrigins []string) testCase {
   197  		return testCase{
   198  			Origin:       origin,
   199  			AllowOrigins: allowedOrigins,
   200  			ResHeaders: map[string]string{
   201  				ACAOrigin:                       origin,
   202  				ACAMethods:                      "",
   203  				ACACredentials:                  "",
   204  				"Access-Control-Max-Age":        "",
   205  				"Access-Control-Expose-Headers": "",
   206  			},
   207  			Code: http.StatusOK,
   208  		}
   209  	}
   210  
   211  	tcs := []testCase{
   212  		gtc("http://barbaz.com", []string{"*"}),
   213  		gtc("http://barbaz.com", []string{"http://localhost", "*"}),
   214  		gtc("http://127.0.0.1", []string{"http://localhost", "*"}),
   215  		gtc("http://localhost", []string{"http://127.0.0.1", "*"}),
   216  		gtc("http://127.0.0.1", []string{"*"}),
   217  		gtc("http://localhost", []string{"*"}),
   218  		gtc("http://127.0.0.1:1234", []string{"*"}),
   219  		gtc("http://localhost:1234", []string{"*"}),
   220  	}
   221  
   222  	for _, tc := range tcs {
   223  		tc.test(t)
   224  	}
   225  }
   226  
   227  func TestDisallowedReferer(t *testing.T) {
   228  	gtc := func(referer string, allowedOrigins []string) testCase {
   229  		return testCase{
   230  			Origin:       "http://localhost",
   231  			Referer:      referer,
   232  			AllowOrigins: allowedOrigins,
   233  			ResHeaders: map[string]string{
   234  				ACAOrigin:                       "http://localhost",
   235  				ACAMethods:                      "",
   236  				ACACredentials:                  "",
   237  				"Access-Control-Max-Age":        "",
   238  				"Access-Control-Expose-Headers": "",
   239  			},
   240  			Code: http.StatusForbidden,
   241  		}
   242  	}
   243  
   244  	tcs := []testCase{
   245  		gtc("http://foobar.com", nil),
   246  		gtc("http://localhost:1234", nil),
   247  		gtc("http://127.0.0.1:1234", nil),
   248  	}
   249  
   250  	for _, tc := range tcs {
   251  		tc.test(t)
   252  	}
   253  }
   254  
   255  func TestAllowedReferer(t *testing.T) {
   256  	gtc := func(referer string, allowedOrigins []string) testCase {
   257  		return testCase{
   258  			Origin:       "http://localhost",
   259  			AllowOrigins: allowedOrigins,
   260  			ResHeaders: map[string]string{
   261  				ACAOrigin:                       "http://localhost",
   262  				ACAMethods:                      "",
   263  				ACACredentials:                  "",
   264  				"Access-Control-Max-Age":        "",
   265  				"Access-Control-Expose-Headers": "",
   266  			},
   267  			Code: http.StatusOK,
   268  		}
   269  	}
   270  
   271  	tcs := []testCase{
   272  		gtc("http://barbaz.com", []string{"http://barbaz.com", "http://localhost"}),
   273  		gtc("http://localhost", []string{"http://barbaz.com", "http://localhost"}),
   274  		gtc("http://localhost", nil),
   275  		gtc("http://127.0.0.1", nil),
   276  	}
   277  
   278  	for _, tc := range tcs {
   279  		tc.test(t)
   280  	}
   281  }
   282  
   283  func TestWildcardReferer(t *testing.T) {
   284  	gtc := func(origin string, allowedOrigins []string) testCase {
   285  		return testCase{
   286  			Origin:       origin,
   287  			AllowOrigins: allowedOrigins,
   288  			ResHeaders: map[string]string{
   289  				ACAOrigin:                       origin,
   290  				ACAMethods:                      "",
   291  				ACACredentials:                  "",
   292  				"Access-Control-Max-Age":        "",
   293  				"Access-Control-Expose-Headers": "",
   294  			},
   295  			Code: http.StatusOK,
   296  		}
   297  	}
   298  
   299  	tcs := []testCase{
   300  		gtc("http://barbaz.com", []string{"*"}),
   301  		gtc("http://barbaz.com", []string{"http://localhost", "*"}),
   302  		gtc("http://127.0.0.1", []string{"http://localhost", "*"}),
   303  		gtc("http://localhost", []string{"http://127.0.0.1", "*"}),
   304  		gtc("http://127.0.0.1", []string{"*"}),
   305  		gtc("http://localhost", []string{"*"}),
   306  		gtc("http://127.0.0.1:1234", []string{"*"}),
   307  		gtc("http://localhost:1234", []string{"*"}),
   308  	}
   309  
   310  	for _, tc := range tcs {
   311  		tc.test(t)
   312  	}
   313  }
   314  
   315  func TestAllowedMethod(t *testing.T) {
   316  	gtc := func(method string, ok bool) testCase {
   317  		code := http.StatusOK
   318  		hdrs := map[string]string{
   319  			ACAOrigin:                       "http://localhost",
   320  			ACAMethods:                      method,
   321  			ACACredentials:                  "",
   322  			"Access-Control-Max-Age":        "",
   323  			"Access-Control-Expose-Headers": "",
   324  		}
   325  
   326  		if !ok {
   327  			hdrs[ACAOrigin] = ""
   328  			hdrs[ACAMethods] = ""
   329  		}
   330  
   331  		return testCase{
   332  			Method:       "OPTIONS",
   333  			Origin:       "http://localhost",
   334  			AllowOrigins: []string{"*"},
   335  			ReqHeaders: map[string]string{
   336  				"Access-Control-Request-Method": method,
   337  			},
   338  			ResHeaders: hdrs,
   339  			Code:       code,
   340  		}
   341  	}
   342  
   343  	tcs := []testCase{
   344  		gtc("PUT", true),
   345  		gtc("GET", true),
   346  		gtc("FOOBAR", false),
   347  	}
   348  
   349  	for _, tc := range tcs {
   350  		tc.test(t)
   351  	}
   352  }