github.com/google/martian/v3@v3.3.3/proxyutil/header_test.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxyutil
    16  
    17  import (
    18  	"net/http"
    19  	"reflect"
    20  	"testing"
    21  )
    22  
    23  func TestRequestHeader(t *testing.T) {
    24  	req, err := http.NewRequest("GET", "http://example.com", nil)
    25  	if err != nil {
    26  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
    27  	}
    28  
    29  	h := RequestHeader(req)
    30  
    31  	tt := []struct {
    32  		name  string
    33  		value string
    34  	}{
    35  		{
    36  			name:  "Host",
    37  			value: "example.com",
    38  		},
    39  		{
    40  			name:  "Test-Header",
    41  			value: "true",
    42  		},
    43  		{
    44  			name:  "Content-Length",
    45  			value: "100",
    46  		},
    47  		{
    48  			name:  "Transfer-Encoding",
    49  			value: "chunked",
    50  		},
    51  	}
    52  
    53  	for i, tc := range tt {
    54  		if err := h.Set(tc.name, tc.value); err != nil {
    55  			t.Errorf("%d. h.Set(%q, %q): got %v, want no error", i, tc.name, tc.value, err)
    56  		}
    57  	}
    58  
    59  	if got, want := req.Host, "example.com"; got != want {
    60  		t.Errorf("req.Host: got %q, want %q", got, want)
    61  	}
    62  	if got, want := req.Header.Get("Test-Header"), "true"; got != want {
    63  		t.Errorf("req.Header.Get(%q): got %q, want %q", "Test-Header", got, want)
    64  	}
    65  	if got, want := req.ContentLength, int64(100); got != want {
    66  		t.Errorf("req.ContentLength: got %d, want %d", got, want)
    67  	}
    68  	if got, want := req.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(got, want) {
    69  		t.Errorf("req.TransferEncoding: got %v, want %v", got, want)
    70  	}
    71  
    72  	if got, want := len(h.Map()), 4; got != want {
    73  		t.Errorf("h.Map(): got %d entries, want %d entries", got, want)
    74  	}
    75  
    76  	for n, vs := range h.Map() {
    77  		var want string
    78  		switch n {
    79  		case "Host":
    80  			want = "example.com"
    81  		case "Content-Length":
    82  			want = "100"
    83  		case "Transfer-Encoding":
    84  			want = "chunked"
    85  		case "Test-Header":
    86  			want = "true"
    87  		default:
    88  			t.Errorf("h.Map(): got unexpected %s header", n)
    89  		}
    90  
    91  		if got := vs[0]; got != want {
    92  			t.Errorf("h.Map(): got %s header with value %s, want value %s", n, got, want)
    93  		}
    94  	}
    95  
    96  	for i, tc := range tt {
    97  		got, ok := h.All(tc.name)
    98  		if !ok {
    99  			t.Errorf("%d. h.All(%q): got false, want true", i, tc.name)
   100  		}
   101  
   102  		if want := []string{tc.value}; !reflect.DeepEqual(got, want) {
   103  			t.Errorf("%d. h.All(%q): got %v, want %v", i, tc.name, got, want)
   104  		}
   105  
   106  		if got, want := h.Get(tc.name), tc.value; got != want {
   107  			t.Errorf("%d. h.Get(%q): got %q, want %q", i, tc.name, got, want)
   108  		}
   109  
   110  		h.Del(tc.name)
   111  	}
   112  
   113  	if got, want := req.Host, ""; got != want {
   114  		t.Errorf("req.Host: got %q, want %q", got, want)
   115  	}
   116  	if got, want := req.Header.Get("Test-Header"), ""; got != want {
   117  		t.Errorf("req.Header.Get(%q): got %q, want %q", "Test-Header", got, want)
   118  	}
   119  	if got, want := req.ContentLength, int64(-1); got != want {
   120  		t.Errorf("req.ContentLength: got %d, want %d", got, want)
   121  	}
   122  	if got := req.TransferEncoding; got != nil {
   123  		t.Errorf("req.TransferEncoding: got %v, want nil", got)
   124  	}
   125  
   126  	for i, tc := range tt {
   127  		if got, want := h.Get(tc.name), ""; got != want {
   128  			t.Errorf("%d. h.Get(%q): got %q, want %q", i, tc.name, got, want)
   129  		}
   130  
   131  		got, ok := h.All(tc.name)
   132  		if ok {
   133  			t.Errorf("%d. h.All(%q): got ok, want !ok", i, tc.name)
   134  		}
   135  		if got != nil {
   136  			t.Errorf("%d. h.All(%q): got %v, want nil", i, tc.name, got)
   137  		}
   138  	}
   139  }
   140  
   141  func TestRequestHeaderAdd(t *testing.T) {
   142  	req, err := http.NewRequest("GET", "http://example.com", nil)
   143  	if err != nil {
   144  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   145  	}
   146  	req.Host = "" // Set to empty so add may overwrite.
   147  
   148  	h := RequestHeader(req)
   149  
   150  	tt := []struct {
   151  		name             string
   152  		values           []string
   153  		errOnSecondValue bool
   154  	}{
   155  		{
   156  			name:             "Host",
   157  			values:           []string{"example.com", "invalid.com"},
   158  			errOnSecondValue: true,
   159  		},
   160  		{
   161  			name:   "Test-Header",
   162  			values: []string{"first", "second"},
   163  		},
   164  		{
   165  			name:             "Content-Length",
   166  			values:           []string{"100", "101"},
   167  			errOnSecondValue: true,
   168  		},
   169  		{
   170  			name:   "Transfer-Encoding",
   171  			values: []string{"chunked", "gzip"},
   172  		},
   173  	}
   174  
   175  	for i, tc := range tt {
   176  		if err := h.Add(tc.name, tc.values[0]); err != nil {
   177  			t.Errorf("%d. h.Add(%q, %q): got %v, want no error", i, tc.name, tc.values[0], err)
   178  		}
   179  		if err := h.Add(tc.name, tc.values[1]); err != nil && !tc.errOnSecondValue {
   180  			t.Errorf("%d. h.Add(%q, %q): got %v, want no error", i, tc.name, tc.values[1], err)
   181  		}
   182  	}
   183  
   184  	if got, want := req.Host, "example.com"; got != want {
   185  		t.Errorf("req.Host: got %q, want %q", got, want)
   186  	}
   187  	if got, want := req.Header["Test-Header"], []string{"first", "second"}; !reflect.DeepEqual(got, want) {
   188  		t.Errorf("req.Header[%q]: got %v, want %v", "Test-Header", got, want)
   189  	}
   190  	if got, want := req.ContentLength, int64(100); got != want {
   191  		t.Errorf("req.ContentLength: got %d, want %d", got, want)
   192  	}
   193  	if got, want := req.TransferEncoding, []string{"chunked", "gzip"}; !reflect.DeepEqual(got, want) {
   194  		t.Errorf("req.TransferEncoding: got %v, want %v", got, want)
   195  	}
   196  }
   197  
   198  func TestResponseHeader(t *testing.T) {
   199  	res := NewResponse(200, nil, nil)
   200  
   201  	h := ResponseHeader(res)
   202  
   203  	tt := []struct {
   204  		name  string
   205  		value string
   206  	}{
   207  		{
   208  			name:  "Test-Header",
   209  			value: "true",
   210  		},
   211  		{
   212  			name:  "Content-Length",
   213  			value: "100",
   214  		},
   215  		{
   216  			name:  "Transfer-Encoding",
   217  			value: "chunked",
   218  		},
   219  	}
   220  
   221  	for i, tc := range tt {
   222  		if err := h.Set(tc.name, tc.value); err != nil {
   223  			t.Errorf("%d. h.Set(%q, %q): got %v, want no error", i, tc.name, tc.value, err)
   224  		}
   225  	}
   226  
   227  	if got, want := res.Header.Get("Test-Header"), "true"; got != want {
   228  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Test-Header", got, want)
   229  	}
   230  	if got, want := res.ContentLength, int64(100); got != want {
   231  		t.Errorf("res.ContentLength: got %d, want %d", got, want)
   232  	}
   233  	if got, want := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(got, want) {
   234  		t.Errorf("res.TransferEncoding: got %v, want %v", got, want)
   235  	}
   236  
   237  	if got, want := len(h.Map()), 3; got != want {
   238  		t.Errorf("h.Map(): got %d entries, want %d entries", got, want)
   239  	}
   240  
   241  	for n, vs := range h.Map() {
   242  		var want string
   243  		switch n {
   244  		case "Content-Length":
   245  			want = "100"
   246  		case "Transfer-Encoding":
   247  			want = "chunked"
   248  		case "Test-Header":
   249  			want = "true"
   250  		default:
   251  			t.Errorf("h.Map(): got unexpected %s header", n)
   252  		}
   253  
   254  		if got := vs[0]; got != want {
   255  			t.Errorf("h.Map(): got %s header with value %s, want value %s", n, got, want)
   256  		}
   257  	}
   258  
   259  	for i, tc := range tt {
   260  		got, ok := h.All(tc.name)
   261  		if !ok {
   262  			t.Errorf("%d. h.All(%q): got false, want true", i, tc.name)
   263  		}
   264  
   265  		if want := []string{tc.value}; !reflect.DeepEqual(got, want) {
   266  			t.Errorf("%d. h.All(%q): got %v, want %v", i, tc.name, got, want)
   267  		}
   268  
   269  		if got, want := h.Get(tc.name), tc.value; got != want {
   270  			t.Errorf("%d. h.Get(%q): got %q, want %q", i, tc.name, got, want)
   271  		}
   272  
   273  		h.Del(tc.name)
   274  	}
   275  
   276  	if got, want := res.Header.Get("Test-Header"), ""; got != want {
   277  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Test-Header", got, want)
   278  	}
   279  	if got, want := res.ContentLength, int64(-1); got != want {
   280  		t.Errorf("res.ContentLength: got %d, want %d", got, want)
   281  	}
   282  	if got := res.TransferEncoding; got != nil {
   283  		t.Errorf("res.TransferEncoding: got %v, want nil", got)
   284  	}
   285  
   286  	for i, tc := range tt {
   287  		if got, want := h.Get(tc.name), ""; got != want {
   288  			t.Errorf("%d. h.Get(%q): got %q, want %q", i, tc.name, got, want)
   289  		}
   290  
   291  		got, ok := h.All(tc.name)
   292  		if ok {
   293  			t.Errorf("%d. h.All(%q): got ok, want !ok", i, tc.name)
   294  		}
   295  		if got != nil {
   296  			t.Errorf("%d. h.All(%q): got %v, want nil", i, tc.name, got)
   297  		}
   298  	}
   299  }
   300  
   301  func TestResponseHeaderAdd(t *testing.T) {
   302  	res := NewResponse(200, nil, nil)
   303  
   304  	h := ResponseHeader(res)
   305  
   306  	tt := []struct {
   307  		name             string
   308  		values           []string
   309  		errOnSecondValue bool
   310  	}{
   311  		{
   312  			name:   "Test-Header",
   313  			values: []string{"first", "second"},
   314  		},
   315  		{
   316  			name:             "Content-Length",
   317  			values:           []string{"100", "101"},
   318  			errOnSecondValue: true,
   319  		},
   320  		{
   321  			name:   "Transfer-Encoding",
   322  			values: []string{"chunked", "gzip"},
   323  		},
   324  	}
   325  
   326  	for i, tc := range tt {
   327  		if err := h.Add(tc.name, tc.values[0]); err != nil {
   328  			t.Errorf("%d. h.Add(%q, %q): got %v, want no error", i, tc.name, tc.values[0], err)
   329  		}
   330  		if err := h.Add(tc.name, tc.values[1]); err != nil && !tc.errOnSecondValue {
   331  			t.Errorf("%d. h.Add(%q, %q): got %v, want no error", i, tc.name, tc.values[1], err)
   332  		}
   333  	}
   334  
   335  	if got, want := res.Header["Test-Header"], []string{"first", "second"}; !reflect.DeepEqual(got, want) {
   336  		t.Errorf("res.Header[%q]: got %v, want %v", "Test-Header", got, want)
   337  	}
   338  	if got, want := res.ContentLength, int64(100); got != want {
   339  		t.Errorf("res.ContentLength: got %d, want %d", got, want)
   340  	}
   341  	if got, want := res.TransferEncoding, []string{"chunked", "gzip"}; !reflect.DeepEqual(got, want) {
   342  		t.Errorf("res.TransferEncoding: got %v, want %v", got, want)
   343  	}
   344  }