k8s.io/client-go@v0.22.2/tools/portforward/portforward.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package portforward
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"io/ioutil"
    24  	"net"
    25  	"net/http"
    26  	"sort"
    27  	"strconv"
    28  	"strings"
    29  	"sync"
    30  
    31  	"k8s.io/api/core/v1"
    32  	"k8s.io/apimachinery/pkg/util/httpstream"
    33  	"k8s.io/apimachinery/pkg/util/runtime"
    34  )
    35  
    36  // PortForwardProtocolV1Name is the subprotocol used for port forwarding.
    37  // TODO move to API machinery and re-unify with kubelet/server/portfoward
    38  const PortForwardProtocolV1Name = "portforward.k8s.io"
    39  
    40  // PortForwarder knows how to listen for local connections and forward them to
    41  // a remote pod via an upgraded HTTP request.
    42  type PortForwarder struct {
    43  	addresses []listenAddress
    44  	ports     []ForwardedPort
    45  	stopChan  <-chan struct{}
    46  
    47  	dialer        httpstream.Dialer
    48  	streamConn    httpstream.Connection
    49  	listeners     []io.Closer
    50  	Ready         chan struct{}
    51  	requestIDLock sync.Mutex
    52  	requestID     int
    53  	out           io.Writer
    54  	errOut        io.Writer
    55  }
    56  
    57  // ForwardedPort contains a Local:Remote port pairing.
    58  type ForwardedPort struct {
    59  	Local  uint16
    60  	Remote uint16
    61  }
    62  
    63  /*
    64  	valid port specifications:
    65  
    66  	5000
    67  	- forwards from localhost:5000 to pod:5000
    68  
    69  	8888:5000
    70  	- forwards from localhost:8888 to pod:5000
    71  
    72  	0:5000
    73  	:5000
    74  	- selects a random available local port,
    75  	  forwards from localhost:<random port> to pod:5000
    76  */
    77  func parsePorts(ports []string) ([]ForwardedPort, error) {
    78  	var forwards []ForwardedPort
    79  	for _, portString := range ports {
    80  		parts := strings.Split(portString, ":")
    81  		var localString, remoteString string
    82  		if len(parts) == 1 {
    83  			localString = parts[0]
    84  			remoteString = parts[0]
    85  		} else if len(parts) == 2 {
    86  			localString = parts[0]
    87  			if localString == "" {
    88  				// support :5000
    89  				localString = "0"
    90  			}
    91  			remoteString = parts[1]
    92  		} else {
    93  			return nil, fmt.Errorf("invalid port format '%s'", portString)
    94  		}
    95  
    96  		localPort, err := strconv.ParseUint(localString, 10, 16)
    97  		if err != nil {
    98  			return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
    99  		}
   100  
   101  		remotePort, err := strconv.ParseUint(remoteString, 10, 16)
   102  		if err != nil {
   103  			return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
   104  		}
   105  		if remotePort == 0 {
   106  			return nil, fmt.Errorf("remote port must be > 0")
   107  		}
   108  
   109  		forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
   110  	}
   111  
   112  	return forwards, nil
   113  }
   114  
   115  type listenAddress struct {
   116  	address     string
   117  	protocol    string
   118  	failureMode string
   119  }
   120  
   121  func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
   122  	var addresses []listenAddress
   123  	parsed := make(map[string]listenAddress)
   124  	for _, address := range addressesToParse {
   125  		if address == "localhost" {
   126  			if _, exists := parsed["127.0.0.1"]; !exists {
   127  				ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
   128  				parsed[ip.address] = ip
   129  			}
   130  			if _, exists := parsed["::1"]; !exists {
   131  				ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
   132  				parsed[ip.address] = ip
   133  			}
   134  		} else if net.ParseIP(address).To4() != nil {
   135  			parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
   136  		} else if net.ParseIP(address) != nil {
   137  			parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
   138  		} else {
   139  			return nil, fmt.Errorf("%s is not a valid IP", address)
   140  		}
   141  	}
   142  	addresses = make([]listenAddress, len(parsed))
   143  	id := 0
   144  	for _, v := range parsed {
   145  		addresses[id] = v
   146  		id++
   147  	}
   148  	// Sort addresses before returning to get a stable order
   149  	sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })
   150  
   151  	return addresses, nil
   152  }
   153  
   154  // New creates a new PortForwarder with localhost listen addresses.
   155  func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
   156  	return NewOnAddresses(dialer, []string{"localhost"}, ports, stopChan, readyChan, out, errOut)
   157  }
   158  
   159  // NewOnAddresses creates a new PortForwarder with custom listen addresses.
   160  func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
   161  	if len(addresses) == 0 {
   162  		return nil, errors.New("you must specify at least 1 address")
   163  	}
   164  	parsedAddresses, err := parseAddresses(addresses)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	if len(ports) == 0 {
   169  		return nil, errors.New("you must specify at least 1 port")
   170  	}
   171  	parsedPorts, err := parsePorts(ports)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	return &PortForwarder{
   176  		dialer:    dialer,
   177  		addresses: parsedAddresses,
   178  		ports:     parsedPorts,
   179  		stopChan:  stopChan,
   180  		Ready:     readyChan,
   181  		out:       out,
   182  		errOut:    errOut,
   183  	}, nil
   184  }
   185  
   186  // ForwardPorts formats and executes a port forwarding request. The connection will remain
   187  // open until stopChan is closed.
   188  func (pf *PortForwarder) ForwardPorts() error {
   189  	defer pf.Close()
   190  
   191  	var err error
   192  	pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name)
   193  	if err != nil {
   194  		return fmt.Errorf("error upgrading connection: %s", err)
   195  	}
   196  	defer pf.streamConn.Close()
   197  
   198  	return pf.forward()
   199  }
   200  
   201  // forward dials the remote host specific in req, upgrades the request, starts
   202  // listeners for each port specified in ports, and forwards local connections
   203  // to the remote host via streams.
   204  func (pf *PortForwarder) forward() error {
   205  	var err error
   206  
   207  	listenSuccess := false
   208  	for i := range pf.ports {
   209  		port := &pf.ports[i]
   210  		err = pf.listenOnPort(port)
   211  		switch {
   212  		case err == nil:
   213  			listenSuccess = true
   214  		default:
   215  			if pf.errOut != nil {
   216  				fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
   217  			}
   218  		}
   219  	}
   220  
   221  	if !listenSuccess {
   222  		return fmt.Errorf("unable to listen on any of the requested ports: %v", pf.ports)
   223  	}
   224  
   225  	if pf.Ready != nil {
   226  		close(pf.Ready)
   227  	}
   228  
   229  	// wait for interrupt or conn closure
   230  	select {
   231  	case <-pf.stopChan:
   232  	case <-pf.streamConn.CloseChan():
   233  		runtime.HandleError(errors.New("lost connection to pod"))
   234  	}
   235  
   236  	return nil
   237  }
   238  
   239  // listenOnPort delegates listener creation and waits for connections on requested bind addresses.
   240  // An error is raised based on address groups (default and localhost) and their failure modes
   241  func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
   242  	var errors []error
   243  	failCounters := make(map[string]int, 2)
   244  	successCounters := make(map[string]int, 2)
   245  	for _, addr := range pf.addresses {
   246  		err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
   247  		if err != nil {
   248  			errors = append(errors, err)
   249  			failCounters[addr.failureMode]++
   250  		} else {
   251  			successCounters[addr.failureMode]++
   252  		}
   253  	}
   254  	if successCounters["all"] == 0 && failCounters["all"] > 0 {
   255  		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
   256  	}
   257  	if failCounters["any"] > 0 {
   258  		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
   259  	}
   260  	return nil
   261  }
   262  
   263  // listenOnPortAndAddress delegates listener creation and waits for new connections
   264  // in the background f
   265  func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
   266  	listener, err := pf.getListener(protocol, address, port)
   267  	if err != nil {
   268  		return err
   269  	}
   270  	pf.listeners = append(pf.listeners, listener)
   271  	go pf.waitForConnection(listener, *port)
   272  	return nil
   273  }
   274  
   275  // getListener creates a listener on the interface targeted by the given hostname on the given port with
   276  // the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
   277  func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
   278  	listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
   279  	if err != nil {
   280  		return nil, fmt.Errorf("unable to create listener: Error %s", err)
   281  	}
   282  	listenerAddress := listener.Addr().String()
   283  	host, localPort, _ := net.SplitHostPort(listenerAddress)
   284  	localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
   285  
   286  	if err != nil {
   287  		fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote)
   288  		return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
   289  	}
   290  	port.Local = uint16(localPortUInt)
   291  	if pf.out != nil {
   292  		fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote)
   293  	}
   294  
   295  	return listener, nil
   296  }
   297  
   298  // waitForConnection waits for new connections to listener and handles them in
   299  // the background.
   300  func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
   301  	for {
   302  		conn, err := listener.Accept()
   303  		if err != nil {
   304  			// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
   305  			if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
   306  				runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
   307  			}
   308  			return
   309  		}
   310  		go pf.handleConnection(conn, port)
   311  	}
   312  }
   313  
   314  func (pf *PortForwarder) nextRequestID() int {
   315  	pf.requestIDLock.Lock()
   316  	defer pf.requestIDLock.Unlock()
   317  	id := pf.requestID
   318  	pf.requestID++
   319  	return id
   320  }
   321  
   322  // handleConnection copies data between the local connection and the stream to
   323  // the remote server.
   324  func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
   325  	defer conn.Close()
   326  
   327  	if pf.out != nil {
   328  		fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local)
   329  	}
   330  
   331  	requestID := pf.nextRequestID()
   332  
   333  	// create error stream
   334  	headers := http.Header{}
   335  	headers.Set(v1.StreamType, v1.StreamTypeError)
   336  	headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
   337  	headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
   338  	errorStream, err := pf.streamConn.CreateStream(headers)
   339  	if err != nil {
   340  		runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
   341  		return
   342  	}
   343  	// we're not writing to this stream
   344  	errorStream.Close()
   345  
   346  	errorChan := make(chan error)
   347  	go func() {
   348  		message, err := ioutil.ReadAll(errorStream)
   349  		switch {
   350  		case err != nil:
   351  			errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
   352  		case len(message) > 0:
   353  			errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
   354  		}
   355  		close(errorChan)
   356  	}()
   357  
   358  	// create data stream
   359  	headers.Set(v1.StreamType, v1.StreamTypeData)
   360  	dataStream, err := pf.streamConn.CreateStream(headers)
   361  	if err != nil {
   362  		runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
   363  		return
   364  	}
   365  
   366  	localError := make(chan struct{})
   367  	remoteDone := make(chan struct{})
   368  
   369  	go func() {
   370  		// Copy from the remote side to the local port.
   371  		if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
   372  			runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
   373  		}
   374  
   375  		// inform the select below that the remote copy is done
   376  		close(remoteDone)
   377  	}()
   378  
   379  	go func() {
   380  		// inform server we're not sending any more data after copy unblocks
   381  		defer dataStream.Close()
   382  
   383  		// Copy from the local port to the remote side.
   384  		if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
   385  			runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
   386  			// break out of the select below without waiting for the other copy to finish
   387  			close(localError)
   388  		}
   389  	}()
   390  
   391  	// wait for either a local->remote error or for copying from remote->local to finish
   392  	select {
   393  	case <-remoteDone:
   394  	case <-localError:
   395  	}
   396  
   397  	// always expect something on errorChan (it may be nil)
   398  	err = <-errorChan
   399  	if err != nil {
   400  		runtime.HandleError(err)
   401  	}
   402  }
   403  
   404  // Close stops all listeners of PortForwarder.
   405  func (pf *PortForwarder) Close() {
   406  	// stop all listeners
   407  	for _, l := range pf.listeners {
   408  		if err := l.Close(); err != nil {
   409  			runtime.HandleError(fmt.Errorf("error closing listener: %v", err))
   410  		}
   411  	}
   412  }
   413  
   414  // GetPorts will return the ports that were forwarded; this can be used to
   415  // retrieve the locally-bound port in cases where the input was port 0. This
   416  // function will signal an error if the Ready channel is nil or if the
   417  // listeners are not ready yet; this function will succeed after the Ready
   418  // channel has been closed.
   419  func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
   420  	if pf.Ready == nil {
   421  		return nil, fmt.Errorf("no Ready channel provided")
   422  	}
   423  	select {
   424  	case <-pf.Ready:
   425  		return pf.ports, nil
   426  	default:
   427  		return nil, fmt.Errorf("listeners not ready")
   428  	}
   429  }