github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/fetchmetadata/framing_test.go (about)

     1  // Copyright 2022 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package fetchmetadata_test
    16  
    17  import (
    18  	"net/http"
    19  	"testing"
    20  
    21  	"github.com/google/go-cmp/cmp"
    22  	"github.com/google/go-safeweb/safehttp/plugins/fetchmetadata"
    23  	"github.com/google/go-safeweb/safehttp/plugins/framing/internalunsafeframing"
    24  
    25  	"github.com/google/go-safeweb/safehttp"
    26  	"github.com/google/go-safeweb/safehttp/safehttptest"
    27  )
    28  
    29  var (
    30  	allowedFIPHeaders = []testHeaders{
    31  		{
    32  			name: "Fetch Metadata not supported",
    33  		},
    34  		{
    35  			name: "Same origin",
    36  			site: "same-origin",
    37  			mode: "navigate",
    38  			dest: "frame",
    39  		},
    40  		{
    41  			name: "User agent initiated",
    42  			site: "none",
    43  			mode: "navigate",
    44  			dest: "frame",
    45  		},
    46  		{
    47  			name: "Non-navigational",
    48  			site: "cross-site",
    49  			mode: "cors",
    50  			dest: "frame",
    51  		},
    52  		{
    53  			name: "Non-frameable",
    54  			site: "cross-site",
    55  			mode: "navigate",
    56  			dest: "script",
    57  		},
    58  	}
    59  	disallowedFIPHeaders = []testHeaders{
    60  		{
    61  			name: "Cross origin frame",
    62  			site: "cross-origin",
    63  			mode: "navigate",
    64  			dest: "frame",
    65  		},
    66  		{
    67  			name: "Same site, corss origin embed",
    68  			site: "same-site",
    69  			mode: "nested-navigate",
    70  			dest: "embed",
    71  		},
    72  	}
    73  )
    74  
    75  func TestAllowedFramingIsolationEnforceMode(t *testing.T) {
    76  	for _, test := range allowedFIPHeaders {
    77  		t.Run(test.name, func(t *testing.T) {
    78  			req := safehttptest.NewRequest("GET", "https://spaghetti.com/carbonara", nil)
    79  			req.Header.Add("Sec-Fetch-Site", test.site)
    80  			req.Header.Add("Sec-Fetch-Mode", test.mode)
    81  			req.Header.Add("Sec-Fetch-Dest", test.dest)
    82  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
    83  
    84  			p := fetchmetadata.FramingIsolationPolicy()
    85  			p.Before(fakeRW, req, nil)
    86  
    87  			if want, got := int(safehttp.StatusOK), rr.Code; got != want {
    88  				t.Errorf("rr.Code got: %v want: %v", got, want)
    89  			}
    90  			if diff := cmp.Diff(http.Header{}, rr.Header()); diff != "" {
    91  				t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
    92  			}
    93  			if want, got := "", rr.Body.String(); got != want {
    94  				t.Errorf("rr.Body.String() got: %q want: %q", got, want)
    95  			}
    96  		})
    97  	}
    98  }
    99  
   100  func TestRejectedFramingIsolationEnforceMode(t *testing.T) {
   101  	for _, test := range disallowedFIPHeaders {
   102  		t.Run(test.name, func(t *testing.T) {
   103  			req := safehttptest.NewRequest("GET", "https://spaghetti.com/carbonara", nil)
   104  			req.Header.Add("Sec-Fetch-Site", test.site)
   105  			req.Header.Add("Sec-Fetch-Mode", test.mode)
   106  			req.Header.Add("Sec-Fetch-Dest", test.dest)
   107  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   108  
   109  			p := fetchmetadata.FramingIsolationPolicy()
   110  			p.Before(fakeRW, req, nil)
   111  
   112  			if want, got := safehttp.StatusForbidden, safehttp.StatusCode(rr.Code); want != got {
   113  				t.Errorf("rr.Code got: %v want: %v", got, want)
   114  			}
   115  			if diff := cmp.Diff(http.Header{}, rr.Header()); diff != "" {
   116  				t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   117  			}
   118  		})
   119  	}
   120  }
   121  
   122  func TestDisableFramingIsolationPolicy(t *testing.T) {
   123  	type reportTests struct {
   124  		name, site, mode, dest, block string
   125  	}
   126  	var tests []reportTests
   127  	for _, t := range allowedFIPHeaders {
   128  		tests = append(tests, reportTests{
   129  			name:  t.name,
   130  			site:  t.site,
   131  			mode:  t.mode,
   132  			dest:  t.dest,
   133  			block: "false",
   134  		})
   135  	}
   136  	for _, t := range disallowedFIPHeaders {
   137  		tests = append(tests, reportTests{
   138  			name:  t.name,
   139  			site:  t.site,
   140  			mode:  t.mode,
   141  			dest:  t.dest,
   142  			block: "true",
   143  		})
   144  	}
   145  	overrides := []struct {
   146  		name  string
   147  		value safehttp.InterceptorConfig
   148  	}{
   149  		{"disable", internalunsafeframing.Disable{SkipReports: true}},
   150  		{"allowlist", internalunsafeframing.AllowList{}},
   151  	}
   152  	for _, override := range overrides {
   153  		for _, test := range tests {
   154  			t.Run(test.name+" "+override.name, func(t *testing.T) {
   155  				req := safehttptest.NewRequest("GET", "https://spaghetti.com/carbonara", nil)
   156  				req.Header.Add("Sec-Fetch-Site", test.site)
   157  				req.Header.Add("Sec-Fetch-Mode", test.mode)
   158  				req.Header.Add("Sec-Fetch-Dest", test.dest)
   159  				fakeRW, rr := safehttptest.NewFakeResponseWriter()
   160  
   161  				p := fetchmetadata.FramingIsolationPolicy()
   162  				p.Before(fakeRW, req, override.value)
   163  
   164  				if want, got := safehttp.StatusOK, safehttp.StatusCode(rr.Code); want != got {
   165  					t.Errorf("rr.Code got: %v want: %v", got, want)
   166  				}
   167  				if diff := cmp.Diff(http.Header{}, rr.Header()); diff != "" {
   168  					t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   169  				}
   170  				if want, got := "", rr.Body.String(); got != want {
   171  					t.Errorf("rr.Body.String() got: %q want: %q", got, want)
   172  				}
   173  			})
   174  		}
   175  	}
   176  }