github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/dnet/kpfconn.go (about)

     1  package dnet
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"sort"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	core "k8s.io/api/core/v1"
    17  	meta "k8s.io/apimachinery/pkg/apis/meta/v1"
    18  	"k8s.io/apimachinery/pkg/labels"
    19  	"k8s.io/apimachinery/pkg/util/httpstream"
    20  	"k8s.io/client-go/kubernetes"
    21  	"k8s.io/client-go/rest"
    22  	"k8s.io/client-go/tools/portforward"
    23  	"k8s.io/client-go/transport/spdy"
    24  	"k8s.io/kubectl/pkg/polymorphichelpers"
    25  	"k8s.io/kubectl/pkg/util"
    26  	"k8s.io/kubectl/pkg/util/podutils"
    27  
    28  	"github.com/datawire/dlib/dlog"
    29  )
    30  
    31  type k8sPortForwardDialer struct {
    32  	// static
    33  	logCtx        context.Context
    34  	k8sInterface  kubernetes.Interface
    35  	spdyTransport http.RoundTripper
    36  	spdyUpgrader  spdy.Upgrader
    37  
    38  	// state
    39  	nextRequestID int64
    40  	spdyStreamsMu sync.Mutex
    41  	spdyStreams   map[string]httpstream.Connection // key is "podname.namespace"
    42  }
    43  
    44  type DialerFunc func(context.Context, string) (net.Conn, error)
    45  
    46  type PortForwardDialer interface {
    47  	io.Closer
    48  	Dial(ctx context.Context, addr string) (net.Conn, error)
    49  	DialPod(ctx context.Context, name, namespace string, port uint16) (net.Conn, error)
    50  }
    51  
    52  // NewK8sPortForwardDialer returns a dialer function (matching the signature required by
    53  // grpc.WithContextDialer) that dials to a port on a Kubernetes Pod, in the manor of `kubectl
    54  // port-forward`.  It returns the direct connection to the apiserver; it does not establish a local
    55  // port being forwarded from or otherwise pump data over the connection.
    56  func NewK8sPortForwardDialer(logCtx context.Context, kubeConfig *rest.Config, k8sInterface kubernetes.Interface) (PortForwardDialer, error) {
    57  	if err := setKubernetesDefaults(kubeConfig); err != nil {
    58  		return nil, err
    59  	}
    60  	spdyTransport, spdyUpgrader, err := spdy.RoundTripperFor(kubeConfig)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	dialer := &k8sPortForwardDialer{
    65  		logCtx:        logCtx,
    66  		k8sInterface:  k8sInterface,
    67  		spdyTransport: spdyTransport,
    68  		spdyUpgrader:  spdyUpgrader,
    69  
    70  		spdyStreams: make(map[string]httpstream.Connection),
    71  	}
    72  	return dialer, nil
    73  }
    74  
    75  type podAddress struct {
    76  	name      string
    77  	namespace string
    78  	port      uint16
    79  }
    80  
    81  // Dial dials a port of something in the cluster.  The address format is
    82  // "[objkind/]objname[.objnamespace]:port".
    83  func (pf *k8sPortForwardDialer) Dial(ctx context.Context, addr string) (conn net.Conn, err error) {
    84  	var pod *podAddress
    85  	if pod, err = pf.resolve(ctx, addr); err == nil {
    86  		if conn, err = pf.dial(pod); err == nil {
    87  			return conn, nil
    88  		}
    89  	}
    90  	dlog.Errorf(pf.logCtx, "Error with k8sPortForwardDialer dial: %s", err)
    91  	return nil, err
    92  }
    93  
    94  func (pf *k8sPortForwardDialer) DialPod(_ context.Context, name, namespace string, podPortNumber uint16) (net.Conn, error) {
    95  	conn, err := pf.dial(&podAddress{name: name, namespace: namespace, port: podPortNumber})
    96  	if err != nil {
    97  		dlog.Errorf(pf.logCtx, "Error with k8sPortForwardDialer dial: %s", err)
    98  	}
    99  	return conn, err
   100  }
   101  
   102  func (pf *k8sPortForwardDialer) Close() error {
   103  	pf.spdyStreamsMu.Lock()
   104  	defer pf.spdyStreamsMu.Unlock()
   105  	for k, s := range pf.spdyStreams {
   106  		dlog.Errorf(pf.logCtx, "closing spdyStream: %s", k)
   107  		if err := s.Close(); err != nil {
   108  			dlog.Errorf(pf.logCtx, "failed to close spdyStream: %v", err)
   109  		}
   110  	}
   111  	return nil
   112  }
   113  
   114  func (pf *k8sPortForwardDialer) resolve(ctx context.Context, addr string) (*podAddress, error) {
   115  	var hostName, portName string
   116  	hostName, portName, err := net.SplitHostPort(addr)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	var objKind, objQName string
   122  	if slash := strings.Index(hostName, "/"); slash < 0 {
   123  		objKind = "Pod."
   124  		objQName = hostName
   125  	} else {
   126  		objKind = hostName[:slash]
   127  		objQName = hostName[slash+1:]
   128  	}
   129  	var objName, objNamespace string
   130  	if dot := strings.LastIndex(objQName, "."); dot < 0 {
   131  		objName = objQName
   132  		objNamespace = ""
   133  	} else {
   134  		objName = objQName[:dot]
   135  		objNamespace = objQName[dot+1:]
   136  	}
   137  
   138  	coreV1 := pf.k8sInterface.CoreV1()
   139  	if objKind == "svc" {
   140  		// Get the service.
   141  		svc, err := coreV1.Services(objNamespace).Get(ctx, objName, meta.GetOptions{})
   142  		if err != nil {
   143  			return nil, err
   144  		}
   145  		svcPortNumber, err := func() (int32, error) {
   146  			if svcPortNumber, err := strconv.Atoi(portName); err == nil {
   147  				return int32(svcPortNumber), nil
   148  			}
   149  			return util.LookupServicePortNumberByName(*svc, portName)
   150  		}()
   151  		if err != nil {
   152  			return nil, fmt.Errorf("cannot find service port in %s.%s: %v", objName, objNamespace, err)
   153  		}
   154  
   155  		// Resolve the Service to a Pod.
   156  		var selector labels.Selector
   157  		var podNS string
   158  		podNS, selector, err = polymorphichelpers.SelectorsForObject(svc)
   159  		if err != nil {
   160  			return nil, fmt.Errorf("cannot attach to %T: %v", svc, err)
   161  		}
   162  		timeout := func() time.Duration {
   163  			if deadline, ok := ctx.Deadline(); ok {
   164  				return time.Until(deadline)
   165  			}
   166  			// Fall back to the same default as --pod-running-timeout.
   167  			return time.Minute
   168  		}()
   169  
   170  		sortBy := func(pods []*core.Pod) sort.Interface { return sort.Reverse(podutils.ActivePods(pods)) }
   171  		pod, _, err := polymorphichelpers.GetFirstPod(coreV1, podNS, selector.String(), timeout, sortBy)
   172  		if err != nil {
   173  			return nil, fmt.Errorf("cannot find first pod for %s.%s: %v", objName, objNamespace, err)
   174  		}
   175  		containerPortNumber, err := util.LookupContainerPortNumberByServicePort(*svc, *pod, svcPortNumber)
   176  		if err != nil {
   177  			return nil, fmt.Errorf("cannot find first container port %s.%s: %v", pod.Name, pod.Namespace, err)
   178  		}
   179  		return &podAddress{name: pod.Name, namespace: pod.Namespace, port: uint16(containerPortNumber)}, nil
   180  	}
   181  
   182  	if p, err := strconv.Atoi(portName); err == nil {
   183  		return &podAddress{name: objName, namespace: objNamespace, port: uint16(p)}, nil
   184  	}
   185  
   186  	// Get the pod.
   187  	pod, err := coreV1.Pods(objNamespace).Get(ctx, objName, meta.GetOptions{})
   188  	if err != nil {
   189  		return nil, fmt.Errorf("unable to get %s %s.%s: %w", objKind, objName, objNamespace, err)
   190  	}
   191  	pn, err := util.LookupContainerPortNumberByName(*pod, portName)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	return &podAddress{
   196  		name:      pod.Name,
   197  		namespace: pod.Namespace,
   198  		port:      uint16(pn),
   199  	}, nil
   200  }
   201  
   202  func (pf *k8sPortForwardDialer) spdyStream(pod *podAddress) (httpstream.Connection, error) {
   203  	cacheKey := pod.name + "." + pod.namespace
   204  	pf.spdyStreamsMu.Lock()
   205  	defer pf.spdyStreamsMu.Unlock()
   206  	if spdyStream, ok := pf.spdyStreams[cacheKey]; ok {
   207  		return spdyStream, nil
   208  	}
   209  
   210  	// Most of the Kubernetes API is HTTP/2+gRPC, not SPDY; and so that's what client-go mostly
   211  	// helps us with.  So in order to get the URL to use in the SPDY request, we're going to
   212  	// build a standard Kubernetes HTTP/2 *rest.Request and extract the URL from that, and
   213  	// discard the rest of the *rest.Request.
   214  	reqURL := pf.k8sInterface.CoreV1().RESTClient().
   215  		Post().
   216  		Resource("pods").
   217  		Namespace(pod.namespace).
   218  		Name(pod.name).
   219  		SubResource("portforward").
   220  		URL()
   221  
   222  	// Don't bother caching dialers in .pf, they're just stateless utility structures.
   223  	spdyDialer := spdy.NewDialer(pf.spdyUpgrader, &http.Client{Transport: pf.spdyTransport}, http.MethodPost, reqURL)
   224  
   225  	dlog.Debugf(pf.logCtx, "k8sPortForwardDialer.spdyDial(ctx, Pod./%s.%s)", pod.name, pod.namespace)
   226  
   227  	spdyStream, _, err := spdyDialer.Dial(portforward.PortForwardProtocolV1Name)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	pf.spdyStreams[cacheKey] = spdyStream
   233  	go func() {
   234  		<-spdyStream.CloseChan()
   235  		pf.spdyStreamsMu.Lock()
   236  		delete(pf.spdyStreams, cacheKey)
   237  		pf.spdyStreamsMu.Unlock()
   238  	}()
   239  
   240  	return spdyStream, nil
   241  }
   242  
   243  func (pf *k8sPortForwardDialer) dial(pod *podAddress) (conn *kpfConn, err error) {
   244  	dlog.Debugf(pf.logCtx, "k8sPortForwardDialer.dial(ctx, Pod./%s.%s, %d)",
   245  		pod.name,
   246  		pod.namespace,
   247  		pod.port)
   248  
   249  	// All port-forwards to the same Pod get multiplexed over the same SPDY stream.
   250  	spdyStream, err := pf.spdyStream(pod)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  	defer func() {
   255  		if err != nil {
   256  			pf.spdyStreamsMu.Lock()
   257  			delete(pf.spdyStreams, pod.name+"."+pod.namespace)
   258  			pf.spdyStreamsMu.Unlock()
   259  		}
   260  	}()
   261  
   262  	requestID := atomic.AddInt64(&pf.nextRequestID, 1) - 1
   263  
   264  	headers := http.Header{}
   265  	headers.Set(core.PortHeader, strconv.FormatInt(int64(pod.port), 10))
   266  	headers.Set(core.PortForwardRequestIDHeader, strconv.FormatInt(requestID, 10))
   267  
   268  	// Quick note: spdyStream.CreateStream returns httpstream.Stream objects.  These have
   269  	// confusing method names compared to net.Conn objects:
   270  	//
   271  	//   |                            | net.Conn     | httpstream.Stream |
   272  	//   |----------------------------+--------------+-------------------|
   273  	//   | close both ends            | Close()      | Reset()           |
   274  	//   | close just the 'read' end  | CloseRead()  | -                 |
   275  	//   | close just the 'write' end | CloseWrite() | Close()           |
   276  
   277  	headers.Set(core.StreamType, core.StreamTypeError)
   278  	errorStream, err := spdyStream.CreateStream(headers)
   279  	if err != nil {
   280  		return nil, fmt.Errorf("create port-forward error stream: %w", err)
   281  	}
   282  	// errorStream is read-only, we can go ahead and close the 'write' end.
   283  	_ = errorStream.Close()
   284  
   285  	headers.Set(core.StreamType, core.StreamTypeData)
   286  	dataStream, err := spdyStream.CreateStream(headers)
   287  	if err != nil {
   288  		return nil, fmt.Errorf("create port-forward data stream: %w", err)
   289  	}
   290  
   291  	conn = &kpfConn{
   292  		Stream:      dataStream,
   293  		remoteAddr:  net.JoinHostPort(pod.name+"."+pod.namespace, strconv.FormatInt(int64(pod.port), 10)),
   294  		errorStream: errorStream,
   295  	}
   296  	conn.init()
   297  	return conn, nil
   298  }
   299  
   300  type kpfConn struct {
   301  	httpstream.Stream
   302  
   303  	// Configuration
   304  
   305  	remoteAddr string
   306  	// See the above comment about httpstream.Stream close semantics.
   307  	errorStream httpstream.Stream
   308  
   309  	// Internal data
   310  
   311  	oobErrCh chan struct{}
   312  	oobErr   error // may only access .oobErr if .oobErrCh is closed (unless you're .oobWorker()).
   313  
   314  	readErr  error
   315  	writeErr error
   316  }
   317  
   318  func (c *kpfConn) SetDeadline(t time.Time) error {
   319  	if dataConn, ok := c.Stream.(net.Conn); ok {
   320  		return dataConn.SetDeadline(t)
   321  	}
   322  	return nil
   323  }
   324  
   325  func (c *kpfConn) SetReadDeadline(t time.Time) error {
   326  	if dataConn, ok := c.Stream.(net.Conn); ok {
   327  		return dataConn.SetReadDeadline(t)
   328  	}
   329  	return nil
   330  }
   331  
   332  func (c *kpfConn) SetWriteDeadline(t time.Time) error {
   333  	if dataConn, ok := c.Stream.(net.Conn); ok {
   334  		return dataConn.SetWriteDeadline(t)
   335  	}
   336  	return nil
   337  }
   338  
   339  func (c *kpfConn) init() {
   340  	c.oobErrCh = make(chan struct{})
   341  	go c.oobWorker()
   342  }
   343  
   344  func (c *kpfConn) oobWorker() {
   345  	msg, err := io.ReadAll(c.errorStream)
   346  	switch {
   347  	case err != nil:
   348  		c.oobErr = fmt.Errorf("reading error error stream: %w", err)
   349  	case len(msg) > 0:
   350  		c.oobErr = fmt.Errorf("error stream: %s", msg)
   351  	}
   352  	close(c.oobErrCh)
   353  }
   354  
   355  func (c *kpfConn) Read(data []byte) (int, error) {
   356  	switch {
   357  	case c.readErr != nil:
   358  		return 0, c.readErr
   359  	case isClosedChan(c.oobErrCh) && c.oobErr != nil:
   360  		return 0, c.oobErr
   361  	default:
   362  		n, err := c.Stream.Read(data)
   363  		if err != nil {
   364  			c.readErr = err
   365  		}
   366  		return n, err
   367  	}
   368  }
   369  
   370  func (c *kpfConn) Write(b []byte) (int, error) {
   371  	switch {
   372  	case c.writeErr != nil:
   373  		return 0, c.writeErr
   374  	case isClosedChan(c.oobErrCh) && c.oobErr != nil:
   375  		return 0, c.oobErr
   376  	default:
   377  		n, err := c.Stream.Write(b)
   378  		if err != nil {
   379  			c.writeErr = err
   380  		}
   381  		return n, err
   382  	}
   383  }
   384  
   385  func (c *kpfConn) Close() error {
   386  	closeErr := c.Reset()
   387  	<-c.oobErrCh
   388  	if c.oobErr != nil {
   389  		return c.oobErr
   390  	}
   391  	if closeErr != nil {
   392  		return closeErr
   393  	}
   394  	return nil
   395  }
   396  
   397  // LocalAddr implements UnbufferedConn.
   398  func (c *kpfConn) LocalAddr() net.Addr {
   399  	return Addr{
   400  		Net:  "kubectl-port-forward",
   401  		Addr: "client",
   402  	}
   403  }
   404  
   405  // RemoteAddr implements UnbufferedConn.
   406  func (c *kpfConn) RemoteAddr() net.Addr {
   407  	return Addr{
   408  		Net:  "kubectl-port-forward",
   409  		Addr: c.remoteAddr,
   410  	}
   411  }