github.com/tilt-dev/tilt@v0.36.0/internal/k8s/portforward/portforward_test.go (about) 1 /* 2 Copyright 2015 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package portforward 18 19 import ( 20 "bytes" 21 "fmt" 22 "net" 23 "net/http" 24 "os" 25 "reflect" 26 "sort" 27 "strings" 28 "testing" 29 "time" 30 31 "github.com/stretchr/testify/assert" 32 33 v1 "k8s.io/api/core/v1" 34 "k8s.io/apimachinery/pkg/util/httpstream" 35 ) 36 37 type fakeDialer struct { 38 dialed bool 39 conn httpstream.Connection 40 err error 41 negotiatedProtocol string 42 } 43 44 func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { 45 d.dialed = true 46 return d.conn, d.negotiatedProtocol, d.err 47 } 48 49 type fakeConnection struct { 50 closed bool 51 closeChan chan bool 52 dataStream *fakeStream 53 errorStream *fakeStream 54 streamCount int 55 } 56 57 func newFakeConnection() *fakeConnection { 58 return &fakeConnection{ 59 closeChan: make(chan bool), 60 dataStream: &fakeStream{}, 61 errorStream: &fakeStream{}, 62 } 63 } 64 65 func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { 66 switch headers.Get(v1.StreamType) { 67 case v1.StreamTypeData: 68 c.streamCount++ 69 return c.dataStream, nil 70 case v1.StreamTypeError: 71 c.streamCount++ 72 return c.errorStream, nil 73 default: 74 return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType)) 75 } 76 } 77 78 func (c *fakeConnection) Close() error { 79 if !c.closed { 80 c.closed = true 81 close(c.closeChan) 82 } 83 return nil 84 } 85 86 func (c *fakeConnection) CloseChan() <-chan bool { 87 return c.closeChan 88 } 89 90 func (c *fakeConnection) RemoveStreams(streams ...httpstream.Stream) { 91 for range streams { 92 c.streamCount-- 93 } 94 } 95 96 func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) { 97 // no-op 98 } 99 100 type fakeListener struct { 101 net.Listener 102 closeChan chan bool 103 } 104 105 func newFakeListener() fakeListener { 106 return fakeListener{ 107 closeChan: make(chan bool), 108 } 109 } 110 111 func (l *fakeListener) Accept() (net.Conn, error) { 112 <-l.closeChan 113 return nil, fmt.Errorf("listener closed") 114 } 115 116 func (l *fakeListener) Close() error { 117 close(l.closeChan) 118 return nil 119 } 120 121 func (l *fakeListener) Addr() net.Addr { 122 return fakeAddr{} 123 } 124 125 type fakeAddr struct{} 126 127 func (fakeAddr) Network() string { return "fake" } 128 func (fakeAddr) String() string { return "fake" } 129 130 type fakeStream struct { 131 headers http.Header 132 readFunc func(p []byte) (int, error) 133 writeFunc func(p []byte) (int, error) 134 } 135 136 func (s *fakeStream) Read(p []byte) (n int, err error) { return s.readFunc(p) } 137 func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) } 138 func (*fakeStream) Close() error { return nil } 139 func (*fakeStream) Reset() error { return nil } 140 func (s *fakeStream) Headers() http.Header { return s.headers } 141 func (*fakeStream) Identifier() uint32 { return 0 } 142 143 type fakeConn struct { 144 sendBuffer *bytes.Buffer 145 receiveBuffer *bytes.Buffer 146 } 147 148 func (f fakeConn) Read(p []byte) (int, error) { return f.sendBuffer.Read(p) } 149 func (f fakeConn) Write(p []byte) (int, error) { return f.receiveBuffer.Write(p) } 150 func (fakeConn) Close() error { return nil } 151 func (fakeConn) LocalAddr() net.Addr { return nil } 152 func (fakeConn) RemoteAddr() net.Addr { return nil } 153 func (fakeConn) SetDeadline(t time.Time) error { return nil } 154 func (fakeConn) SetReadDeadline(t time.Time) error { return nil } 155 func (fakeConn) SetWriteDeadline(t time.Time) error { return nil } 156 157 func TestParsePortsAndNew(t *testing.T) { 158 tests := []struct { 159 input []string 160 addresses []string 161 expectedPorts []ForwardedPort 162 expectedAddresses []listenAddress 163 expectPortParseError bool 164 expectAddressParseError bool 165 expectNewError bool 166 }{ 167 {input: []string{}, expectNewError: true}, 168 {input: []string{"a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true}, 169 {input: []string{":a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true}, 170 {input: []string{"-1"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true}, 171 {input: []string{"65536"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true}, 172 {input: []string{"0"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true}, 173 {input: []string{"0:0"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true}, 174 {input: []string{"a:5000"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true}, 175 {input: []string{"5000:a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true}, 176 {input: []string{"5000:5000"}, addresses: []string{"127.0.0.257"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true}, 177 {input: []string{"5000:5000"}, addresses: []string{"::g"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true}, 178 {input: []string{"5000:5000"}, addresses: []string{"domain.invalid"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true}, 179 { 180 input: []string{"5000:5000"}, 181 addresses: []string{"localhost"}, 182 expectedPorts: []ForwardedPort{ 183 {5000, 5000}, 184 }, 185 expectedAddresses: []listenAddress{ 186 {protocol: "tcp4", address: "127.0.0.1", failureMode: "all"}, 187 {protocol: "tcp6", address: "::1", failureMode: "all"}, 188 }, 189 }, 190 { 191 input: []string{"5000:5000"}, 192 addresses: []string{"localhost", "127.0.0.1"}, 193 expectedPorts: []ForwardedPort{ 194 {5000, 5000}, 195 }, 196 expectedAddresses: []listenAddress{ 197 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"}, 198 {protocol: "tcp6", address: "::1", failureMode: "all"}, 199 }, 200 }, 201 { 202 input: []string{"5000:5000"}, 203 addresses: []string{"localhost", "::1"}, 204 expectedPorts: []ForwardedPort{ 205 {5000, 5000}, 206 }, 207 expectedAddresses: []listenAddress{ 208 {protocol: "tcp4", address: "127.0.0.1", failureMode: "all"}, 209 {protocol: "tcp6", address: "::1", failureMode: "any"}, 210 }, 211 }, 212 { 213 input: []string{"5000:5000"}, 214 addresses: []string{"localhost", "127.0.0.1", "::1"}, 215 expectedPorts: []ForwardedPort{ 216 {5000, 5000}, 217 }, 218 expectedAddresses: []listenAddress{ 219 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"}, 220 {protocol: "tcp6", address: "::1", failureMode: "any"}, 221 }, 222 }, 223 { 224 input: []string{"5000:5000"}, 225 addresses: []string{"localhost", "127.0.0.1", "10.10.10.1"}, 226 expectedPorts: []ForwardedPort{ 227 {5000, 5000}, 228 }, 229 expectedAddresses: []listenAddress{ 230 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"}, 231 {protocol: "tcp6", address: "::1", failureMode: "all"}, 232 {protocol: "tcp4", address: "10.10.10.1", failureMode: "any"}, 233 }, 234 }, 235 { 236 input: []string{"5000:5000"}, 237 addresses: []string{"127.0.0.1", "::1", "localhost"}, 238 expectedPorts: []ForwardedPort{ 239 {5000, 5000}, 240 }, 241 expectedAddresses: []listenAddress{ 242 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"}, 243 {protocol: "tcp6", address: "::1", failureMode: "any"}, 244 }, 245 }, 246 { 247 input: []string{"5000:5000"}, 248 addresses: []string{"10.0.0.1", "127.0.0.1"}, 249 expectedPorts: []ForwardedPort{ 250 {5000, 5000}, 251 }, 252 expectedAddresses: []listenAddress{ 253 {protocol: "tcp4", address: "10.0.0.1", failureMode: "any"}, 254 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"}, 255 }, 256 }, 257 { 258 input: []string{"5000", "5000:5000", "8888:5000", "5000:8888", ":5000", "0:5000"}, 259 addresses: []string{"127.0.0.1", "::1"}, 260 expectedPorts: []ForwardedPort{ 261 {5000, 5000}, 262 {5000, 5000}, 263 {8888, 5000}, 264 {5000, 8888}, 265 {0, 5000}, 266 {0, 5000}, 267 }, 268 expectedAddresses: []listenAddress{ 269 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"}, 270 {protocol: "tcp6", address: "::1", failureMode: "any"}, 271 }, 272 }, 273 } 274 275 for i, test := range tests { 276 parsedPorts, err := parsePorts(test.input) 277 haveError := err != nil 278 if e, a := test.expectPortParseError, haveError; e != a { 279 t.Fatalf("%d: parsePorts: error expected=%t, got %t: %s", i, e, a, err) 280 } 281 282 // default to localhost 283 if len(test.addresses) == 0 && len(test.expectedAddresses) == 0 { 284 test.addresses = []string{"localhost"} 285 test.expectedAddresses = []listenAddress{{protocol: "tcp4", address: "127.0.0.1"}, {protocol: "tcp6", address: "::1"}} 286 } 287 // assert address parser 288 parsedAddresses, err := parseAddresses(test.addresses) 289 haveError = err != nil 290 if e, a := test.expectAddressParseError, haveError; e != a { 291 t.Fatalf("%d: parseAddresses: error expected=%t, got %t: %s", i, e, a, err) 292 } 293 294 dialer := &fakeDialer{} 295 expectedStopChan := make(chan struct{}) 296 readyChan := make(chan struct{}) 297 298 var pf *PortForwarder 299 if len(test.addresses) > 0 { 300 pf, err = NewOnAddresses(dialer, test.addresses, test.input, expectedStopChan, readyChan, os.Stdout, os.Stderr) 301 } else { 302 pf, err = New(dialer, test.input, expectedStopChan, readyChan, os.Stdout, os.Stderr) 303 } 304 haveError = err != nil 305 if e, a := test.expectNewError, haveError; e != a { 306 t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err) 307 } 308 309 if test.expectPortParseError || test.expectAddressParseError || test.expectNewError { 310 continue 311 } 312 313 sort.Slice(test.expectedAddresses, func(i, j int) bool { return test.expectedAddresses[i].address < test.expectedAddresses[j].address }) 314 sort.Slice(parsedAddresses, func(i, j int) bool { return parsedAddresses[i].address < parsedAddresses[j].address }) 315 316 if !reflect.DeepEqual(test.expectedAddresses, parsedAddresses) { 317 t.Fatalf("%d: expectedAddresses: %v, got: %v", i, test.expectedAddresses, parsedAddresses) 318 } 319 320 for pi, expectedPort := range test.expectedPorts { 321 if e, a := expectedPort.Local, parsedPorts[pi].Local; e != a { 322 t.Fatalf("%d: local expected: %d, got: %d", i, e, a) 323 } 324 if e, a := expectedPort.Remote, parsedPorts[pi].Remote; e != a { 325 t.Fatalf("%d: remote expected: %d, got: %d", i, e, a) 326 } 327 } 328 329 if dialer.dialed { 330 t.Fatalf("%d: expected not dialed", i) 331 } 332 if _, portErr := pf.GetPorts(); portErr == nil { 333 t.Fatalf("%d: GetPorts: error expected but got nil", i) 334 } 335 336 // mock-signal the Ready channel 337 close(readyChan) 338 339 if ports, portErr := pf.GetPorts(); portErr != nil { 340 t.Fatalf("%d: GetPorts: unable to retrieve ports: %s", i, portErr) 341 } else if !reflect.DeepEqual(test.expectedPorts, ports) { 342 t.Fatalf("%d: ports: expected %#v, got %#v", i, test.expectedPorts, ports) 343 } 344 if e, a := expectedStopChan, pf.stopChan; e != a { 345 t.Fatalf("%d: stopChan: expected %#v, got %#v", i, e, a) 346 } 347 if pf.Ready == nil { 348 t.Fatalf("%d: Ready should be non-nil", i) 349 } 350 } 351 } 352 353 type GetListenerTestCase struct { 354 Hostname string 355 Protocol string 356 ShouldRaiseError bool 357 ExpectedListenerAddress string 358 } 359 360 func TestGetListener(t *testing.T) { 361 var pf PortForwarder 362 testCases := []GetListenerTestCase{ 363 { 364 Hostname: "localhost", 365 Protocol: "tcp4", 366 ShouldRaiseError: false, 367 ExpectedListenerAddress: "127.0.0.1", 368 }, 369 { 370 Hostname: "127.0.0.1", 371 Protocol: "tcp4", 372 ShouldRaiseError: false, 373 ExpectedListenerAddress: "127.0.0.1", 374 }, 375 { 376 Hostname: "::1", 377 Protocol: "tcp6", 378 ShouldRaiseError: false, 379 ExpectedListenerAddress: "::1", 380 }, 381 { 382 Hostname: "::1", 383 Protocol: "tcp4", 384 ShouldRaiseError: true, 385 }, 386 { 387 Hostname: "127.0.0.1", 388 Protocol: "tcp6", 389 ShouldRaiseError: true, 390 }, 391 } 392 393 for i, testCase := range testCases { 394 forwardedPort := &ForwardedPort{Local: 0, Remote: 12345} 395 listener, err := pf.getListener(testCase.Protocol, testCase.Hostname, forwardedPort) 396 if err != nil && strings.Contains(err.Error(), "cannot assign requested address") { 397 t.Logf("Can't test #%d: %v", i, err) 398 continue 399 } 400 expectedListenerPort := fmt.Sprintf("%d", forwardedPort.Local) 401 errorRaised := err != nil 402 403 if testCase.ShouldRaiseError != errorRaised { 404 t.Errorf("Test case #%d failed: Data %v an error has been raised(%t) where it should not (or reciprocally): %v", i, testCase, testCase.ShouldRaiseError, err) 405 continue 406 } 407 if errorRaised { 408 continue 409 } 410 411 if listener == nil { 412 t.Errorf("Test case #%d did not raise an error but failed in initializing listener", i) 413 continue 414 } 415 416 host, port, _ := net.SplitHostPort(listener.Addr().String()) 417 t.Logf("Asked a %s forward for: %s:0, got listener %s:%s, expected: %s", testCase.Protocol, testCase.Hostname, host, port, expectedListenerPort) 418 if host != testCase.ExpectedListenerAddress { 419 t.Errorf("Test case #%d failed: Listener does not listen on expected address: asked '%v' got '%v'", i, testCase.ExpectedListenerAddress, host) 420 } 421 if port != expectedListenerPort { 422 t.Errorf("Test case #%d failed: Listener does not listen on expected port: asked %v got %v", i, expectedListenerPort, port) 423 424 } 425 listener.Close() 426 } 427 } 428 429 func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) { 430 dialer := &fakeDialer{ 431 conn: newFakeConnection(), 432 negotiatedProtocol: PortForwardProtocolV1Name, 433 } 434 435 stopChan := make(chan struct{}) 436 readyChan := make(chan struct{}) 437 errChan := make(chan error) 438 439 defer func() { 440 close(stopChan) 441 442 forwardErr := <-errChan 443 if forwardErr != nil { 444 t.Fatalf("ForwardPorts returned error: %s", forwardErr) 445 } 446 }() 447 448 pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr) 449 450 if err != nil { 451 t.Fatalf("error while calling New: %s", err) 452 } 453 454 go func() { 455 errChan <- pf.ForwardPorts() 456 close(errChan) 457 }() 458 459 <-pf.Ready 460 461 ports, err := pf.GetPorts() 462 if err != nil { 463 t.Fatalf("Failed to get ports. error: %v", err) 464 } 465 466 if len(ports) != 1 { 467 t.Fatalf("expected 1 port, got %d", len(ports)) 468 } 469 470 port := ports[0] 471 if port.Local == 0 { 472 t.Fatalf("local port is 0, expected != 0") 473 } 474 } 475 476 func TestHandleConnection(t *testing.T) { 477 out := bytes.NewBufferString("") 478 479 pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, nil) 480 if err != nil { 481 t.Fatalf("error while calling New: %s", err) 482 } 483 484 // Setup fake local connection 485 localConnection := &fakeConn{ 486 sendBuffer: bytes.NewBufferString("test data from local"), 487 receiveBuffer: bytes.NewBufferString(""), 488 } 489 490 // Setup fake remote connection to send data on the data stream after it receives data from the local connection 491 remoteDataToSend := bytes.NewBufferString("test data from remote") 492 remoteDataReceived := bytes.NewBufferString("") 493 remoteErrorToSend := bytes.NewBufferString("") 494 blockRemoteSend := make(chan struct{}) 495 remoteConnection := newFakeConnection() 496 remoteConnection.dataStream.readFunc = func(p []byte) (int, error) { 497 <-blockRemoteSend // Wait for the expected data to be received before responding 498 return remoteDataToSend.Read(p) 499 } 500 remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) { 501 n, err := remoteDataReceived.Write(p) 502 if remoteDataReceived.String() == "test data from local" { 503 close(blockRemoteSend) 504 } 505 return n, err 506 } 507 remoteConnection.errorStream.readFunc = remoteErrorToSend.Read 508 pf.streamConn = remoteConnection 509 510 // Test handleConnection 511 pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) 512 assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero") 513 assert.Equal(t, "test data from local", remoteDataReceived.String()) 514 assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String()) 515 } 516 517 func TestHandleConnectionSendsRemoteError(t *testing.T) { 518 out := bytes.NewBufferString("") 519 errOut := bytes.NewBufferString("") 520 521 pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut) 522 if err != nil { 523 t.Fatalf("error while calling New: %s", err) 524 } 525 526 // Setup fake local connection 527 localConnection := &fakeConn{ 528 sendBuffer: bytes.NewBufferString(""), 529 receiveBuffer: bytes.NewBufferString(""), 530 } 531 532 // Setup fake remote connection to return an error message on the error stream 533 remoteDataToSend := bytes.NewBufferString("") 534 remoteDataReceived := bytes.NewBufferString("") 535 remoteErrorToSend := bytes.NewBufferString("error") 536 remoteConnection := newFakeConnection() 537 remoteConnection.dataStream.readFunc = remoteDataToSend.Read 538 remoteConnection.dataStream.writeFunc = remoteDataReceived.Write 539 remoteConnection.errorStream.readFunc = remoteErrorToSend.Read 540 pf.streamConn = remoteConnection 541 542 // Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan 543 pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) 544 545 assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero") 546 assert.Equal(t, "", remoteDataReceived.String()) 547 assert.Equal(t, "", localConnection.receiveBuffer.String()) 548 } 549 550 func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) { 551 out := bytes.NewBufferString("") 552 errOut := bytes.NewBufferString("") 553 554 pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut) 555 if err != nil { 556 t.Fatalf("error while calling New: %s", err) 557 } 558 559 listener := newFakeListener() 560 561 pf.streamConn = newFakeConnection() 562 pf.streamConn.Close() 563 564 port := ForwardedPort{} 565 pf.waitForConnection(&listener, port) 566 } 567 568 func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) { 569 dialer := &fakeDialer{ 570 conn: newFakeConnection(), 571 negotiatedProtocol: PortForwardProtocolV1Name, 572 } 573 574 stopChan := make(chan struct{}) 575 readyChan := make(chan struct{}) 576 errChan := make(chan error) 577 578 pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr) 579 if err != nil { 580 t.Fatalf("failed to create new PortForwarder: %s", err) 581 } 582 583 go func() { 584 errChan <- pf.ForwardPorts() 585 }() 586 587 <-pf.Ready 588 589 // Simulate lost pod connection by closing streamConn, which should result in pf.ForwardPorts() returning an error. 590 pf.streamConn.Close() 591 592 err = <-errChan 593 if err == nil { 594 t.Fatalf("unexpected non-error from pf.ForwardPorts()") 595 } else if err != ErrLostConnectionToPod { 596 t.Fatalf("unexpected error from pf.ForwardPorts(): %s", err) 597 } 598 } 599 600 func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) { 601 dialer := &fakeDialer{ 602 conn: newFakeConnection(), 603 negotiatedProtocol: PortForwardProtocolV1Name, 604 } 605 606 stopChan := make(chan struct{}) 607 readyChan := make(chan struct{}) 608 errChan := make(chan error) 609 610 pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr) 611 if err != nil { 612 t.Fatalf("failed to create new PortForwarder: %s", err) 613 } 614 615 go func() { 616 errChan <- pf.ForwardPorts() 617 }() 618 619 <-pf.Ready 620 621 // Closing (or sending to) stopChan indicates a stop request by the caller, which should result in pf.ForwardPorts() 622 // returning nil. 623 close(stopChan) 624 625 err = <-errChan 626 if err != nil { 627 t.Fatalf("unexpected error from pf.ForwardPorts(): %s", err) 628 } 629 }