github.com/tilt-dev/tilt@v0.36.0/internal/k8s/portforward/portforward.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package portforward
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"net"
    24  	"net/http"
    25  	"sort"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  
    30  	v1 "k8s.io/api/core/v1"
    31  	"k8s.io/apimachinery/pkg/util/httpstream"
    32  	netutils "k8s.io/utils/net"
    33  )
    34  
    35  // PortForwardProtocolV1Name is the subprotocol used for port forwarding.
    36  // TODO move to API machinery and re-unify with kubelet/server/portfoward
    37  const PortForwardProtocolV1Name = "portforward.k8s.io"
    38  
    39  var (
    40  	// error returned whenever we lost connection to a pod
    41  	ErrLostConnectionToPod = errors.New("lost connection to pod")
    42  
    43  	// set of error we're expecting during port-forwarding
    44  	networkClosedError = "use of closed network connection"
    45  )
    46  
    47  // PortForwarder knows how to listen for local connections and forward them to
    48  // a remote pod via an upgraded HTTP request.
    49  type PortForwarder struct {
    50  	addresses []listenAddress
    51  	ports     []ForwardedPort
    52  	stopChan  <-chan struct{}
    53  
    54  	dialer        httpstream.Dialer
    55  	streamConn    httpstream.Connection
    56  	errorHandler  *errorHandler
    57  	listeners     []io.Closer
    58  	Ready         chan struct{}
    59  	requestIDLock sync.Mutex
    60  	requestID     int
    61  	out           io.Writer
    62  	errOut        io.Writer
    63  }
    64  
    65  // ForwardedPort contains a Local:Remote port pairing.
    66  type ForwardedPort struct {
    67  	Local  uint16
    68  	Remote uint16
    69  }
    70  
    71  /*
    72  valid port specifications:
    73  
    74  5000
    75  - forwards from localhost:5000 to pod:5000
    76  
    77  8888:5000
    78  - forwards from localhost:8888 to pod:5000
    79  
    80  0:5000
    81  :5000
    82    - selects a random available local port,
    83      forwards from localhost:<random port> to pod:5000
    84  */
    85  func parsePorts(ports []string) ([]ForwardedPort, error) {
    86  	var forwards []ForwardedPort
    87  	for _, portString := range ports {
    88  		parts := strings.Split(portString, ":")
    89  		var localString, remoteString string
    90  		if len(parts) == 1 {
    91  			localString = parts[0]
    92  			remoteString = parts[0]
    93  		} else if len(parts) == 2 {
    94  			localString = parts[0]
    95  			if localString == "" {
    96  				// support :5000
    97  				localString = "0"
    98  			}
    99  			remoteString = parts[1]
   100  		} else {
   101  			return nil, fmt.Errorf("invalid port format '%s'", portString)
   102  		}
   103  
   104  		localPort, err := strconv.ParseUint(localString, 10, 16)
   105  		if err != nil {
   106  			return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
   107  		}
   108  
   109  		remotePort, err := strconv.ParseUint(remoteString, 10, 16)
   110  		if err != nil {
   111  			return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
   112  		}
   113  		if remotePort == 0 {
   114  			return nil, fmt.Errorf("remote port must be > 0")
   115  		}
   116  
   117  		forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
   118  	}
   119  
   120  	return forwards, nil
   121  }
   122  
   123  type listenAddress struct {
   124  	address     string
   125  	protocol    string
   126  	failureMode string
   127  }
   128  
   129  func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
   130  	var addresses []listenAddress
   131  	parsed := make(map[string]listenAddress)
   132  	for _, address := range addressesToParse {
   133  		if address == "localhost" {
   134  			if _, exists := parsed["127.0.0.1"]; !exists {
   135  				ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
   136  				parsed[ip.address] = ip
   137  			}
   138  			if _, exists := parsed["::1"]; !exists {
   139  				ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
   140  				parsed[ip.address] = ip
   141  			}
   142  		} else if netutils.ParseIPSloppy(address).To4() != nil {
   143  			parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
   144  		} else if netutils.ParseIPSloppy(address) != nil {
   145  			parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
   146  		} else {
   147  			return nil, fmt.Errorf("%s is not a valid IP", address)
   148  		}
   149  	}
   150  	addresses = make([]listenAddress, len(parsed))
   151  	id := 0
   152  	for _, v := range parsed {
   153  		addresses[id] = v
   154  		id++
   155  	}
   156  	// Sort addresses before returning to get a stable order
   157  	sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })
   158  
   159  	return addresses, nil
   160  }
   161  
   162  // New creates a new PortForwarder with localhost listen addresses.
   163  func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
   164  	return NewOnAddresses(dialer, []string{"localhost"}, ports, stopChan, readyChan, out, errOut)
   165  }
   166  
   167  // NewOnAddresses creates a new PortForwarder with custom listen addresses.
   168  func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
   169  	if len(addresses) == 0 {
   170  		return nil, errors.New("you must specify at least 1 address")
   171  	}
   172  	parsedAddresses, err := parseAddresses(addresses)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	if len(ports) == 0 {
   177  		return nil, errors.New("you must specify at least 1 port")
   178  	}
   179  	parsedPorts, err := parsePorts(ports)
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  	return &PortForwarder{
   184  		dialer:    dialer,
   185  		addresses: parsedAddresses,
   186  		ports:     parsedPorts,
   187  		stopChan:  stopChan,
   188  		Ready:     readyChan,
   189  		out:       out,
   190  		errOut:    errOut,
   191  	}, nil
   192  }
   193  
   194  func (pf *PortForwarder) Addresses() []string {
   195  	var addresses []string
   196  	for _, la := range pf.addresses {
   197  		addresses = append(addresses, la.address)
   198  	}
   199  	return addresses
   200  }
   201  
   202  // ForwardPorts formats and executes a port forwarding request. The connection will remain
   203  // open until stopChan is closed.
   204  func (pf *PortForwarder) ForwardPorts() error {
   205  	defer pf.Close()
   206  
   207  	var err error
   208  	var protocol string
   209  	pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
   210  	if err != nil {
   211  		return fmt.Errorf("error upgrading connection: %s", err)
   212  	}
   213  	defer pf.streamConn.Close()
   214  	if protocol != PortForwardProtocolV1Name {
   215  		return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
   216  	}
   217  
   218  	return pf.forward()
   219  }
   220  
   221  // forward dials the remote host specific in req, upgrades the request, starts
   222  // listeners for each port specified in ports, and forwards local connections
   223  // to the remote host via streams.
   224  //
   225  // Returns an error if any of the local ports aren't available.
   226  func (pf *PortForwarder) forward() error {
   227  	var err error
   228  	pf.errorHandler = newErrorHandler()
   229  	defer pf.errorHandler.Close()
   230  
   231  	for i := range pf.ports {
   232  		port := &pf.ports[i]
   233  		err = pf.listenOnPort(port)
   234  		if err != nil {
   235  			return fmt.Errorf("Unable to listen on port %d: %v", port.Local, err)
   236  		}
   237  	}
   238  
   239  	if pf.Ready != nil {
   240  		close(pf.Ready)
   241  	}
   242  
   243  	// wait for interrupt or conn closure
   244  	select {
   245  	case err := <-pf.errorHandler.Done():
   246  		return err
   247  	case <-pf.stopChan:
   248  	case <-pf.streamConn.CloseChan():
   249  		return ErrLostConnectionToPod
   250  	}
   251  
   252  	return nil
   253  }
   254  
   255  // listenOnPort delegates listener creation and waits for connections on requested bind addresses.
   256  // An error is raised based on address groups (default and localhost) and their failure modes
   257  func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
   258  	var errors []error
   259  	failCounters := make(map[string]int, 2)
   260  	successCounters := make(map[string]int, 2)
   261  	for _, addr := range pf.addresses {
   262  		err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
   263  		if err != nil {
   264  			errors = append(errors, err)
   265  			failCounters[addr.failureMode]++
   266  		} else {
   267  			successCounters[addr.failureMode]++
   268  		}
   269  	}
   270  	if successCounters["all"] == 0 && failCounters["all"] > 0 {
   271  		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
   272  	}
   273  	if failCounters["any"] > 0 {
   274  		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
   275  	}
   276  	return nil
   277  }
   278  
   279  // listenOnPortAndAddress delegates listener creation and waits for new connections
   280  // in the background f
   281  func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
   282  	listener, err := pf.getListener(protocol, address, port)
   283  	if err != nil {
   284  		return err
   285  	}
   286  	pf.listeners = append(pf.listeners, listener)
   287  	go pf.waitForConnection(listener, *port)
   288  	return nil
   289  }
   290  
   291  // getListener creates a listener on the interface targeted by the given hostname on the given port with
   292  // the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
   293  func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
   294  	listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
   295  	if err != nil {
   296  		return nil, fmt.Errorf("unable to create listener: Error %s", err)
   297  	}
   298  	listenerAddress := listener.Addr().String()
   299  	host, localPort, _ := net.SplitHostPort(listenerAddress)
   300  	localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
   301  
   302  	if err != nil {
   303  		return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
   304  	}
   305  	port.Local = uint16(localPortUInt)
   306  
   307  	return listener, nil
   308  }
   309  
   310  // waitForConnection waits for new connections to listener and handles them in
   311  // the background.
   312  func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
   313  	for {
   314  		select {
   315  		case <-pf.streamConn.CloseChan():
   316  			return
   317  		default:
   318  			conn, err := listener.Accept()
   319  			if err != nil {
   320  				// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
   321  				if !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
   322  					_, _ = fmt.Fprintf(pf.out, "error accepting connection on port %d: %v", port.Local, err)
   323  				}
   324  				return
   325  			}
   326  			go pf.handleConnection(conn, port)
   327  		}
   328  	}
   329  }
   330  
   331  func (pf *PortForwarder) nextRequestID() int {
   332  	pf.requestIDLock.Lock()
   333  	defer pf.requestIDLock.Unlock()
   334  	id := pf.requestID
   335  	pf.requestID++
   336  	return id
   337  }
   338  
   339  // handleConnection copies data between the local connection and the stream to
   340  // the remote server.
   341  func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
   342  	defer conn.Close()
   343  
   344  	requestID := pf.nextRequestID()
   345  
   346  	// create error stream
   347  	headers := http.Header{}
   348  	headers.Set(v1.StreamType, v1.StreamTypeError)
   349  	headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
   350  	headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
   351  	errorStream, err := pf.streamConn.CreateStream(headers)
   352  	if err != nil {
   353  		// If CreateStream fails, stop the whole portforwarder, because this might
   354  		// mean the whole streamConn is wedged. The PortForward reconciler will backoff
   355  		// and re-create the connection.
   356  		pf.errorHandler.Stop(fmt.Errorf("creating stream: %v", err))
   357  		return
   358  	}
   359  	// we're not writing to this stream
   360  	errorStream.Close()
   361  	defer pf.streamConn.RemoveStreams(errorStream)
   362  
   363  	errorChan := make(chan error)
   364  	go func() {
   365  		message, err := io.ReadAll(errorStream)
   366  		switch {
   367  		case err != nil:
   368  			errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
   369  		case len(message) > 0:
   370  			errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
   371  		}
   372  		close(errorChan)
   373  	}()
   374  
   375  	// create data stream
   376  	headers.Set(v1.StreamType, v1.StreamTypeData)
   377  	dataStream, err := pf.streamConn.CreateStream(headers)
   378  	if err != nil {
   379  		// If CreateStream fails, stop the whole portforwarder, because this might
   380  		// mean the whole streamConn is wedged. The PortForward reconciler will backoff
   381  		// and re-create the connection.
   382  		pf.errorHandler.Stop(fmt.Errorf("creating stream: %v", err))
   383  		return
   384  	}
   385  	defer pf.streamConn.RemoveStreams(dataStream)
   386  
   387  	localError := make(chan struct{})
   388  	remoteDone := make(chan struct{})
   389  
   390  	go func() {
   391  		// Copy from the remote side to the local port.
   392  		if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
   393  			_, _ = fmt.Fprintf(pf.out, "error copying from remote stream to local connection: %v", err)
   394  		}
   395  
   396  		// inform the select below that the remote copy is done
   397  		close(remoteDone)
   398  	}()
   399  
   400  	go func() {
   401  		// inform server we're not sending any more data after copy unblocks
   402  		defer dataStream.Close()
   403  
   404  		// Copy from the local port to the remote side.
   405  		if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
   406  			_, _ = fmt.Fprintf(pf.out, "error copying from local connection to remote stream: %v", err)
   407  			// break out of the select below without waiting for the other copy to finish
   408  			close(localError)
   409  		}
   410  	}()
   411  
   412  	// wait for either a local->remote error or for copying from remote->local to finish
   413  	select {
   414  	case <-remoteDone:
   415  	case <-localError:
   416  	}
   417  
   418  	// reset dataStream to discard any unsent data, preventing port forwarding from being blocked.
   419  	// we must reset dataStream before waiting on errorChan, otherwise,
   420  	// the blocking data will affect errorStream and cause <-errorChan to block indefinitely.
   421  	_ = dataStream.Reset()
   422  
   423  	// always expect something on errorChan (it may be nil)
   424  	err = <-errorChan
   425  	if err != nil {
   426  		_, _ = fmt.Fprintf(pf.out, "%v", err)
   427  		pf.streamConn.Close()
   428  	}
   429  }
   430  
   431  // Close stops all listeners of PortForwarder.
   432  func (pf *PortForwarder) Close() {
   433  	// stop all listeners
   434  	for _, l := range pf.listeners {
   435  		if err := l.Close(); err != nil {
   436  			_, _ = fmt.Fprintf(pf.out, "error closing listener: %v", err)
   437  		}
   438  	}
   439  }
   440  
   441  // GetPorts will return the ports that were forwarded; this can be used to
   442  // retrieve the locally-bound port in cases where the input was port 0. This
   443  // function will signal an error if the Ready channel is nil or if the
   444  // listeners are not ready yet; this function will succeed after the Ready
   445  // channel has been closed.
   446  func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
   447  	if pf.Ready == nil {
   448  		return nil, fmt.Errorf("no Ready channel provided")
   449  	}
   450  	select {
   451  	case <-pf.Ready:
   452  		return pf.ports, nil
   453  	default:
   454  		return nil, fmt.Errorf("listeners not ready")
   455  	}
   456  }