k8s.io/apimachinery@v0.29.2/pkg/util/httpstream/wsstream/stream_test.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package wsstream
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"io"
    24  	"net/http"
    25  	"reflect"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"golang.org/x/net/websocket"
    31  )
    32  
    33  func TestStream(t *testing.T) {
    34  	input := "some random text"
    35  	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
    36  	r.SetIdleTimeout(time.Second)
    37  	data, err := readWebSocket(r, t, nil)
    38  	if !reflect.DeepEqual(data, []byte(input)) {
    39  		t.Errorf("unexpected server read: %v", data)
    40  	}
    41  	if err != nil {
    42  		t.Fatal(err)
    43  	}
    44  }
    45  
    46  func TestStreamPing(t *testing.T) {
    47  	input := "some random text"
    48  	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
    49  	r.SetIdleTimeout(time.Second)
    50  	err := expectWebSocketFrames(r, t, nil, [][]byte{
    51  		{},
    52  		[]byte(input),
    53  	})
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  }
    58  
    59  func TestStreamBase64(t *testing.T) {
    60  	input := "some random text"
    61  	encoded := base64.StdEncoding.EncodeToString([]byte(input))
    62  	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
    63  	data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
    64  	if !reflect.DeepEqual(data, []byte(encoded)) {
    65  		t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
    66  	}
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  }
    71  
    72  func TestStreamVersionedBase64(t *testing.T) {
    73  	input := "some random text"
    74  	encoded := base64.StdEncoding.EncodeToString([]byte(input))
    75  	r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
    76  		"":                        {Binary: true},
    77  		"binary.k8s.io":           {Binary: true},
    78  		"base64.binary.k8s.io":    {Binary: false},
    79  		"v1.binary.k8s.io":        {Binary: true},
    80  		"v1.base64.binary.k8s.io": {Binary: false},
    81  		"v2.binary.k8s.io":        {Binary: true},
    82  		"v2.base64.binary.k8s.io": {Binary: false},
    83  	})
    84  	data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
    85  	if !reflect.DeepEqual(data, []byte(encoded)) {
    86  		t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
    87  	}
    88  	if err != nil {
    89  		t.Fatal(err)
    90  	}
    91  }
    92  
    93  func TestStreamVersionedCopy(t *testing.T) {
    94  	for i, test := range versionTests() {
    95  		func() {
    96  			supportedProtocols := map[string]ReaderProtocolConfig{}
    97  			for p, binary := range test.supported {
    98  				supportedProtocols[p] = ReaderProtocolConfig{
    99  					Binary: binary,
   100  				}
   101  			}
   102  			input := "some random text"
   103  			r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
   104  			s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   105  				err := r.Copy(w, req)
   106  				if err != nil {
   107  					w.WriteHeader(503)
   108  				}
   109  			}))
   110  			defer s.Close()
   111  
   112  			config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
   113  			if err != nil {
   114  				t.Error(err)
   115  				return
   116  			}
   117  			config.Protocol = test.requested
   118  			client, err := websocket.DialConfig(config)
   119  			if err != nil {
   120  				if !test.error {
   121  					t.Errorf("test %d: didn't expect error: %v", i, err)
   122  				}
   123  				return
   124  			}
   125  			defer client.Close()
   126  			if test.error && err == nil {
   127  				t.Errorf("test %d: expected an error", i)
   128  				return
   129  			}
   130  
   131  			<-r.err
   132  			if got, expected := r.selectedProtocol, test.expected; got != expected {
   133  				t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
   134  			}
   135  		}()
   136  	}
   137  }
   138  
   139  func TestStreamError(t *testing.T) {
   140  	input := "some random text"
   141  	errs := &errorReader{
   142  		reads: [][]byte{
   143  			[]byte("some random"),
   144  			[]byte(" text"),
   145  		},
   146  		err: fmt.Errorf("bad read"),
   147  	}
   148  	r := NewReader(errs, false, NewDefaultReaderProtocols())
   149  
   150  	data, err := readWebSocket(r, t, nil)
   151  	if !reflect.DeepEqual(data, []byte(input)) {
   152  		t.Errorf("unexpected server read: %v", data)
   153  	}
   154  	if err == nil || err.Error() != "bad read" {
   155  		t.Fatal(err)
   156  	}
   157  }
   158  
   159  func TestStreamSurvivesPanic(t *testing.T) {
   160  	input := "some random text"
   161  	errs := &errorReader{
   162  		reads: [][]byte{
   163  			[]byte("some random"),
   164  			[]byte(" text"),
   165  		},
   166  		panicMessage: "bad read",
   167  	}
   168  	r := NewReader(errs, false, NewDefaultReaderProtocols())
   169  
   170  	// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
   171  	r.handleCrash = func(additionalHandlers ...func(interface{})) { recover() }
   172  
   173  	data, err := readWebSocket(r, t, nil)
   174  	if !reflect.DeepEqual(data, []byte(input)) {
   175  		t.Errorf("unexpected server read: %v", data)
   176  	}
   177  	if err != nil {
   178  		t.Fatal(err)
   179  	}
   180  }
   181  
   182  func TestStreamClosedDuringRead(t *testing.T) {
   183  	for i := 0; i < 25; i++ {
   184  		ch := make(chan struct{})
   185  		input := "some random text"
   186  		errs := &errorReader{
   187  			reads: [][]byte{
   188  				[]byte("some random"),
   189  				[]byte(" text"),
   190  			},
   191  			err:   fmt.Errorf("stuff"),
   192  			pause: ch,
   193  		}
   194  		r := NewReader(errs, false, NewDefaultReaderProtocols())
   195  
   196  		data, err := readWebSocket(r, t, func(c *websocket.Conn) {
   197  			c.Close()
   198  			close(ch)
   199  		})
   200  		// verify that the data returned by the server on an early close always has a specific error
   201  		if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
   202  			t.Fatal(err)
   203  		}
   204  		// verify that the data returned is a strict subset of the input
   205  		if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 {
   206  			t.Fatalf("unexpected server read: %q", string(data))
   207  		}
   208  	}
   209  }
   210  
   211  type errorReader struct {
   212  	reads        [][]byte
   213  	err          error
   214  	panicMessage string
   215  	pause        chan struct{}
   216  }
   217  
   218  func (r *errorReader) Read(p []byte) (int, error) {
   219  	if len(r.reads) == 0 {
   220  		if r.pause != nil {
   221  			<-r.pause
   222  		}
   223  		if len(r.panicMessage) != 0 {
   224  			panic(r.panicMessage)
   225  		}
   226  		return 0, r.err
   227  	}
   228  	next := r.reads[0]
   229  	r.reads = r.reads[1:]
   230  	copy(p, next)
   231  	return len(next), nil
   232  }
   233  
   234  func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
   235  	errCh := make(chan error, 1)
   236  	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   237  		errCh <- r.Copy(w, req)
   238  	}))
   239  	defer s.Close()
   240  
   241  	config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
   242  	config.Protocol = protocols
   243  	client, err := websocket.DialConfig(config)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	defer client.Close()
   248  
   249  	if fn != nil {
   250  		fn(client)
   251  	}
   252  
   253  	data, err := io.ReadAll(client)
   254  	if err != nil {
   255  		return data, err
   256  	}
   257  	return data, <-errCh
   258  }
   259  
   260  func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
   261  	errCh := make(chan error, 1)
   262  	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   263  		errCh <- r.Copy(w, req)
   264  	}))
   265  	defer s.Close()
   266  
   267  	config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
   268  	config.Protocol = protocols
   269  	ws, err := websocket.DialConfig(config)
   270  	if err != nil {
   271  		return err
   272  	}
   273  	defer ws.Close()
   274  
   275  	if fn != nil {
   276  		fn(ws)
   277  	}
   278  
   279  	for i := range frames {
   280  		var data []byte
   281  		if err := websocket.Message.Receive(ws, &data); err != nil {
   282  			return err
   283  		}
   284  		if !reflect.DeepEqual(frames[i], data) {
   285  			return fmt.Errorf("frame %d did not match expected: %v", data, err)
   286  		}
   287  	}
   288  	var data []byte
   289  	if err := websocket.Message.Receive(ws, &data); err != io.EOF {
   290  		return fmt.Errorf("expected no more frames: %v (%v)", err, data)
   291  	}
   292  	return <-errCh
   293  }