github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/internal/requesttesting/headers/useragent_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package headers
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  	"github.com/google/go-safeweb/internal/requesttesting"
    24  )
    25  
    26  func TestUserAgent(t *testing.T) {
    27  	type testWant struct {
    28  		headers   map[string][]string
    29  		useragent string
    30  	}
    31  
    32  	var tests = []struct {
    33  		name    string
    34  		request []byte
    35  		want    testWant
    36  	}{
    37  		{
    38  			name: "Basic",
    39  			request: []byte("GET / HTTP/1.1\r\n" +
    40  				"Host: localhost:8080\r\n" +
    41  				"User-Agent: BlahBlah\r\n" +
    42  				"\r\n"),
    43  			want: testWant{
    44  				headers:   map[string][]string{"User-Agent": {"BlahBlah"}},
    45  				useragent: "BlahBlah",
    46  			},
    47  		},
    48  		{
    49  			name: "CasingOrdering1",
    50  			request: []byte("GET / HTTP/1.1\r\n" +
    51  				"Host: localhost:8080\r\n" +
    52  				"user-Agent: BlahBlah\r\n" +
    53  				"User-Agent: FooFoo\r\n" +
    54  				"\r\n"),
    55  			want: testWant{
    56  				headers:   map[string][]string{"User-Agent": {"BlahBlah", "FooFoo"}},
    57  				useragent: "BlahBlah",
    58  			},
    59  		},
    60  		{
    61  			name: "CasingOrdering1",
    62  			request: []byte("GET / HTTP/1.1\r\n" +
    63  				"Host: localhost:8080\r\n" +
    64  				"User-Agent: BlahBlah\r\n" +
    65  				"user-Agent: FooFoo\r\n" +
    66  				"\r\n"),
    67  			want: testWant{
    68  				headers:   map[string][]string{"User-Agent": {"BlahBlah", "FooFoo"}},
    69  				useragent: "BlahBlah",
    70  			},
    71  		},
    72  	}
    73  
    74  	for _, tt := range tests {
    75  		t.Run(tt.name, func(t *testing.T) {
    76  			resp, err := requesttesting.MakeRequest(context.Background(), tt.request, func(r *http.Request) {
    77  				if diff := cmp.Diff(tt.want.headers, map[string][]string(r.Header)); diff != "" {
    78  					t.Errorf("r.Header mismatch (-want +got):\n%s", diff)
    79  				}
    80  
    81  				if r.UserAgent() != tt.want.useragent {
    82  					t.Errorf("r.UserAgent() got: %q want: %q", r.UserAgent(), tt.want.useragent)
    83  				}
    84  			})
    85  			if err != nil {
    86  				t.Fatalf("MakeRequest() got err: %v", err)
    87  			}
    88  
    89  			if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
    90  				t.Errorf("status code got: %q want: %q", got, want)
    91  			}
    92  		})
    93  	}
    94  }
    95  
    96  func TestUserAgentOrdering(t *testing.T) {
    97  	// The documentation of http.Request.UserAgent() doesn't clearly specify
    98  	// that only the first User-Agent header is used and that the other ones
    99  	// are ignored. This could potentially lead to security issues if two
   100  	// HTTP servers that look at different headers are chained together.
   101  	//
   102  	// The desired behavior would be to respond with 400 (Bad Request)
   103  	// when there is more than one User-Agent header.
   104  
   105  	request := []byte("GET / HTTP/1.1\r\n" +
   106  		"Host: localhost:8080\r\n" +
   107  		"User-Agent: BlahBlah\r\n" +
   108  		"User-Agent: FooFoo\r\n" +
   109  		"\r\n")
   110  
   111  	t.Run("Current behavior", func(t *testing.T) {
   112  		resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) {
   113  			wantHeaders := map[string][]string{"User-Agent": {"BlahBlah", "FooFoo"}}
   114  			if diff := cmp.Diff(wantHeaders, map[string][]string(r.Header)); diff != "" {
   115  				t.Errorf("r.Header mismatch (-want +got):\n%s", diff)
   116  			}
   117  
   118  			if want := "BlahBlah"; r.UserAgent() != want {
   119  				t.Errorf("r.UserAgent() got: %q want: %q", r.UserAgent(), want)
   120  			}
   121  		})
   122  		if err != nil {
   123  			t.Fatalf("MakeRequest() got err: %v want: nil", err)
   124  		}
   125  
   126  		if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
   127  			t.Errorf("status code got: %q want: %q", got, want)
   128  		}
   129  	})
   130  
   131  	t.Run("Desired behavior", func(t *testing.T) {
   132  		t.Skip()
   133  		resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) {
   134  			t.Error("Expected handler to not be called!")
   135  		})
   136  		if err != nil {
   137  			t.Fatalf("MakeRequest() got err: %v want: nil", err)
   138  		}
   139  
   140  		if got, want := extractStatus(resp), statusBadRequestPrefix; !matchStatus(got, want) {
   141  			t.Errorf("status code got: %q want: %q", got, want)
   142  		}
   143  	})
   144  }