github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/xsrf/xsrfblockall/xsrf_test.go (about)

     1  // Copyright 2022 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package xsrfblockall
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/google/go-safeweb/safehttp"
    21  	"github.com/google/go-safeweb/safehttp/safehttptest"
    22  )
    23  
    24  func TestStateChanging(t *testing.T) {
    25  	test := []struct {
    26  		name       string
    27  		host       string
    28  		method     string
    29  		wantStatus safehttp.StatusCode
    30  	}{
    31  		{
    32  			name:       "POST request",
    33  			host:       "foo.com",
    34  			method:     safehttp.MethodPost,
    35  			wantStatus: safehttp.StatusForbidden,
    36  		},
    37  		{
    38  			name:       "PUT request",
    39  			host:       "foo.com",
    40  			method:     safehttp.MethodPut,
    41  			wantStatus: safehttp.StatusForbidden,
    42  		},
    43  		{
    44  			name:       "DELETE request",
    45  			host:       "foo.com",
    46  			method:     safehttp.MethodDelete,
    47  			wantStatus: safehttp.StatusForbidden,
    48  		},
    49  		{
    50  			name:       "PATCH request",
    51  			host:       "foo.com",
    52  			method:     safehttp.MethodPatch,
    53  			wantStatus: safehttp.StatusForbidden,
    54  		},
    55  		{
    56  			name:       "TRACE request",
    57  			host:       "foo.com",
    58  			method:     safehttp.MethodTrace,
    59  			wantStatus: safehttp.StatusForbidden,
    60  		},
    61  		{
    62  			name:       "CONNECT request",
    63  			host:       "foo.com",
    64  			method:     safehttp.MethodConnect,
    65  			wantStatus: safehttp.StatusForbidden,
    66  		},
    67  	}
    68  
    69  	for _, test := range test {
    70  		t.Run(test.name, func(t *testing.T) {
    71  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
    72  			req := safehttptest.NewRequest(test.method, "https://foo.com/", nil)
    73  
    74  			i := Interceptor{}
    75  			i.Before(fakeRW, req, nil)
    76  
    77  			if got := rr.Code; got != int(test.wantStatus) {
    78  				t.Errorf("rr.Code: got %v, want %v", got, test.wantStatus)
    79  			}
    80  		})
    81  	}
    82  }
    83  
    84  func TestNonStateChanging(t *testing.T) {
    85  	test := []struct {
    86  		name       string
    87  		host       string
    88  		method     string
    89  		wantStatus safehttp.StatusCode
    90  	}{
    91  		{
    92  			name:       "GET request",
    93  			host:       "foo.com",
    94  			method:     safehttp.MethodGet,
    95  			wantStatus: safehttp.StatusOK,
    96  		},
    97  		{
    98  			name:       "HEAD request",
    99  			host:       "foo.com",
   100  			method:     safehttp.MethodHead,
   101  			wantStatus: safehttp.StatusOK,
   102  		},
   103  		{
   104  			name:       "OPTIONS request",
   105  			host:       "foo.com",
   106  			method:     safehttp.MethodOptions,
   107  			wantStatus: safehttp.StatusOK,
   108  		},
   109  	}
   110  
   111  	for _, test := range test {
   112  		t.Run(test.name, func(t *testing.T) {
   113  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   114  			req := safehttptest.NewRequest(test.method, "https://foo.com/", nil)
   115  
   116  			i := Interceptor{}
   117  			i.Before(fakeRW, req, nil)
   118  
   119  			if got := rr.Code; got != int(test.wantStatus) {
   120  				t.Errorf("rr.Code: got %v, want %v", got, test.wantStatus)
   121  			}
   122  		})
   123  	}
   124  }