github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/hsts/hsts_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package hsts_test
    16  
    17  import (
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/google/go-cmp/cmp"
    22  	"github.com/google/go-safeweb/safehttp"
    23  	"github.com/google/go-safeweb/safehttp/plugins/hsts"
    24  	"github.com/google/go-safeweb/safehttp/safehttptest"
    25  )
    26  
    27  func TestHSTSReject(t *testing.T) {
    28  	var test = []struct {
    29  		name         string
    30  		interceptor  hsts.Interceptor
    31  		req          *safehttp.IncomingRequest
    32  		wantStatus   safehttp.StatusCode
    33  		wantBody     string
    34  		wantRedirect string
    35  	}{
    36  		{
    37  			name:         "HTTP",
    38  			interceptor:  hsts.Default(),
    39  			req:          safehttptest.NewRequest(safehttp.MethodGet, "http://localhost/", nil),
    40  			wantStatus:   safehttp.StatusMovedPermanently,
    41  			wantRedirect: "https://localhost/",
    42  		},
    43  		{
    44  			name:        "Negative MaxAge",
    45  			interceptor: hsts.Interceptor{MaxAge: -1 * time.Second},
    46  			req:         safehttptest.NewRequest(safehttp.MethodGet, "https://localhost/", nil),
    47  			wantStatus:  safehttp.StatusInternalServerError,
    48  		},
    49  	}
    50  
    51  	for _, tt := range test {
    52  		t.Run(tt.name, func(t *testing.T) {
    53  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
    54  
    55  			tt.interceptor.Before(fakeRW, tt.req, nil)
    56  
    57  			if gotStatus := rr.Code; gotStatus != int(tt.wantStatus) {
    58  				t.Errorf("rr.Code got: %v want: %v", gotStatus, tt.wantStatus)
    59  			}
    60  
    61  			if tt.wantRedirect != "" {
    62  				resp := fakeRW.Dispatcher.Written
    63  				redir, ok := resp.(safehttp.RedirectResponse)
    64  				if !ok {
    65  					t.Fatalf("got %T, wanted a RedirectResponse", resp)
    66  				}
    67  				if got, want := redir.Location, tt.wantRedirect; got != want {
    68  					t.Errorf("RedirectResponse.Location got %q, want %q", got, want)
    69  				}
    70  			}
    71  		})
    72  	}
    73  }
    74  
    75  func TestHSTSOK(t *testing.T) {
    76  	var test = []struct {
    77  		name         string
    78  		interceptor  hsts.Interceptor
    79  		req          *safehttp.IncomingRequest
    80  		wantHeaders  map[string][]string
    81  		wantRedirect string
    82  	}{
    83  		{
    84  			name:        "HTTPS",
    85  			interceptor: hsts.Default(),
    86  			req:         safehttptest.NewRequest(safehttp.MethodGet, "https://localhost/", nil),
    87  			wantHeaders: map[string][]string{
    88  				"Strict-Transport-Security": {"max-age=63072000; includeSubDomains"}, // 63072000 seconds is two years
    89  			},
    90  		},
    91  		{
    92  			name:        "HTTP behind proxy",
    93  			interceptor: hsts.Interceptor{BehindProxy: true},
    94  			req:         safehttptest.NewRequest(safehttp.MethodGet, "http://localhost/", nil),
    95  			wantHeaders: map[string][]string{
    96  				// max-age=0 tells the browser to expire the HSTS protection.
    97  				"Strict-Transport-Security": {"max-age=0; includeSubDomains"},
    98  			},
    99  		},
   100  		{
   101  			name:        "Preload",
   102  			interceptor: hsts.Interceptor{Preload: true, DisableIncludeSubDomains: true},
   103  			req:         safehttptest.NewRequest(safehttp.MethodGet, "https://localhost/", nil),
   104  			wantHeaders: map[string][]string{
   105  				// max-age=0 tells the browser to expire the HSTS protection.
   106  				"Strict-Transport-Security": {"max-age=0; preload"},
   107  			},
   108  		},
   109  		{
   110  			name:        "Preload and IncludeSubDomains",
   111  			interceptor: hsts.Interceptor{Preload: true},
   112  			req:         safehttptest.NewRequest(safehttp.MethodGet, "https://localhost/", nil),
   113  			wantHeaders: map[string][]string{
   114  				// max-age=0 tells the browser to expire the HSTS protection.
   115  				"Strict-Transport-Security": {"max-age=0; includeSubDomains; preload"},
   116  			},
   117  		},
   118  		{
   119  			name:        "No preload and no includeSubDomains",
   120  			interceptor: hsts.Interceptor{DisableIncludeSubDomains: true},
   121  			req:         safehttptest.NewRequest(safehttp.MethodGet, "https://localhost/", nil),
   122  			wantHeaders: map[string][]string{
   123  				// max-age=0 tells the browser to expire the HSTS protection.
   124  				"Strict-Transport-Security": {"max-age=0"},
   125  			},
   126  		},
   127  		{
   128  			name:        "Custom MaxAge",
   129  			interceptor: hsts.Interceptor{MaxAge: 3600 * time.Second},
   130  			req:         safehttptest.NewRequest(safehttp.MethodGet, "https://localhost/", nil),
   131  			wantHeaders: map[string][]string{
   132  				"Strict-Transport-Security": {"max-age=3600; includeSubDomains"}, // 3600 seconds is 1 hour
   133  			},
   134  		},
   135  	}
   136  
   137  	for _, tt := range test {
   138  		t.Run(tt.name, func(t *testing.T) {
   139  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   140  
   141  			tt.interceptor.Before(fakeRW, tt.req, nil)
   142  
   143  			if gotStatus := rr.Code; gotStatus != int(safehttp.StatusOK) {
   144  				t.Errorf("rr.Code got: %v want: %v", gotStatus, safehttp.StatusOK)
   145  			}
   146  
   147  			if diff := cmp.Diff(tt.wantHeaders, map[string][]string(rr.Header())); diff != "" {
   148  				t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   149  			}
   150  		})
   151  	}
   152  }