github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/internal/requesttesting/headers/basicauth_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package headers
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  
    24  	"github.com/google/go-safeweb/internal/requesttesting"
    25  )
    26  
    27  func TestBasicAuth(t *testing.T) {
    28  	type basicAuth struct {
    29  		username string
    30  		password string
    31  		ok       bool
    32  	}
    33  
    34  	var tests = []struct {
    35  		name          string
    36  		request       []byte
    37  		wantBasicAuth basicAuth
    38  		wantHeaders   map[string][]string
    39  	}{
    40  		{
    41  			name: "Basic",
    42  			request: []byte("GET / HTTP/1.1\r\n" +
    43  				"Host: localhost:8080\r\n" +
    44  				// Base64 encoding of "Pizza:Password".
    45  				"Authorization: Basic UGl6emE6UGFzc3dvcmQ=\r\n" +
    46  				"\r\n"),
    47  			wantBasicAuth: basicAuth{
    48  				username: "Pizza",
    49  				password: "Password",
    50  				ok:       true,
    51  			},
    52  			// Same Base64 as above.
    53  			wantHeaders: map[string][]string{"Authorization": {"Basic UGl6emE6UGFzc3dvcmQ="}},
    54  		},
    55  		{
    56  			name: "NoTrailingEquals",
    57  			request: []byte("GET / HTTP/1.1\r\n" +
    58  				"Host: localhost:8080\r\n" +
    59  				// Base64 encoding of "Pizza:Password" without trailing equals.
    60  				"Authorization: Basic UGl6emE6UGFzc3dvcmQ\r\n" +
    61  				"\r\n"),
    62  			wantBasicAuth: basicAuth{
    63  				username: "",
    64  				password: "",
    65  				ok:       false,
    66  			},
    67  			// Same Base64 as above.
    68  			wantHeaders: map[string][]string{"Authorization": {"Basic UGl6emE6UGFzc3dvcmQ"}},
    69  		},
    70  		{
    71  			name: "DoubleColon",
    72  			request: []byte("GET / HTTP/1.1\r\n" +
    73  				"Host: localhost:8080\r\n" +
    74  				// Base64 encoding of "Pizza:Password:Password".
    75  				"Authorization: Basic UGl6emE6UGFzc3dvcmQ6UGFzc3dvcmQ=\r\n" +
    76  				"\r\n"),
    77  			wantBasicAuth: basicAuth{
    78  				username: "Pizza",
    79  				password: "Password:Password",
    80  				ok:       true,
    81  			},
    82  			// Same Base64 as above.
    83  			wantHeaders: map[string][]string{"Authorization": {"Basic UGl6emE6UGFzc3dvcmQ6UGFzc3dvcmQ="}},
    84  		},
    85  		{
    86  			name: "NotBasic",
    87  			request: []byte("GET / HTTP/1.1\r\n" +
    88  				"Host: localhost:8080\r\n" +
    89  				// Base64 encoding of "Pizza:Password:Password".
    90  				"Authorization: xasic UGl6emE6UGFzc3dvcmQ6UGFzc3dvcmQ=\r\n" +
    91  				"\r\n"),
    92  			wantBasicAuth: basicAuth{
    93  				username: "",
    94  				password: "",
    95  				ok:       false,
    96  			},
    97  			// Same Base64 as above.
    98  			wantHeaders: map[string][]string{"Authorization": {"xasic UGl6emE6UGFzc3dvcmQ6UGFzc3dvcmQ="}},
    99  		},
   100  		{
   101  			name: "CasingOrdering1",
   102  			request: []byte("GET / HTTP/1.1\r\n" +
   103  				"Host: localhost:8080\r\n" +
   104  				// Base64 encoding of "AAA:aaa".
   105  				"Authorization: basic QUFBOmFhYQ==\r\n" +
   106  				// Base64 encoding of "BBB:bbb".
   107  				"authorization: basic QkJCOmJiYg==\r\n" +
   108  				"\r\n"),
   109  			wantBasicAuth: basicAuth{
   110  				username: "AAA",
   111  				password: "aaa",
   112  				ok:       true,
   113  			},
   114  			// Base64 encoding of "AAA:aaa" and then of "BBB:bbb" in that order.
   115  			wantHeaders: map[string][]string{"Authorization": {"basic QUFBOmFhYQ==", "basic QkJCOmJiYg=="}},
   116  		},
   117  		{
   118  			name: "CasingOrdering2",
   119  			request: []byte("GET / HTTP/1.1\r\n" +
   120  				"Host: localhost:8080\r\n" +
   121  				// Base64 encoding of "AAA:aaa".
   122  				"authorization: basic QUFBOmFhYQ==\r\n" +
   123  				// Base64 encoding of "BBB:bbb".
   124  				"Authorization: basic QkJCOmJiYg==\r\n" +
   125  				"\r\n"),
   126  			wantBasicAuth: basicAuth{
   127  				username: "AAA",
   128  				password: "aaa",
   129  				ok:       true,
   130  			},
   131  			// Base64 encoding of "AAA:aaa" and then of "BBB:bbb" in that order.
   132  			wantHeaders: map[string][]string{"Authorization": {"basic QUFBOmFhYQ==", "basic QkJCOmJiYg=="}},
   133  		},
   134  	}
   135  
   136  	for _, tt := range tests {
   137  		t.Run(tt.name, func(t *testing.T) {
   138  			resp, err := requesttesting.MakeRequest(context.Background(), tt.request, func(r *http.Request) {
   139  				if diff := cmp.Diff(tt.wantHeaders, map[string][]string(r.Header)); diff != "" {
   140  					t.Errorf("r.Header mismatch (-want +got):\n%s", diff)
   141  				}
   142  
   143  				username, password, ok := r.BasicAuth()
   144  				if ok != tt.wantBasicAuth.ok {
   145  					t.Errorf("_, _, ok := r.BasicAuth() got: %v want: %v", ok, tt.wantBasicAuth.ok)
   146  				}
   147  
   148  				if username != tt.wantBasicAuth.username {
   149  					t.Errorf("username, _, _ := r.BasicAuth() got: %q want: %q", username, tt.wantBasicAuth.username)
   150  				}
   151  
   152  				if password != tt.wantBasicAuth.password {
   153  					t.Errorf("_, password, _ := r.BasicAuth() got: %q want: %q", password, tt.wantBasicAuth.password)
   154  				}
   155  			})
   156  			if err != nil {
   157  				t.Fatalf("MakeRequest() got err: %v", err)
   158  			}
   159  
   160  			if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
   161  				t.Errorf("status code got: %q want: %q", got, want)
   162  			}
   163  		})
   164  	}
   165  }
   166  
   167  func TestBasicAuthOrdering(t *testing.T) {
   168  	// The documentation of http.Request.BasicAuth() doesn't clearly specify
   169  	// that only the first Authorization header is used and the other ones
   170  	// are ignored. This could potentially lead to security issues if two
   171  	// HTTP servers that look at different headers are chained together.
   172  	//
   173  	// The desired behavior would be to respond with 400 (Bad Request) when
   174  	// there is more than one Authorization header.
   175  
   176  	request := []byte("GET / HTTP/1.1\r\n" +
   177  		"Host: localhost:8080\r\n" +
   178  		// Base64 encoding of "AAA:aaa".
   179  		"Authorization: basic QUFBOmFhYQ==\r\n" +
   180  		// Base64 encoding of "BBB:bbb".
   181  		"Authorization: basic QkJCOmJiYg==\r\n" +
   182  		"\r\n")
   183  
   184  	t.Run("Current behavior", func(t *testing.T) {
   185  		resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) {
   186  			// Base64 encoding of "AAA:aaa" and then of "BBB:bbb" in that order.
   187  			wantHeaders := map[string][]string{"Authorization": {"basic QUFBOmFhYQ==", "basic QkJCOmJiYg=="}}
   188  			if diff := cmp.Diff(wantHeaders, map[string][]string(r.Header)); diff != "" {
   189  				t.Errorf("r.Header mismatch (-want +got):\n%s", diff)
   190  			}
   191  
   192  			username, password, ok := r.BasicAuth()
   193  			if want := true; ok != want {
   194  				t.Errorf("_, _, ok := r.BasicAuth() got: %v want: %v", ok, want)
   195  			}
   196  
   197  			if want := "AAA"; username != want {
   198  				t.Errorf("username, _, _ := r.BasicAuth() got: %q want: %q", username, want)
   199  			}
   200  
   201  			if want := "aaa"; password != want {
   202  				t.Errorf("_, password, _ := r.BasicAuth() got: %q want: %q", password, want)
   203  			}
   204  		})
   205  		if err != nil {
   206  			t.Fatalf("MakeRequest() got err: %v want: nil", err)
   207  		}
   208  
   209  		if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) {
   210  			t.Errorf("status code got: %q want: %q", got, want)
   211  		}
   212  	})
   213  
   214  	t.Run("Desired behavior", func(t *testing.T) {
   215  		t.Skip()
   216  		resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) {
   217  			t.Error("Expected handler to not be called!")
   218  		})
   219  		if err != nil {
   220  			t.Fatalf("MakeRequest() got err: %v want: nil", err)
   221  		}
   222  
   223  		if got, want := extractStatus(resp), statusBadRequestPrefix; !matchStatus(got, want) {
   224  			t.Errorf("status code got: %q want prefix: %q", got, want)
   225  		}
   226  	})
   227  }