github.com/tilt-dev/tilt@v0.33.15-0.20240515162809-0a22ed45d8a0/internal/k8s/portforward/portforward.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package portforward
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"net/http"
    26  	"sort"
    27  	"strconv"
    28  	"strings"
    29  	"sync"
    30  
    31  	v1 "k8s.io/api/core/v1"
    32  	"k8s.io/apimachinery/pkg/util/httpstream"
    33  	netutils "k8s.io/utils/net"
    34  
    35  	"github.com/tilt-dev/tilt/pkg/logger"
    36  )
    37  
    38  // PortForwardProtocolV1Name is the subprotocol used for port forwarding.
    39  // TODO move to API machinery and re-unify with kubelet/server/portfoward
    40  const PortForwardProtocolV1Name = "portforward.k8s.io"
    41  
    42  // PortForwarder knows how to listen for local connections and forward them to
    43  // a remote pod via an upgraded HTTP request.
    44  type PortForwarder struct {
    45  	ctx       context.Context
    46  	addresses []listenAddress
    47  	ports     []ForwardedPort
    48  
    49  	dialer        httpstream.Dialer
    50  	streamConn    httpstream.Connection
    51  	errorHandler  *errorHandler
    52  	listeners     []io.Closer
    53  	Ready         chan struct{}
    54  	requestIDLock sync.Mutex
    55  	requestID     int
    56  }
    57  
    58  // ForwardedPort contains a Local:Remote port pairing.
    59  type ForwardedPort struct {
    60  	Local  uint16
    61  	Remote uint16
    62  }
    63  
    64  /*
    65  valid port specifications:
    66  
    67  5000
    68  - forwards from localhost:5000 to pod:5000
    69  
    70  8888:5000
    71  - forwards from localhost:8888 to pod:5000
    72  
    73  0:5000
    74  :5000
    75    - selects a random available local port,
    76      forwards from localhost:<random port> to pod:5000
    77  */
    78  func parsePorts(ports []string) ([]ForwardedPort, error) {
    79  	var forwards []ForwardedPort
    80  	for _, portString := range ports {
    81  		parts := strings.Split(portString, ":")
    82  		var localString, remoteString string
    83  		if len(parts) == 1 {
    84  			localString = parts[0]
    85  			remoteString = parts[0]
    86  		} else if len(parts) == 2 {
    87  			localString = parts[0]
    88  			if localString == "" {
    89  				// support :5000
    90  				localString = "0"
    91  			}
    92  			remoteString = parts[1]
    93  		} else {
    94  			return nil, fmt.Errorf("invalid port format '%s'", portString)
    95  		}
    96  
    97  		localPort, err := strconv.ParseUint(localString, 10, 16)
    98  		if err != nil {
    99  			return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
   100  		}
   101  
   102  		remotePort, err := strconv.ParseUint(remoteString, 10, 16)
   103  		if err != nil {
   104  			return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
   105  		}
   106  		if remotePort == 0 {
   107  			return nil, fmt.Errorf("remote port must be > 0")
   108  		}
   109  
   110  		forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
   111  	}
   112  
   113  	return forwards, nil
   114  }
   115  
   116  type listenAddress struct {
   117  	address     string
   118  	protocol    string
   119  	failureMode string
   120  }
   121  
   122  func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
   123  	var addresses []listenAddress
   124  	parsed := make(map[string]listenAddress)
   125  	for _, address := range addressesToParse {
   126  		if address == "localhost" {
   127  			if _, exists := parsed["127.0.0.1"]; !exists {
   128  				ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
   129  				parsed[ip.address] = ip
   130  			}
   131  			if _, exists := parsed["::1"]; !exists {
   132  				ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
   133  				parsed[ip.address] = ip
   134  			}
   135  		} else if netutils.ParseIPSloppy(address).To4() != nil {
   136  			parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
   137  		} else if netutils.ParseIPSloppy(address) != nil {
   138  			parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
   139  		} else {
   140  			return nil, fmt.Errorf("%s is not a valid IP", address)
   141  		}
   142  	}
   143  	addresses = make([]listenAddress, len(parsed))
   144  	id := 0
   145  	for _, v := range parsed {
   146  		addresses[id] = v
   147  		id++
   148  	}
   149  	// Sort addresses before returning to get a stable order
   150  	sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })
   151  
   152  	return addresses, nil
   153  }
   154  
   155  // New creates a new PortForwarder with localhost listen addresses.
   156  func New(ctx context.Context, dialer httpstream.Dialer, ports []string, readyChan chan struct{}) (*PortForwarder, error) {
   157  	return NewOnAddresses(ctx, dialer, []string{"localhost"}, ports, readyChan)
   158  }
   159  
   160  // NewOnAddresses creates a new PortForwarder with custom listen addresses.
   161  func NewOnAddresses(ctx context.Context, dialer httpstream.Dialer, addresses []string, ports []string, readyChan chan struct{}) (*PortForwarder, error) {
   162  	if len(addresses) == 0 {
   163  		return nil, errors.New("you must specify at least 1 address")
   164  	}
   165  	parsedAddresses, err := parseAddresses(addresses)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	if len(ports) == 0 {
   170  		return nil, errors.New("you must specify at least 1 port")
   171  	}
   172  	parsedPorts, err := parsePorts(ports)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	return &PortForwarder{
   177  		ctx:       ctx,
   178  		dialer:    dialer,
   179  		addresses: parsedAddresses,
   180  		ports:     parsedPorts,
   181  		Ready:     readyChan,
   182  	}, nil
   183  }
   184  
   185  func (pf *PortForwarder) Addresses() []string {
   186  	var addresses []string
   187  	for _, la := range pf.addresses {
   188  		addresses = append(addresses, la.address)
   189  	}
   190  	return addresses
   191  }
   192  
   193  // ForwardPorts formats and executes a port forwarding request. The connection will remain
   194  // open until stopChan is closed.
   195  func (pf *PortForwarder) ForwardPorts() error {
   196  	defer pf.Close()
   197  
   198  	var err error
   199  	pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name)
   200  	if err != nil {
   201  		return fmt.Errorf("error upgrading connection: %s", err)
   202  	}
   203  	defer pf.streamConn.Close()
   204  
   205  	return pf.forward()
   206  }
   207  
   208  // forward dials the remote host specific in req, upgrades the request, starts
   209  // listeners for each port specified in ports, and forwards local connections
   210  // to the remote host via streams.
   211  //
   212  // Returns an error if any of the local ports aren't available.
   213  func (pf *PortForwarder) forward() error {
   214  	var err error
   215  	pf.errorHandler = newErrorHandler()
   216  	defer pf.errorHandler.Close()
   217  
   218  	for i := range pf.ports {
   219  		port := &pf.ports[i]
   220  		err = pf.listenOnPort(port)
   221  		if err != nil {
   222  			return fmt.Errorf("Unable to listen on port %d: %v", port.Local, err)
   223  		}
   224  	}
   225  
   226  	if pf.Ready != nil {
   227  		close(pf.Ready)
   228  	}
   229  
   230  	// wait for interrupt or conn closure
   231  	select {
   232  	case err := <-pf.errorHandler.Done():
   233  		return err
   234  
   235  	case <-pf.ctx.Done():
   236  	case <-pf.streamConn.CloseChan():
   237  		logger.Get(pf.ctx).Debugf("lost connection to pod")
   238  	}
   239  
   240  	return nil
   241  }
   242  
   243  // listenOnPort delegates listener creation and waits for connections on requested bind addresses.
   244  // An error is raised based on address groups (default and localhost) and their failure modes
   245  func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
   246  	var errors []error
   247  	failCounters := make(map[string]int, 2)
   248  	successCounters := make(map[string]int, 2)
   249  	for _, addr := range pf.addresses {
   250  		err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
   251  		if err != nil {
   252  			errors = append(errors, err)
   253  			failCounters[addr.failureMode]++
   254  		} else {
   255  			successCounters[addr.failureMode]++
   256  		}
   257  	}
   258  	if successCounters["all"] == 0 && failCounters["all"] > 0 {
   259  		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
   260  	}
   261  	if failCounters["any"] > 0 {
   262  		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
   263  	}
   264  	return nil
   265  }
   266  
   267  // listenOnPortAndAddress delegates listener creation and waits for new connections
   268  // in the background f
   269  func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
   270  	listener, err := pf.getListener(protocol, address, port)
   271  	if err != nil {
   272  		return err
   273  	}
   274  	pf.listeners = append(pf.listeners, listener)
   275  	go pf.waitForConnection(listener, *port)
   276  	return nil
   277  }
   278  
   279  // getListener creates a listener on the interface targeted by the given hostname on the given port with
   280  // the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
   281  func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
   282  	listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
   283  	if err != nil {
   284  		return nil, fmt.Errorf("unable to create listener: Error %s", err)
   285  	}
   286  	listenerAddress := listener.Addr().String()
   287  	host, localPort, _ := net.SplitHostPort(listenerAddress)
   288  	localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
   289  
   290  	if err != nil {
   291  		return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
   292  	}
   293  	port.Local = uint16(localPortUInt)
   294  
   295  	return listener, nil
   296  }
   297  
   298  // waitForConnection waits for new connections to listener and handles them in
   299  // the background.
   300  func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
   301  	for {
   302  		select {
   303  		case <-pf.streamConn.CloseChan():
   304  			return
   305  		default:
   306  			conn, err := listener.Accept()
   307  			if err != nil {
   308  				// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
   309  				if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
   310  					logger.Get(pf.ctx).Debugf("error accepting connection on port %d: %v", port.Local, err)
   311  				}
   312  				return
   313  			}
   314  			go pf.handleConnection(conn, port)
   315  		}
   316  	}
   317  }
   318  
   319  func (pf *PortForwarder) nextRequestID() int {
   320  	pf.requestIDLock.Lock()
   321  	defer pf.requestIDLock.Unlock()
   322  	id := pf.requestID
   323  	pf.requestID++
   324  	return id
   325  }
   326  
   327  // handleConnection copies data between the local connection and the stream to
   328  // the remote server.
   329  func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
   330  	defer conn.Close()
   331  
   332  	requestID := pf.nextRequestID()
   333  
   334  	// create error stream
   335  	headers := http.Header{}
   336  	headers.Set(v1.StreamType, v1.StreamTypeError)
   337  	headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
   338  	headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
   339  	errorStream, err := pf.streamConn.CreateStream(headers)
   340  	if err != nil {
   341  		// If CreateStream fails, stop the whole portforwarder, because this might
   342  		// mean the whole streamConn is wedged. The PortForward reconciler will backoff
   343  		// and re-create the connection.
   344  		pf.errorHandler.Stop(fmt.Errorf("creating stream: %v", err))
   345  		return
   346  	}
   347  	// we're not writing to this stream
   348  	errorStream.Close()
   349  
   350  	errorChan := make(chan error)
   351  	go func() {
   352  		message, err := io.ReadAll(errorStream)
   353  		switch {
   354  		case err != nil:
   355  			errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
   356  		case len(message) > 0:
   357  			errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
   358  		}
   359  		close(errorChan)
   360  	}()
   361  
   362  	// create data stream
   363  	headers.Set(v1.StreamType, v1.StreamTypeData)
   364  	dataStream, err := pf.streamConn.CreateStream(headers)
   365  	if err != nil {
   366  		// If CreateStream fails, stop the whole portforwarder, because this might
   367  		// mean the whole streamConn is wedged. The PortForward reconciler will backoff
   368  		// and re-create the connection.
   369  		pf.errorHandler.Stop(fmt.Errorf("creating stream: %v", err))
   370  		return
   371  	}
   372  
   373  	localError := make(chan struct{})
   374  	remoteDone := make(chan struct{})
   375  
   376  	go func() {
   377  		// Copy from the remote side to the local port.
   378  		if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
   379  			logger.Get(pf.ctx).Debugf("error copying from remote stream to local connection: %v", err)
   380  		}
   381  
   382  		// inform the select below that the remote copy is done
   383  		close(remoteDone)
   384  	}()
   385  
   386  	go func() {
   387  		// inform server we're not sending any more data after copy unblocks
   388  		defer dataStream.Close()
   389  
   390  		// Copy from the local port to the remote side.
   391  		if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
   392  			logger.Get(pf.ctx).Debugf("error copying from local connection to remote stream: %v", err)
   393  			// break out of the select below without waiting for the other copy to finish
   394  			close(localError)
   395  		}
   396  	}()
   397  
   398  	// wait for either a local->remote error or for copying from remote->local to finish
   399  	select {
   400  	case <-remoteDone:
   401  	case <-localError:
   402  	}
   403  
   404  	// always expect something on errorChan (it may be nil)
   405  	err = <-errorChan
   406  	if err != nil {
   407  		logger.Get(pf.ctx).Debugf("%v", err)
   408  		pf.streamConn.Close()
   409  	}
   410  }
   411  
   412  // Close stops all listeners of PortForwarder.
   413  func (pf *PortForwarder) Close() {
   414  	// stop all listeners
   415  	for _, l := range pf.listeners {
   416  		if err := l.Close(); err != nil {
   417  			logger.Get(pf.ctx).Debugf("error closing listener: %v", err)
   418  		}
   419  	}
   420  }
   421  
   422  // GetPorts will return the ports that were forwarded; this can be used to
   423  // retrieve the locally-bound port in cases where the input was port 0. This
   424  // function will signal an error if the Ready channel is nil or if the
   425  // listeners are not ready yet; this function will succeed after the Ready
   426  // channel has been closed.
   427  func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
   428  	if pf.Ready == nil {
   429  		return nil, fmt.Errorf("no Ready channel provided")
   430  	}
   431  	select {
   432  	case <-pf.Ready:
   433  		return pf.ports, nil
   434  	default:
   435  		return nil, fmt.Errorf("listeners not ready")
   436  	}
   437  }