github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/internal/requesttesting/headers/header_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package headers
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  
    24  	"github.com/google/go-safeweb/internal/requesttesting"
    25  )
    26  
    27  func TestHeaderParsing(t *testing.T) {
    28  	var tests = []struct {
    29  		name    string
    30  		request []byte
    31  		want    map[string][]string
    32  	}{
    33  		{
    34  			name: "Basic",
    35  			request: []byte("GET / HTTP/1.1\r\n" +
    36  				"Host: localhost:8080\r\n" +
    37  				"A: B\r\n" +
    38  				"\r\n"),
    39  			want: map[string][]string{"A": {"B"}},
    40  		},
    41  		{
    42  			name: "Ordering",
    43  			request: []byte("GET / HTTP/1.1\r\n" +
    44  				"Host: localhost:8080\r\n" +
    45  				"A: X\r\n" +
    46  				"A: Y\r\n" +
    47  				"\r\n"),
    48  			want: map[string][]string{"A": {"X", "Y"}},
    49  		},
    50  		{
    51  			name: "Casing",
    52  			request: []byte("GET / HTTP/1.1\r\n" +
    53  				"Host: localhost:8080\r\n" +
    54  				"BaBaBaBa-BaBaBaBa-BaBaBa: xXxXxXxX\r\n" +
    55  				"cDcDcDcD-cDcDcDcD-cDcDcD: YyYyYyYy\r\n" +
    56  				"\r\n"),
    57  			want: map[string][]string{
    58  				"Babababa-Babababa-Bababa": {"xXxXxXxX"},
    59  				"Cdcdcdcd-Cdcdcdcd-Cdcdcd": {"YyYyYyYy"},
    60  			},
    61  		},
    62  		{
    63  			name: "CasingOrdering",
    64  			request: []byte("GET / HTTP/1.1\r\n" +
    65  				"Host: localhost:8080\r\n" +
    66  				"a: X\r\n" +
    67  				"A: Y\r\n" +
    68  				"a: Z\r\n" +
    69  				"A: W\r\n" +
    70  				"\r\n"),
    71  			want: map[string][]string{"A": {"X", "Y", "Z", "W"}},
    72  		},
    73  		{
    74  			name: "MultiLineHeaderTab",
    75  			request: []byte("GET / HTTP/1.1\r\n" +
    76  				"Host: localhost:8080\r\n" +
    77  				"AAAA: aaaa aaa\r\n" +
    78  				"\taaa aaa\r\n" +
    79  				"\r\n"),
    80  			want: map[string][]string{"Aaaa": {"aaaa aaa aaa aaa"}},
    81  		},
    82  		{
    83  			name: "MultiLineHeaderManyLines",
    84  			request: []byte("GET / HTTP/1.1\r\n" +
    85  				"Host: localhost:8080\r\n" +
    86  				"AAAA: aaaa aaa\r\n" +
    87  				" aaa\r\n" +
    88  				" aaa\r\n" +
    89  				"\r\n"),
    90  			want: map[string][]string{"Aaaa": {"aaaa aaa aaa aaa"}},
    91  		},
    92  		{
    93  			name: "UnicodeValue",
    94  			request: []byte("GET / HTTP/1.1\r\n" +
    95  				"Host: localhost:8080\r\n" +
    96  				"A: \xf0\x9f\xa5\xb3\r\n" +
    97  				"\r\n"),
    98  			want: map[string][]string{"A": {"\xf0\x9f\xa5\xb3"}},
    99  		},
   100  	}
   101  
   102  	for _, tt := range tests {
   103  		t.Run(tt.name, func(t *testing.T) {
   104  			resp, err := requesttesting.MakeRequest(context.Background(), tt.request, func(r *http.Request) {
   105  				if diff := cmp.Diff(tt.want, map[string][]string(r.Header)); diff != "" {
   106  					t.Errorf("r.Header mismatch (-want +got):\n%s", diff)
   107  				}
   108  			})
   109  			if err != nil {
   110  				t.Fatalf("MakeRequest() got err: %v", err)
   111  			}
   112  
   113  			if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
   114  				t.Errorf("status code got: %q want: %q", got, want)
   115  			}
   116  		})
   117  	}
   118  }
   119  
   120  func TestStatusCode(t *testing.T) {
   121  	var tests = []struct {
   122  		name    string
   123  		request []byte
   124  		want    string
   125  	}{
   126  		{
   127  			name: "MultilineContinuationOnFirstLine",
   128  			request: []byte("GET / HTTP/1.1\r\n" +
   129  				" A: a\r\n" +
   130  				"Host: localhost:8080\r\n" +
   131  				"\r\n"),
   132  			want: statusBadRequestPrefix,
   133  		},
   134  		{
   135  			name: "MultilineHeaderName",
   136  			request: []byte("GET / HTTP/1.1\r\n" +
   137  				"Host: localhost:8080\r\n" +
   138  				"AA\r\n" +
   139  				" AA: aaaa\r\n" +
   140  				"\r\n"),
   141  			want: statusBadRequestPrefix,
   142  		},
   143  		{
   144  			name: "MultilineHeaderBeforeColon",
   145  			request: []byte("GET / HTTP/1.1\r\n" +
   146  				"Host: localhost:8080\r\n" +
   147  				"A\r\n" +
   148  				" : a\r\n" +
   149  				"\r\n"),
   150  			want: statusBadRequestPrefix,
   151  		},
   152  		{
   153  			name: "UnicodeHeaderName",
   154  			request: []byte("GET / HTTP/1.1\r\n" +
   155  				"Host: localhost:8080\r\n" +
   156  				"\xf0\x9f\xa5\xb3: a\r\n" +
   157  				"\r\n"),
   158  			want: statusBadRequestPrefix,
   159  		},
   160  		{
   161  			name: "SpecialCharactersHeaderName",
   162  			request: []byte("GET / HTTP/1.1\r\n" +
   163  				"Host: localhost:8080\r\n" +
   164  				"&%*(#@%()): a\r\n" +
   165  				"\r\n"),
   166  			want: statusBadRequestPrefix,
   167  		},
   168  		{
   169  			name: "NullByteInHeaderName",
   170  			request: []byte("GET / HTTP/1.1\r\n" +
   171  				"Host: localhost:8080\r\n" +
   172  				"A\x00A: a\r\n" +
   173  				"\r\n"),
   174  			want: statusBadRequestPrefix,
   175  		},
   176  		{
   177  			name: "NullByteInHeaderValue",
   178  			request: []byte("GET / HTTP/1.1\r\n" +
   179  				"Host: localhost:8080\r\n" +
   180  				"AA: a\x00a\r\n" +
   181  				"\r\n"),
   182  			want: statusBadRequestPrefix,
   183  		},
   184  	}
   185  
   186  	for _, tt := range tests {
   187  		t.Run(tt.name, func(t *testing.T) {
   188  			resp, err := requesttesting.MakeRequest(context.Background(), tt.request, nil)
   189  			if err != nil {
   190  				t.Fatalf("MakeRequest() got err: %v", err)
   191  			}
   192  
   193  			if got := extractStatus(resp); !matchStatus(got, tt.want) {
   194  				t.Errorf("status code got: %q want: %q", got, tt.want)
   195  			}
   196  		})
   197  	}
   198  }
   199  
   200  // TestValues verifies that the http.Header.Values() function
   201  // returns the header values in the order that they are sent
   202  // in the request to the server.
   203  func TestValues(t *testing.T) {
   204  	var tests = []struct {
   205  		name    string
   206  		request []byte
   207  		want    []string
   208  	}{
   209  		{
   210  			name: "Ordering1",
   211  			request: []byte("GET / HTTP/1.1\r\n" +
   212  				"Host: localhost:8080\r\n" +
   213  				"A: X\r\n" +
   214  				"A: Y\r\n" +
   215  				"\r\n"),
   216  			want: []string{"X", "Y"},
   217  		},
   218  		{
   219  			name: "Ordering2",
   220  			request: []byte("GET / HTTP/1.1\r\n" +
   221  				"Host: localhost:8080\r\n" +
   222  				"A: Y\r\n" +
   223  				"A: X\r\n" +
   224  				"\r\n"),
   225  			want: []string{"Y", "X"},
   226  		},
   227  	}
   228  
   229  	for _, tt := range tests {
   230  		t.Run(tt.name, func(t *testing.T) {
   231  			resp, err := requesttesting.MakeRequest(context.Background(), tt.request, func(r *http.Request) {
   232  				if diff := cmp.Diff(tt.want, r.Header.Values("A")); diff != "" {
   233  					t.Errorf(`r.Header.Values("A") mismatch (-want +got):\n%s`, diff)
   234  				}
   235  			})
   236  			if err != nil {
   237  				t.Fatalf("MakeRequest() got err: %v want: nil", err)
   238  			}
   239  
   240  			if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
   241  				t.Errorf("status code got: %q want: %q", got, want)
   242  			}
   243  		})
   244  	}
   245  }
   246  
   247  func TestMultiLineHeader(t *testing.T) {
   248  	// Multiline continuation has been deprecated as of RFC 7230.
   249  	// " Historically, HTTP header field values could be extended over
   250  	//   multiple lines by preceding each extra line with at least one space
   251  	//   or horizontal tab (obs-fold).  This specification deprecates such
   252  	//   line folding [...]
   253  	//    A server that receives an obs-fold in a request message that is not
   254  	//   within a message/http container MUST either reject the message by
   255  	//   sending a 400 (Bad Request), preferably with a representation
   256  	//   explaining that obsolete line folding is unacceptable, or replace
   257  	//   each received obs-fold with one or more SP octets prior to
   258  	//   interpreting the field value or forwarding the message downstream. "
   259  	// - RFC 7230 Section 3.2.4
   260  	//
   261  	// Currently obs-folds are replaced with spaces before the value of the
   262  	// header is interpreted. This is in line with the RFC. But it would
   263  	// be more robust and future proof to drop the support of multiline
   264  	// continuation entirely and instead respond with a 400 (Bad Request)
   265  	// like the RFC also suggests.
   266  
   267  	request := []byte("GET / HTTP/1.1\r\n" +
   268  		"Host: localhost:8080\r\n" +
   269  		"AAAA: aaaa aaa\r\n" +
   270  		" aaa aaa\r\n" +
   271  		"\r\n")
   272  
   273  	t.Run("Current behavior", func(t *testing.T) {
   274  		resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) {
   275  			want := map[string][]string{"Aaaa": {"aaaa aaa aaa aaa"}}
   276  			if diff := cmp.Diff(want, map[string][]string(r.Header)); diff != "" {
   277  				t.Errorf("r.Header mismatch (-want +got):\n%s", diff)
   278  			}
   279  		})
   280  		if err != nil {
   281  			t.Fatalf("MakeRequest() got err: %v want: nil", err)
   282  		}
   283  
   284  		if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
   285  			t.Errorf("status code got: %q want: %q", got, want)
   286  		}
   287  	})
   288  
   289  	t.Run("Desired behavior", func(t *testing.T) {
   290  		t.Skip()
   291  		resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) {
   292  			t.Error("Expected handler to not be called!")
   293  		})
   294  		if err != nil {
   295  			t.Fatalf("MakeRequest() got err: %v want: nil", err)
   296  		}
   297  
   298  		if got, want := extractStatus(resp), statusBadRequestPrefix; !matchStatus(got, want) {
   299  			t.Errorf("status code got: %q want: %q", got, want)
   300  		}
   301  	})
   302  }