github.com/ethersphere/bee/v2@v2.2.0/pkg/api/cors_test.go (about)

     1  // Copyright 2021 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package api_test
     6  
     7  import (
     8  	"context"
     9  	"net/http"
    10  	"testing"
    11  
    12  	"github.com/ethersphere/bee/v2/pkg/api"
    13  	"github.com/ethersphere/bee/v2/pkg/jsonhttp/jsonhttptest"
    14  )
    15  
    16  func TestCORSHeaders(t *testing.T) {
    17  	t.Parallel()
    18  
    19  	for _, tc := range []struct {
    20  		name           string
    21  		origin         string
    22  		allowedOrigins []string
    23  		wantCORS       bool
    24  	}{
    25  		{
    26  			name: "none",
    27  		},
    28  		{
    29  			name:           "no origin",
    30  			allowedOrigins: []string{"https://gateway.ethswarm.org"},
    31  			wantCORS:       false,
    32  		},
    33  		{
    34  			name:           "single explicit",
    35  			origin:         "https://gateway.ethswarm.org",
    36  			allowedOrigins: []string{"https://gateway.ethswarm.org"},
    37  			wantCORS:       true,
    38  		},
    39  		{
    40  			name:           "single explicit blocked",
    41  			origin:         "http://a-hacker.me",
    42  			allowedOrigins: []string{"https://gateway.ethswarm.org"},
    43  			wantCORS:       false,
    44  		},
    45  		{
    46  			name:           "multiple explicit",
    47  			origin:         "https://staging.gateway.ethswarm.org",
    48  			allowedOrigins: []string{"https://gateway.ethswarm.org", "https://staging.gateway.ethswarm.org"},
    49  			wantCORS:       true,
    50  		},
    51  		{
    52  			name:           "multiple explicit blocked",
    53  			origin:         "http://a-hacker.me",
    54  			allowedOrigins: []string{"https://gateway.ethswarm.org", "https://staging.gateway.ethswarm.org"},
    55  			wantCORS:       false,
    56  		},
    57  		{
    58  			name:           "wildcard",
    59  			origin:         "http://localhost:1234",
    60  			allowedOrigins: []string{"*"},
    61  			wantCORS:       true,
    62  		},
    63  		{
    64  			name:           "wildcard",
    65  			origin:         "https://gateway.ethswarm.org",
    66  			allowedOrigins: []string{"*"},
    67  			wantCORS:       true,
    68  		},
    69  		{
    70  			name:           "with origin only",
    71  			origin:         "https://gateway.ethswarm.org",
    72  			allowedOrigins: nil,
    73  			wantCORS:       false,
    74  		},
    75  		{
    76  			name:           "with origin only not nil",
    77  			origin:         "https://gateway.ethswarm.org",
    78  			allowedOrigins: []string{},
    79  			wantCORS:       false,
    80  		},
    81  	} {
    82  		tc := tc
    83  		t.Run(tc.name, func(t *testing.T) {
    84  			t.Parallel()
    85  
    86  			ctx, cancel := context.WithCancel(context.Background())
    87  			t.Cleanup(cancel)
    88  
    89  			client, _, _, _ := newTestServer(t, testServerOptions{
    90  				CORSAllowedOrigins: tc.allowedOrigins,
    91  			})
    92  
    93  			req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
    94  			if err != nil {
    95  				t.Fatal(err)
    96  			}
    97  			if tc.origin != "" {
    98  				req.Header.Set(api.OriginHeader, tc.origin)
    99  			}
   100  
   101  			r, err := client.Do(req)
   102  			if err != nil {
   103  				t.Fatal(err)
   104  			}
   105  
   106  			got := r.Header.Get("Access-Control-Allow-Origin")
   107  
   108  			if tc.wantCORS {
   109  				if got != tc.origin {
   110  					t.Errorf("got Access-Control-Allow-Origin %q, want %q", got, tc.origin)
   111  				}
   112  			} else {
   113  				if got != "" {
   114  					t.Errorf("got Access-Control-Allow-Origin %q, want none", got)
   115  				}
   116  			}
   117  		})
   118  	}
   119  
   120  }
   121  
   122  // TestCors tests whether CORs work correctly with OPTIONS method
   123  func TestCors(t *testing.T) {
   124  	t.Parallel()
   125  
   126  	const origin = "example.com"
   127  	for _, tc := range []struct {
   128  		endpoint        string
   129  		expectedMethods string // expectedMethods contains HTTP methods like GET, POST, HEAD, PATCH, DELETE, OPTIONS. These are in alphabetical sorted order
   130  	}{
   131  		{
   132  			endpoint:        "tags",
   133  			expectedMethods: "GET, POST",
   134  		},
   135  		{
   136  			endpoint:        "bzz",
   137  			expectedMethods: "POST",
   138  		}, {
   139  			endpoint:        "bzz/0101011",
   140  			expectedMethods: "GET, HEAD",
   141  		},
   142  		{
   143  			endpoint:        "chunks",
   144  			expectedMethods: "POST",
   145  		},
   146  		{
   147  			endpoint:        "chunks/123213",
   148  			expectedMethods: "GET, HEAD",
   149  		},
   150  		{
   151  			endpoint:        "bytes",
   152  			expectedMethods: "POST",
   153  		},
   154  		{
   155  			endpoint:        "bytes/0121012",
   156  			expectedMethods: "GET, HEAD",
   157  		},
   158  	} {
   159  		tc := tc
   160  		t.Run(tc.endpoint, func(t *testing.T) {
   161  			t.Parallel()
   162  
   163  			client, _, _, _ := newTestServer(t, testServerOptions{
   164  				CORSAllowedOrigins: []string{origin},
   165  			})
   166  
   167  			jsonhttptest.Request(t, client, http.MethodOptions, "/"+tc.endpoint, http.StatusNoContent,
   168  				jsonhttptest.WithRequestHeader(api.OriginHeader, origin),
   169  				jsonhttptest.WithExpectedResponseHeader("Access-Control-Allow-Methods", tc.expectedMethods),
   170  			)
   171  		})
   172  	}
   173  }
   174  
   175  // TestCorsStatus tests whether CORs returns correct allowed method if wrong method is called
   176  func TestCorsStatus(t *testing.T) {
   177  	t.Parallel()
   178  	const origin = "example.com"
   179  	for _, tc := range []struct {
   180  		endpoint          string
   181  		notAllowedMethods string // notAllowedMethods contains HTTP methods like GET, POST, HEAD, PATCH, DELETE, OPTIONS. These are method which is not supported by endpoint
   182  		allowedMethods    string // expectedMethods contains HTTP methods like GET, POST, HEAD, PATCH, DELETE, OPTIONS. These are in alphabetical sorted order
   183  	}{
   184  		{
   185  			endpoint:          "tags",
   186  			notAllowedMethods: http.MethodDelete,
   187  			allowedMethods:    "GET, POST",
   188  		},
   189  		{
   190  			endpoint:          "bzz",
   191  			notAllowedMethods: http.MethodDelete,
   192  			allowedMethods:    "POST",
   193  		},
   194  		{
   195  			endpoint:          "chunks",
   196  			notAllowedMethods: http.MethodDelete,
   197  			allowedMethods:    "POST",
   198  		},
   199  		{
   200  			endpoint:          "chunks/0101011",
   201  			notAllowedMethods: http.MethodPost,
   202  			allowedMethods:    "GET, HEAD",
   203  		},
   204  		{
   205  			endpoint:          "bytes",
   206  			notAllowedMethods: http.MethodDelete,
   207  			allowedMethods:    "POST",
   208  		},
   209  		{
   210  			endpoint:          "bytes/0121012",
   211  			notAllowedMethods: http.MethodDelete,
   212  			allowedMethods:    "GET, HEAD",
   213  		},
   214  	} {
   215  		tc := tc
   216  		t.Run(tc.endpoint, func(t *testing.T) {
   217  			t.Parallel()
   218  
   219  			client, _, _, _ := newTestServer(t, testServerOptions{
   220  				CORSAllowedOrigins: []string{origin},
   221  			})
   222  
   223  			jsonhttptest.Request(t, client, tc.notAllowedMethods, "/"+tc.endpoint, http.StatusMethodNotAllowed,
   224  				jsonhttptest.WithRequestHeader(api.OriginHeader, origin),
   225  				jsonhttptest.WithExpectedResponseHeader("Access-Control-Allow-Methods", tc.allowedMethods),
   226  			)
   227  		})
   228  	}
   229  }