github.com/spotmaxtech/k8s-apimachinery-v0260@v0.0.1/pkg/util/httpstream/httpstream_test.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package httpstream
    18  
    19  import (
    20  	"net/http"
    21  	"reflect"
    22  	"testing"
    23  )
    24  
    25  type responseWriter struct {
    26  	header     http.Header
    27  	statusCode *int
    28  }
    29  
    30  func newResponseWriter() *responseWriter {
    31  	return &responseWriter{
    32  		header: make(http.Header),
    33  	}
    34  }
    35  
    36  func (r *responseWriter) Header() http.Header {
    37  	return r.header
    38  }
    39  
    40  func (r *responseWriter) WriteHeader(code int) {
    41  	r.statusCode = &code
    42  }
    43  
    44  func (r *responseWriter) Write([]byte) (int, error) {
    45  	return 0, nil
    46  }
    47  
    48  func TestHandshake(t *testing.T) {
    49  	tests := map[string]struct {
    50  		clientProtocols  []string
    51  		serverProtocols  []string
    52  		expectedProtocol string
    53  		expectError      bool
    54  	}{
    55  		"no common protocol": {
    56  			clientProtocols:  []string{"c"},
    57  			serverProtocols:  []string{"a", "b"},
    58  			expectedProtocol: "",
    59  			expectError:      true,
    60  		},
    61  		"no common protocol with comma separated list": {
    62  			clientProtocols:  []string{"c, d"},
    63  			serverProtocols:  []string{"a", "b"},
    64  			expectedProtocol: "",
    65  			expectError:      true,
    66  		},
    67  		"common protocol": {
    68  			clientProtocols:  []string{"b"},
    69  			serverProtocols:  []string{"a", "b"},
    70  			expectedProtocol: "b",
    71  		},
    72  		"common protocol with comma separated list": {
    73  			clientProtocols:  []string{"b, c"},
    74  			serverProtocols:  []string{"a", "b"},
    75  			expectedProtocol: "b",
    76  		},
    77  	}
    78  
    79  	for name, test := range tests {
    80  		req, err := http.NewRequest("GET", "http://www.example.com/", nil)
    81  		if err != nil {
    82  			t.Fatalf("%s: error creating request: %v", name, err)
    83  		}
    84  
    85  		for _, p := range test.clientProtocols {
    86  			req.Header.Add(HeaderProtocolVersion, p)
    87  		}
    88  
    89  		w := newResponseWriter()
    90  		negotiated, err := Handshake(req, w, test.serverProtocols)
    91  
    92  		// verify negotiated protocol
    93  		if e, a := test.expectedProtocol, negotiated; e != a {
    94  			t.Errorf("%s: protocol: expected %q, got %q", name, e, a)
    95  		}
    96  
    97  		if test.expectError {
    98  			if err == nil {
    99  				t.Errorf("%s: expected error but did not get one", name)
   100  			}
   101  			if w.statusCode == nil {
   102  				t.Errorf("%s: expected w.statusCode to be set", name)
   103  			} else if e, a := http.StatusForbidden, *w.statusCode; e != a {
   104  				t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a)
   105  			}
   106  			if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) {
   107  				t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a)
   108  			}
   109  			continue
   110  		}
   111  		if !test.expectError && err != nil {
   112  			t.Errorf("%s: unexpected error: %v", name, err)
   113  			continue
   114  		}
   115  		if w.statusCode != nil {
   116  			t.Errorf("%s: unexpected non-nil w.statusCode: %d", name, w.statusCode)
   117  		}
   118  
   119  		if len(test.expectedProtocol) == 0 {
   120  			if len(w.Header()[HeaderProtocolVersion]) > 0 {
   121  				t.Errorf("%s: unexpected protocol version response header: %s", name, w.Header()[HeaderProtocolVersion])
   122  			}
   123  			continue
   124  		}
   125  
   126  		// verify response headers
   127  		if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
   128  			t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
   129  		}
   130  	}
   131  }