github.com/grahambrereton-form3/tilt@v0.10.18/internal/k8s/portforward.go (about)

     1  package k8s
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"strconv"
     9  
    10  	v1 "k8s.io/client-go/kubernetes/typed/core/v1"
    11  	_ "k8s.io/client-go/plugin/pkg/client/auth/gcp" // registers gcp auth provider
    12  	"k8s.io/client-go/rest"
    13  	"k8s.io/client-go/tools/portforward"
    14  	"k8s.io/client-go/transport/spdy"
    15  
    16  	"github.com/windmilleng/tilt/pkg/logger"
    17  
    18  	"github.com/pkg/errors"
    19  )
    20  
    21  type PortForwardClient interface {
    22  	// Creates a new port-forwarder that's bound to the given context's lifecycle.
    23  	// When the context is canceled, the port-forwarder will close.
    24  	CreatePortForwarder(ctx context.Context, namespace Namespace, podID PodID, localPort int, remotePort int, host string) (PortForwarder, error)
    25  }
    26  
    27  type PortForwarder interface {
    28  	// The local port we're listening on.
    29  	LocalPort() int
    30  
    31  	// Listens on the configured port and forward all traffic to the container.
    32  	// Returns when the port-forwarder sees an unrecoverable error or
    33  	// when the context passed at creation is canceled.
    34  	ForwardPorts() error
    35  }
    36  
    37  type portForwarder struct {
    38  	*portforward.PortForwarder
    39  	localPort int
    40  }
    41  
    42  func (pf portForwarder) LocalPort() int {
    43  	return pf.localPort
    44  }
    45  
    46  func (k K8sClient) CreatePortForwarder(ctx context.Context, namespace Namespace, podID PodID, optionalLocalPort, remotePort int, host string) (PortForwarder, error) {
    47  	localPort := optionalLocalPort
    48  	if localPort == 0 {
    49  		// preferably, we'd set the localport to 0, and let the underlying function pick a port for us,
    50  		// to avoid the race condition potential of something else grabbing this port between
    51  		// the call to `getAvailablePort` and whenever `portForwarder` actually binds the port.
    52  		// the k8s client supports a local port of 0, and stores the actual local port assigned in a field,
    53  		// but unfortunately does not export that field, so there is no way for the caller to know which
    54  		// local port to talk to.
    55  		var err error
    56  		localPort, err = getAvailablePort()
    57  		if err != nil {
    58  			return nil, errors.Wrap(err, "failed to find an available local port")
    59  		}
    60  	}
    61  
    62  	return k.portForwardClient.CreatePortForwarder(ctx, namespace, podID, localPort, remotePort, host)
    63  }
    64  
    65  type portForwardClient struct {
    66  	config *rest.Config
    67  	core   v1.CoreV1Interface
    68  }
    69  
    70  func ProvidePortForwardClient(
    71  	maybeRESTConfig RESTConfigOrError,
    72  	maybeClientset ClientsetOrError) PortForwardClient {
    73  	if maybeRESTConfig.Error != nil {
    74  		return explodingPortForwardClient{error: maybeRESTConfig.Error}
    75  	}
    76  	if maybeClientset.Error != nil {
    77  		return explodingPortForwardClient{error: maybeClientset.Error}
    78  	}
    79  	return portForwardClient{
    80  		maybeRESTConfig.Config,
    81  		maybeClientset.Clientset.CoreV1(),
    82  	}
    83  }
    84  
    85  func (c portForwardClient) CreatePortForwarder(ctx context.Context, namespace Namespace, podID PodID, localPort int, remotePort int, host string) (PortForwarder, error) {
    86  	transport, upgrader, err := spdy.RoundTripperFor(c.config)
    87  	if err != nil {
    88  		return nil, errors.Wrap(err, "error getting roundtripper")
    89  	}
    90  
    91  	req := c.core.RESTClient().Post().
    92  		Resource("pods").
    93  		Namespace(namespace.String()).
    94  		Name(podID.String()).
    95  		SubResource("portforward")
    96  
    97  	dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL())
    98  	if err != nil {
    99  		return nil, errors.Wrap(err, "error creating dialer")
   100  	}
   101  
   102  	stopChan := make(chan struct{}, 1)
   103  	readyChan := make(chan struct{}, 1)
   104  
   105  	ports := []string{fmt.Sprintf("%d:%d", localPort, remotePort)}
   106  
   107  	var pf *portforward.PortForwarder
   108  	if host == "" {
   109  		pf, err = portforward.New(
   110  			dialer,
   111  			ports,
   112  			stopChan,
   113  			readyChan,
   114  			logger.Get(ctx).Writer(logger.DebugLvl),
   115  			logger.Get(ctx).Writer(logger.DebugLvl))
   116  	} else {
   117  		addresses := []string{host}
   118  		pf, err = portforward.NewOnAddresses(
   119  			dialer,
   120  			addresses,
   121  			ports,
   122  			stopChan,
   123  			readyChan,
   124  			logger.Get(ctx).Writer(logger.DebugLvl),
   125  			logger.Get(ctx).Writer(logger.DebugLvl))
   126  	}
   127  	if err != nil {
   128  		return nil, errors.Wrap(err, "error forwarding port")
   129  	}
   130  
   131  	go func() {
   132  		<-ctx.Done()
   133  		close(stopChan)
   134  	}()
   135  	return portForwarder{
   136  		PortForwarder: pf,
   137  		localPort:     localPort,
   138  	}, nil
   139  }
   140  
   141  func getAvailablePort() (int, error) {
   142  	l, err := net.Listen("tcp", ":0")
   143  	if err != nil {
   144  		return 0, err
   145  	}
   146  	defer func() {
   147  		e := l.Close()
   148  		if err == nil {
   149  			err = e
   150  		}
   151  	}()
   152  
   153  	_, p, err := net.SplitHostPort(l.Addr().String())
   154  	if err != nil {
   155  		return 0, err
   156  	}
   157  	port, err := strconv.Atoi(p)
   158  	if err != nil {
   159  		return 0, err
   160  	}
   161  	return port, err
   162  }
   163  
   164  type explodingPortForwardClient struct {
   165  	error error
   166  }
   167  
   168  func (c explodingPortForwardClient) CreatePortForwarder(ctx context.Context, namespace Namespace, podID PodID, localPort int, remotePort int, host string) (PortForwarder, error) {
   169  	return nil, c.error
   170  }