k8s.io/apiserver@v0.31.1/pkg/registry/generic/rest/streamer_test.go (about)

     1  /*
     2  Copyright 2014 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package rest
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"context"
    23  	"fmt"
    24  	"io/ioutil"
    25  	"net/http"
    26  	"net/http/httptest"
    27  	"net/url"
    28  	"reflect"
    29  	"testing"
    30  
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  	"k8s.io/apimachinery/pkg/api/errors"
    34  	"k8s.io/apimachinery/pkg/runtime/schema"
    35  )
    36  
    37  func TestInputStreamReader(t *testing.T) {
    38  	resultString := "Test output"
    39  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    40  		w.Write([]byte(resultString))
    41  	}))
    42  	defer s.Close()
    43  	u, err := url.Parse(s.URL)
    44  	if err != nil {
    45  		t.Errorf("Error parsing server URL: %v", err)
    46  		return
    47  	}
    48  	streamer := &LocationStreamer{
    49  		Location: u,
    50  	}
    51  	readCloser, _, _, err := streamer.InputStream(context.Background(), "", "")
    52  	if err != nil {
    53  		t.Errorf("Unexpected error when getting stream: %v", err)
    54  		return
    55  	}
    56  	defer readCloser.Close()
    57  	result, _ := ioutil.ReadAll(readCloser)
    58  	if string(result) != resultString {
    59  		t.Errorf("Stream content does not match. Got: %s. Expected: %s.", string(result), resultString)
    60  	}
    61  }
    62  
    63  func TestInputStreamNullLocation(t *testing.T) {
    64  	streamer := &LocationStreamer{
    65  		Location: nil,
    66  	}
    67  	readCloser, _, _, err := streamer.InputStream(context.Background(), "", "")
    68  	if err != nil {
    69  		t.Errorf("Unexpected error when getting stream with null location: %v", err)
    70  	}
    71  	if readCloser != nil {
    72  		t.Errorf("Expected stream to be nil. Got: %#v", readCloser)
    73  	}
    74  }
    75  
    76  type testTransport struct {
    77  	body string
    78  }
    79  
    80  func (tt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    81  	r := bufio.NewReader(bytes.NewBufferString(tt.body))
    82  	return http.ReadResponse(r, req)
    83  }
    84  
    85  func fakeTransport(mime, message string) http.RoundTripper {
    86  	content := fmt.Sprintf("HTTP/1.1 200 OK\nContent-Type: %s\n\n%s", mime, message)
    87  	return &testTransport{body: content}
    88  }
    89  
    90  func TestInputStreamContentType(t *testing.T) {
    91  	location, _ := url.Parse("http://www.example.com")
    92  	streamer := &LocationStreamer{
    93  		Location:  location,
    94  		Transport: fakeTransport("application/json", "hello world"),
    95  	}
    96  	readCloser, _, contentType, err := streamer.InputStream(context.Background(), "", "")
    97  	if err != nil {
    98  		t.Errorf("Unexpected error when getting stream: %v", err)
    99  		return
   100  	}
   101  	defer readCloser.Close()
   102  	if contentType != "application/json" {
   103  		t.Errorf("Unexpected content type. Got: %s. Expected: application/json", contentType)
   104  	}
   105  }
   106  
   107  func TestInputStreamTransport(t *testing.T) {
   108  	message := "hello world"
   109  	location, _ := url.Parse("http://www.example.com")
   110  	streamer := &LocationStreamer{
   111  		Location:  location,
   112  		Transport: fakeTransport("text/plain", message),
   113  	}
   114  	readCloser, _, _, err := streamer.InputStream(context.Background(), "", "")
   115  	if err != nil {
   116  		t.Errorf("Unexpected error when getting stream: %v", err)
   117  		return
   118  	}
   119  	defer readCloser.Close()
   120  	result, _ := ioutil.ReadAll(readCloser)
   121  	if string(result) != message {
   122  		t.Errorf("Stream content does not match. Got: %s. Expected: %s.", string(result), message)
   123  	}
   124  }
   125  
   126  func fakeInternalServerErrorTransport(mime, message string) http.RoundTripper {
   127  	content := fmt.Sprintf("HTTP/1.1 500 \"Internal Server Error\"\nContent-Type: %s\n\n%s", mime, message)
   128  	return &testTransport{body: content}
   129  }
   130  
   131  func TestInputStreamInternalServerErrorTransport(t *testing.T) {
   132  	message := "Pod is in PodPending"
   133  	location, _ := url.Parse("http://www.example.com")
   134  	streamer := &LocationStreamer{
   135  		Location:        location,
   136  		Transport:       fakeInternalServerErrorTransport("text/plain", message),
   137  		ResponseChecker: NewGenericHttpResponseChecker(schema.GroupResource{}, ""),
   138  	}
   139  	expectedError := errors.NewInternalError(fmt.Errorf("%s", message))
   140  
   141  	_, _, _, err := streamer.InputStream(context.Background(), "", "")
   142  	if err == nil {
   143  		t.Errorf("unexpected non-error")
   144  		return
   145  	}
   146  
   147  	if !reflect.DeepEqual(err, expectedError) {
   148  		t.Errorf("StreamInternalServerError does not match. Got: %s. Expected: %s.", err, expectedError)
   149  	}
   150  }
   151  
   152  func TestInputStreamRedirects(t *testing.T) {
   153  	const redirectPath = "/redirect"
   154  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   155  		if req.URL.Path == redirectPath {
   156  			t.Fatal("Redirects should not be followed")
   157  		} else {
   158  			http.Redirect(w, req, redirectPath, http.StatusFound)
   159  		}
   160  	}))
   161  	loc, err := url.Parse(s.URL)
   162  	require.NoError(t, err, "Error parsing server URL")
   163  
   164  	streamer := &LocationStreamer{
   165  		Location:        loc,
   166  		RedirectChecker: PreventRedirects,
   167  	}
   168  	_, _, _, err = streamer.InputStream(context.Background(), "", "")
   169  	assert.Error(t, err, "Redirect should trigger an error")
   170  }