github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/internal/requesttesting/headers/referer_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package headers
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  
    24  	"github.com/google/go-safeweb/internal/requesttesting"
    25  )
    26  
    27  func TestReferer(t *testing.T) {
    28  	type testWant struct {
    29  		headers map[string][]string
    30  		referer string
    31  	}
    32  
    33  	var tests = []struct {
    34  		name    string
    35  		request []byte
    36  		want    testWant
    37  	}{
    38  		{
    39  			name: "Basic",
    40  			request: []byte("GET / HTTP/1.1\r\n" +
    41  				"Host: localhost:8080\r\n" +
    42  				"Referer: http://example.com\r\n" +
    43  				"\r\n"),
    44  			want: testWant{
    45  				headers: map[string][]string{"Referer": {"http://example.com"}},
    46  				referer: "http://example.com",
    47  			},
    48  		},
    49  		{
    50  			name: "CasingOrdering1",
    51  			request: []byte("GET / HTTP/1.1\r\n" +
    52  				"Host: localhost:8080\r\n" +
    53  				"referer: http://example.com\r\n" +
    54  				"Referer: http://evil.com\r\n" +
    55  				"\r\n"),
    56  			want: testWant{
    57  				headers: map[string][]string{"Referer": {"http://example.com", "http://evil.com"}},
    58  				referer: "http://example.com",
    59  			},
    60  		},
    61  		{
    62  			name: "CasingOrdering2",
    63  			request: []byte("GET / HTTP/1.1\r\n" +
    64  				"Host: localhost:8080\r\n" +
    65  				"Referer: http://example.com\r\n" +
    66  				"referer: http://evil.com\r\n" +
    67  				"\r\n"),
    68  			want: testWant{
    69  				headers: map[string][]string{"Referer": {"http://example.com", "http://evil.com"}},
    70  				referer: "http://example.com",
    71  			},
    72  		},
    73  	}
    74  
    75  	for _, tt := range tests {
    76  		t.Run(tt.name, func(t *testing.T) {
    77  			resp, err := requesttesting.MakeRequest(context.Background(), tt.request, func(r *http.Request) {
    78  				if diff := cmp.Diff(tt.want.headers, map[string][]string(r.Header)); diff != "" {
    79  					t.Errorf("r.Header mismatch (-want +got):\n%s", diff)
    80  				}
    81  
    82  				if r.Referer() != tt.want.referer {
    83  					t.Errorf("r.Referer() got: %q want: %q", r.Referer(), tt.want.referer)
    84  				}
    85  			})
    86  			if err != nil {
    87  				t.Fatalf("MakeRequest() got err: %v", err)
    88  			}
    89  
    90  			if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
    91  				t.Errorf("status code got: %q want: %q", got, want)
    92  			}
    93  		})
    94  	}
    95  }
    96  
    97  func TestRefererOrdering(t *testing.T) {
    98  	// The documentation of http.Request.Referer() doesn't clearly specify
    99  	// that only the first Referer header is used and that the other ones
   100  	// are ignored. This could potentially lead to security issues if two
   101  	// HTTP servers that look at different headers are chained together.
   102  	//
   103  	// The desired behavior would be to respond with 400 (Bad Request)
   104  	// when there is more than one Referer header.
   105  
   106  	request := []byte("GET / HTTP/1.1\r\n" +
   107  		"Host: localhost:8080\r\n" +
   108  		"Referer: http://example.com\r\n" +
   109  		"Referer: http://evil.com\r\n" +
   110  		"\r\n")
   111  
   112  	t.Run("Current behavior", func(t *testing.T) {
   113  		resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) {
   114  			wantHeaders := map[string][]string{"Referer": {"http://example.com", "http://evil.com"}}
   115  			if diff := cmp.Diff(wantHeaders, map[string][]string(r.Header)); diff != "" {
   116  				t.Errorf("r.Header mismatch (-want +got):\n%s", diff)
   117  			}
   118  
   119  			if want := "http://example.com"; r.Referer() != want {
   120  				t.Errorf("r.Referer() got: %q want: %q", r.Referer(), want)
   121  			}
   122  		})
   123  		if err != nil {
   124  			t.Fatalf("MakeRequest() got err: %v want: nil", err)
   125  		}
   126  
   127  		if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
   128  			t.Errorf("status code got: %q want: %q", got, want)
   129  		}
   130  	})
   131  
   132  	t.Run("Desired behavior", func(t *testing.T) {
   133  		t.Skip()
   134  		resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) {
   135  			t.Error("Expected handler to not be called!")
   136  		})
   137  		if err != nil {
   138  			t.Fatalf("MakeRequest() got err: %v want: nil", err)
   139  		}
   140  
   141  		if got, want := extractStatus(resp), statusBadRequestPrefix; !matchStatus(got, want) {
   142  			t.Errorf("status code got: %q want: %q", got, want)
   143  		}
   144  	})
   145  }