github.com/tilt-dev/tilt@v0.33.15-0.20240515162809-0a22ed45d8a0/internal/k8s/portforward/portforward_test.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package portforward
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"fmt"
    23  	"net"
    24  	"net/http"
    25  	"os"
    26  	"reflect"
    27  	"sort"
    28  	"strings"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  
    34  	v1 "k8s.io/api/core/v1"
    35  	"k8s.io/apimachinery/pkg/util/httpstream"
    36  
    37  	"github.com/tilt-dev/tilt/pkg/logger"
    38  )
    39  
    40  type fakeDialer struct {
    41  	dialed             bool
    42  	conn               httpstream.Connection
    43  	err                error
    44  	negotiatedProtocol string
    45  }
    46  
    47  func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
    48  	d.dialed = true
    49  	return d.conn, d.negotiatedProtocol, d.err
    50  }
    51  
    52  type fakeConnection struct {
    53  	closed      bool
    54  	closeChan   chan bool
    55  	dataStream  *fakeStream
    56  	errorStream *fakeStream
    57  }
    58  
    59  func newFakeConnection() *fakeConnection {
    60  	return &fakeConnection{
    61  		closeChan:   make(chan bool),
    62  		dataStream:  &fakeStream{},
    63  		errorStream: &fakeStream{},
    64  	}
    65  }
    66  
    67  func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
    68  	switch headers.Get(v1.StreamType) {
    69  	case v1.StreamTypeData:
    70  		return c.dataStream, nil
    71  	case v1.StreamTypeError:
    72  		return c.errorStream, nil
    73  	default:
    74  		return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
    75  	}
    76  }
    77  
    78  func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) {
    79  }
    80  
    81  func (c *fakeConnection) Close() error {
    82  	if !c.closed {
    83  		c.closed = true
    84  		close(c.closeChan)
    85  	}
    86  	return nil
    87  }
    88  
    89  func (c *fakeConnection) CloseChan() <-chan bool {
    90  	return c.closeChan
    91  }
    92  
    93  func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
    94  	// no-op
    95  }
    96  
    97  type fakeListener struct {
    98  	net.Listener
    99  	closeChan chan bool
   100  }
   101  
   102  func newFakeListener() fakeListener {
   103  	return fakeListener{
   104  		closeChan: make(chan bool),
   105  	}
   106  }
   107  
   108  //nolint:gosimple // Copy of upstream code @ https://github.com/kubernetes/client-go/blob/77f63643f951f19681397a995fe0916d2d5cb992/tools/portforward/portforward_test.go#L105-L110
   109  func (l *fakeListener) Accept() (net.Conn, error) {
   110  	select {
   111  	case <-l.closeChan:
   112  		return nil, fmt.Errorf("listener closed")
   113  	}
   114  }
   115  
   116  func (l *fakeListener) Close() error {
   117  	close(l.closeChan)
   118  	return nil
   119  }
   120  
   121  func (l *fakeListener) Addr() net.Addr {
   122  	return fakeAddr{}
   123  }
   124  
   125  type fakeAddr struct{}
   126  
   127  func (fakeAddr) Network() string { return "fake" }
   128  func (fakeAddr) String() string  { return "fake" }
   129  
   130  type fakeStream struct {
   131  	headers   http.Header
   132  	readFunc  func(p []byte) (int, error)
   133  	writeFunc func(p []byte) (int, error)
   134  }
   135  
   136  func (s *fakeStream) Read(p []byte) (n int, err error)  { return s.readFunc(p) }
   137  func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) }
   138  func (*fakeStream) Close() error                        { return nil }
   139  func (*fakeStream) Reset() error                        { return nil }
   140  func (s *fakeStream) Headers() http.Header              { return s.headers }
   141  func (*fakeStream) Identifier() uint32                  { return 0 }
   142  
   143  type fakeConn struct {
   144  	sendBuffer    *bytes.Buffer
   145  	receiveBuffer *bytes.Buffer
   146  }
   147  
   148  func (f fakeConn) Read(p []byte) (int, error)       { return f.sendBuffer.Read(p) }
   149  func (f fakeConn) Write(p []byte) (int, error)      { return f.receiveBuffer.Write(p) }
   150  func (fakeConn) Close() error                       { return nil }
   151  func (fakeConn) LocalAddr() net.Addr                { return nil }
   152  func (fakeConn) RemoteAddr() net.Addr               { return nil }
   153  func (fakeConn) SetDeadline(t time.Time) error      { return nil }
   154  func (fakeConn) SetReadDeadline(t time.Time) error  { return nil }
   155  func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
   156  
   157  func TestParsePortsAndNew(t *testing.T) {
   158  	tests := []struct {
   159  		input                   []string
   160  		addresses               []string
   161  		expectedPorts           []ForwardedPort
   162  		expectedAddresses       []listenAddress
   163  		expectPortParseError    bool
   164  		expectAddressParseError bool
   165  		expectNewError          bool
   166  	}{
   167  		{input: []string{}, expectNewError: true},
   168  		{input: []string{"a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   169  		{input: []string{":a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   170  		{input: []string{"-1"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   171  		{input: []string{"65536"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   172  		{input: []string{"0"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   173  		{input: []string{"0:0"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   174  		{input: []string{"a:5000"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   175  		{input: []string{"5000:a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   176  		{input: []string{"5000:5000"}, addresses: []string{"127.0.0.257"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
   177  		{input: []string{"5000:5000"}, addresses: []string{"::g"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
   178  		{input: []string{"5000:5000"}, addresses: []string{"domain.invalid"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
   179  		{
   180  			input:     []string{"5000:5000"},
   181  			addresses: []string{"localhost"},
   182  			expectedPorts: []ForwardedPort{
   183  				{5000, 5000},
   184  			},
   185  			expectedAddresses: []listenAddress{
   186  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "all"},
   187  				{protocol: "tcp6", address: "::1", failureMode: "all"},
   188  			},
   189  		},
   190  		{
   191  			input:     []string{"5000:5000"},
   192  			addresses: []string{"localhost", "127.0.0.1"},
   193  			expectedPorts: []ForwardedPort{
   194  				{5000, 5000},
   195  			},
   196  			expectedAddresses: []listenAddress{
   197  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   198  				{protocol: "tcp6", address: "::1", failureMode: "all"},
   199  			},
   200  		},
   201  		{
   202  			input:     []string{"5000:5000"},
   203  			addresses: []string{"localhost", "::1"},
   204  			expectedPorts: []ForwardedPort{
   205  				{5000, 5000},
   206  			},
   207  			expectedAddresses: []listenAddress{
   208  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "all"},
   209  				{protocol: "tcp6", address: "::1", failureMode: "any"},
   210  			},
   211  		},
   212  		{
   213  			input:     []string{"5000:5000"},
   214  			addresses: []string{"localhost", "127.0.0.1", "::1"},
   215  			expectedPorts: []ForwardedPort{
   216  				{5000, 5000},
   217  			},
   218  			expectedAddresses: []listenAddress{
   219  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   220  				{protocol: "tcp6", address: "::1", failureMode: "any"},
   221  			},
   222  		},
   223  		{
   224  			input:     []string{"5000:5000"},
   225  			addresses: []string{"localhost", "127.0.0.1", "10.10.10.1"},
   226  			expectedPorts: []ForwardedPort{
   227  				{5000, 5000},
   228  			},
   229  			expectedAddresses: []listenAddress{
   230  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   231  				{protocol: "tcp6", address: "::1", failureMode: "all"},
   232  				{protocol: "tcp4", address: "10.10.10.1", failureMode: "any"},
   233  			},
   234  		},
   235  		{
   236  			input:     []string{"5000:5000"},
   237  			addresses: []string{"127.0.0.1", "::1", "localhost"},
   238  			expectedPorts: []ForwardedPort{
   239  				{5000, 5000},
   240  			},
   241  			expectedAddresses: []listenAddress{
   242  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   243  				{protocol: "tcp6", address: "::1", failureMode: "any"},
   244  			},
   245  		},
   246  		{
   247  			input:     []string{"5000:5000"},
   248  			addresses: []string{"10.0.0.1", "127.0.0.1"},
   249  			expectedPorts: []ForwardedPort{
   250  				{5000, 5000},
   251  			},
   252  			expectedAddresses: []listenAddress{
   253  				{protocol: "tcp4", address: "10.0.0.1", failureMode: "any"},
   254  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   255  			},
   256  		},
   257  		{
   258  			input:     []string{"5000", "5000:5000", "8888:5000", "5000:8888", ":5000", "0:5000"},
   259  			addresses: []string{"127.0.0.1", "::1"},
   260  			expectedPorts: []ForwardedPort{
   261  				{5000, 5000},
   262  				{5000, 5000},
   263  				{8888, 5000},
   264  				{5000, 8888},
   265  				{0, 5000},
   266  				{0, 5000},
   267  			},
   268  			expectedAddresses: []listenAddress{
   269  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   270  				{protocol: "tcp6", address: "::1", failureMode: "any"},
   271  			},
   272  		},
   273  	}
   274  
   275  	for i, test := range tests {
   276  		parsedPorts, err := parsePorts(test.input)
   277  		haveError := err != nil
   278  		if e, a := test.expectPortParseError, haveError; e != a {
   279  			t.Fatalf("%d: parsePorts: error expected=%t, got %t: %s", i, e, a, err)
   280  		}
   281  
   282  		// default to localhost
   283  		if len(test.addresses) == 0 && len(test.expectedAddresses) == 0 {
   284  			test.addresses = []string{"localhost"}
   285  			test.expectedAddresses = []listenAddress{{protocol: "tcp4", address: "127.0.0.1"}, {protocol: "tcp6", address: "::1"}}
   286  		}
   287  		// assert address parser
   288  		parsedAddresses, err := parseAddresses(test.addresses)
   289  		haveError = err != nil
   290  		if e, a := test.expectAddressParseError, haveError; e != a {
   291  			t.Fatalf("%d: parseAddresses: error expected=%t, got %t: %s", i, e, a, err)
   292  		}
   293  
   294  		ctx := newCtx()
   295  		dialer := &fakeDialer{}
   296  		readyChan := make(chan struct{})
   297  
   298  		var pf *PortForwarder
   299  		if len(test.addresses) > 0 {
   300  			pf, err = NewOnAddresses(ctx, dialer, test.addresses, test.input, readyChan)
   301  		} else {
   302  			pf, err = New(ctx, dialer, test.input, readyChan)
   303  		}
   304  		haveError = err != nil
   305  		if e, a := test.expectNewError, haveError; e != a {
   306  			t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err)
   307  		}
   308  
   309  		if test.expectPortParseError || test.expectAddressParseError || test.expectNewError {
   310  			continue
   311  		}
   312  
   313  		sort.Slice(test.expectedAddresses, func(i, j int) bool { return test.expectedAddresses[i].address < test.expectedAddresses[j].address })
   314  		sort.Slice(parsedAddresses, func(i, j int) bool { return parsedAddresses[i].address < parsedAddresses[j].address })
   315  
   316  		if !reflect.DeepEqual(test.expectedAddresses, parsedAddresses) {
   317  			t.Fatalf("%d: expectedAddresses: %v, got: %v", i, test.expectedAddresses, parsedAddresses)
   318  		}
   319  
   320  		for pi, expectedPort := range test.expectedPorts {
   321  			if e, a := expectedPort.Local, parsedPorts[pi].Local; e != a {
   322  				t.Fatalf("%d: local expected: %d, got: %d", i, e, a)
   323  			}
   324  			if e, a := expectedPort.Remote, parsedPorts[pi].Remote; e != a {
   325  				t.Fatalf("%d: remote expected: %d, got: %d", i, e, a)
   326  			}
   327  		}
   328  
   329  		if dialer.dialed {
   330  			t.Fatalf("%d: expected not dialed", i)
   331  		}
   332  		if _, portErr := pf.GetPorts(); portErr == nil {
   333  			t.Fatalf("%d: GetPorts: error expected but got nil", i)
   334  		}
   335  
   336  		// mock-signal the Ready channel
   337  		close(readyChan)
   338  
   339  		if ports, portErr := pf.GetPorts(); portErr != nil {
   340  			t.Fatalf("%d: GetPorts: unable to retrieve ports: %s", i, portErr)
   341  		} else if !reflect.DeepEqual(test.expectedPorts, ports) {
   342  			t.Fatalf("%d: ports: expected %#v, got %#v", i, test.expectedPorts, ports)
   343  		}
   344  		if pf.Ready == nil {
   345  			t.Fatalf("%d: Ready should be non-nil", i)
   346  		}
   347  	}
   348  }
   349  
   350  type GetListenerTestCase struct {
   351  	Hostname                string
   352  	Protocol                string
   353  	ShouldRaiseError        bool
   354  	ExpectedListenerAddress string
   355  }
   356  
   357  func TestGetListener(t *testing.T) {
   358  	var pf PortForwarder
   359  	testCases := []GetListenerTestCase{
   360  		{
   361  			Hostname:                "localhost",
   362  			Protocol:                "tcp4",
   363  			ShouldRaiseError:        false,
   364  			ExpectedListenerAddress: "127.0.0.1",
   365  		},
   366  		{
   367  			Hostname:                "127.0.0.1",
   368  			Protocol:                "tcp4",
   369  			ShouldRaiseError:        false,
   370  			ExpectedListenerAddress: "127.0.0.1",
   371  		},
   372  		{
   373  			Hostname:                "::1",
   374  			Protocol:                "tcp6",
   375  			ShouldRaiseError:        false,
   376  			ExpectedListenerAddress: "::1",
   377  		},
   378  		{
   379  			Hostname:         "::1",
   380  			Protocol:         "tcp4",
   381  			ShouldRaiseError: true,
   382  		},
   383  		{
   384  			Hostname:         "127.0.0.1",
   385  			Protocol:         "tcp6",
   386  			ShouldRaiseError: true,
   387  		},
   388  	}
   389  
   390  	for i, testCase := range testCases {
   391  		forwardedPort := &ForwardedPort{Local: 0, Remote: 12345}
   392  		listener, err := pf.getListener(testCase.Protocol, testCase.Hostname, forwardedPort)
   393  		if err != nil && strings.Contains(err.Error(), "cannot assign requested address") {
   394  			t.Logf("Can't test #%d: %v", i, err)
   395  			continue
   396  		}
   397  		expectedListenerPort := fmt.Sprintf("%d", forwardedPort.Local)
   398  		errorRaised := err != nil
   399  
   400  		if testCase.ShouldRaiseError != errorRaised {
   401  			t.Errorf("Test case #%d failed: Data %v an error has been raised(%t) where it should not (or reciprocally): %v", i, testCase, testCase.ShouldRaiseError, err)
   402  			continue
   403  		}
   404  		if errorRaised {
   405  			continue
   406  		}
   407  
   408  		if listener == nil {
   409  			t.Errorf("Test case #%d did not raise an error but failed in initializing listener", i)
   410  			continue
   411  		}
   412  
   413  		host, port, _ := net.SplitHostPort(listener.Addr().String())
   414  		t.Logf("Asked a %s forward for: %s:0, got listener %s:%s, expected: %s", testCase.Protocol, testCase.Hostname, host, port, expectedListenerPort)
   415  		if host != testCase.ExpectedListenerAddress {
   416  			t.Errorf("Test case #%d failed: Listener does not listen on expected address: asked '%v' got '%v'", i, testCase.ExpectedListenerAddress, host)
   417  		}
   418  		if port != expectedListenerPort {
   419  			t.Errorf("Test case #%d failed: Listener does not listen on expected port: asked %v got %v", i, expectedListenerPort, port)
   420  
   421  		}
   422  		listener.Close()
   423  	}
   424  }
   425  
   426  func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
   427  	dialer := &fakeDialer{
   428  		conn: newFakeConnection(),
   429  	}
   430  
   431  	ctx, cancel := context.WithCancel(newCtx())
   432  	readyChan := make(chan struct{})
   433  	errChan := make(chan error)
   434  
   435  	defer func() {
   436  		cancel()
   437  
   438  		forwardErr := <-errChan
   439  		if forwardErr != nil {
   440  			t.Fatalf("ForwardPorts returned error: %s", forwardErr)
   441  		}
   442  	}()
   443  
   444  	pf, err := New(ctx, dialer, []string{":5000"}, readyChan)
   445  
   446  	if err != nil {
   447  		t.Fatalf("error while calling New: %s", err)
   448  	}
   449  
   450  	go func() {
   451  		errChan <- pf.ForwardPorts()
   452  		close(errChan)
   453  	}()
   454  
   455  	<-pf.Ready
   456  
   457  	ports, err := pf.GetPorts()
   458  	if err != nil {
   459  		t.Fatalf("Failed to get ports. error: %v", err)
   460  	}
   461  
   462  	if len(ports) != 1 {
   463  		t.Fatalf("expected 1 port, got %d", len(ports))
   464  	}
   465  
   466  	port := ports[0]
   467  	if port.Local == 0 {
   468  		t.Fatalf("local port is 0, expected != 0")
   469  	}
   470  }
   471  
   472  func newCtx() context.Context {
   473  	return logger.WithLogger(context.Background(), logger.NewTestLogger(os.Stdout))
   474  }
   475  
   476  func TestHandleConnection(t *testing.T) {
   477  	readyChan := make(chan struct{})
   478  
   479  	pf, err := New(newCtx(), &fakeDialer{}, []string{":2222"}, readyChan)
   480  	if err != nil {
   481  		t.Fatalf("error while calling New: %s", err)
   482  	}
   483  
   484  	// Setup fake local connection
   485  	localConnection := &fakeConn{
   486  		sendBuffer:    bytes.NewBufferString("test data from local"),
   487  		receiveBuffer: bytes.NewBufferString(""),
   488  	}
   489  
   490  	// Setup fake remote connection to send data on the data stream after it receives data from the local connection
   491  	remoteDataToSend := bytes.NewBufferString("test data from remote")
   492  	remoteDataReceived := bytes.NewBufferString("")
   493  	remoteErrorToSend := bytes.NewBufferString("")
   494  	blockRemoteSend := make(chan struct{})
   495  	remoteConnection := newFakeConnection()
   496  	remoteConnection.dataStream.readFunc = func(p []byte) (int, error) {
   497  		<-blockRemoteSend // Wait for the expected data to be received before responding
   498  		return remoteDataToSend.Read(p)
   499  	}
   500  	remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) {
   501  		n, err := remoteDataReceived.Write(p)
   502  		if remoteDataReceived.String() == "test data from local" {
   503  			close(blockRemoteSend)
   504  		}
   505  		return n, err
   506  	}
   507  	remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
   508  	pf.streamConn = remoteConnection
   509  
   510  	// Test handleConnection
   511  	pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
   512  
   513  	assert.Equal(t, "test data from local", remoteDataReceived.String())
   514  	assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
   515  }
   516  
   517  func TestHandleConnectionSendsRemoteError(t *testing.T) {
   518  	readyChan := make(chan struct{})
   519  
   520  	pf, err := New(newCtx(), &fakeDialer{}, []string{":2222"}, readyChan)
   521  	if err != nil {
   522  		t.Fatalf("error while calling New: %s", err)
   523  	}
   524  
   525  	// Setup fake local connection
   526  	localConnection := &fakeConn{
   527  		sendBuffer:    bytes.NewBufferString(""),
   528  		receiveBuffer: bytes.NewBufferString(""),
   529  	}
   530  
   531  	// Setup fake remote connection to return an error message on the error stream
   532  	remoteDataToSend := bytes.NewBufferString("")
   533  	remoteDataReceived := bytes.NewBufferString("")
   534  	remoteErrorToSend := bytes.NewBufferString("error")
   535  	remoteConnection := newFakeConnection()
   536  	remoteConnection.dataStream.readFunc = remoteDataToSend.Read
   537  	remoteConnection.dataStream.writeFunc = remoteDataReceived.Write
   538  	remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
   539  	pf.streamConn = remoteConnection
   540  
   541  	// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
   542  	pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
   543  
   544  	assert.Equal(t, "", remoteDataReceived.String())
   545  	assert.Equal(t, "", localConnection.receiveBuffer.String())
   546  }
   547  
   548  func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
   549  	readyChan := make(chan struct{})
   550  
   551  	pf, err := New(newCtx(), &fakeDialer{}, []string{":2222"}, readyChan)
   552  	if err != nil {
   553  		t.Fatalf("error while calling New: %s", err)
   554  	}
   555  
   556  	listener := newFakeListener()
   557  
   558  	pf.streamConn = newFakeConnection()
   559  	pf.streamConn.Close()
   560  
   561  	port := ForwardedPort{}
   562  	pf.waitForConnection(&listener, port)
   563  }