github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/dnet/kpfconn.go (about) 1 package dnet 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net" 8 "net/http" 9 "sort" 10 "strconv" 11 "strings" 12 "sync" 13 "sync/atomic" 14 "time" 15 16 core "k8s.io/api/core/v1" 17 meta "k8s.io/apimachinery/pkg/apis/meta/v1" 18 "k8s.io/apimachinery/pkg/labels" 19 "k8s.io/apimachinery/pkg/util/httpstream" 20 "k8s.io/client-go/kubernetes" 21 "k8s.io/client-go/rest" 22 "k8s.io/client-go/tools/portforward" 23 "k8s.io/client-go/transport/spdy" 24 "k8s.io/kubectl/pkg/polymorphichelpers" 25 "k8s.io/kubectl/pkg/util" 26 "k8s.io/kubectl/pkg/util/podutils" 27 28 "github.com/datawire/dlib/dlog" 29 ) 30 31 type k8sPortForwardDialer struct { 32 // static 33 logCtx context.Context 34 k8sInterface kubernetes.Interface 35 spdyTransport http.RoundTripper 36 spdyUpgrader spdy.Upgrader 37 38 // state 39 nextRequestID int64 40 spdyStreamsMu sync.Mutex 41 spdyStreams map[string]httpstream.Connection // key is "podname.namespace" 42 } 43 44 type DialerFunc func(context.Context, string) (net.Conn, error) 45 46 type PortForwardDialer interface { 47 io.Closer 48 Dial(ctx context.Context, addr string) (net.Conn, error) 49 DialPod(ctx context.Context, name, namespace string, port uint16) (net.Conn, error) 50 } 51 52 // NewK8sPortForwardDialer returns a dialer function (matching the signature required by 53 // grpc.WithContextDialer) that dials to a port on a Kubernetes Pod, in the manor of `kubectl 54 // port-forward`. It returns the direct connection to the apiserver; it does not establish a local 55 // port being forwarded from or otherwise pump data over the connection. 56 func NewK8sPortForwardDialer(logCtx context.Context, kubeConfig *rest.Config, k8sInterface kubernetes.Interface) (PortForwardDialer, error) { 57 if err := setKubernetesDefaults(kubeConfig); err != nil { 58 return nil, err 59 } 60 spdyTransport, spdyUpgrader, err := spdy.RoundTripperFor(kubeConfig) 61 if err != nil { 62 return nil, err 63 } 64 dialer := &k8sPortForwardDialer{ 65 logCtx: logCtx, 66 k8sInterface: k8sInterface, 67 spdyTransport: spdyTransport, 68 spdyUpgrader: spdyUpgrader, 69 70 spdyStreams: make(map[string]httpstream.Connection), 71 } 72 return dialer, nil 73 } 74 75 type podAddress struct { 76 name string 77 namespace string 78 port uint16 79 } 80 81 // Dial dials a port of something in the cluster. The address format is 82 // "[objkind/]objname[.objnamespace]:port". 83 func (pf *k8sPortForwardDialer) Dial(ctx context.Context, addr string) (conn net.Conn, err error) { 84 var pod *podAddress 85 if pod, err = pf.resolve(ctx, addr); err == nil { 86 if conn, err = pf.dial(pod); err == nil { 87 return conn, nil 88 } 89 } 90 dlog.Errorf(pf.logCtx, "Error with k8sPortForwardDialer dial: %s", err) 91 return nil, err 92 } 93 94 func (pf *k8sPortForwardDialer) DialPod(_ context.Context, name, namespace string, podPortNumber uint16) (net.Conn, error) { 95 conn, err := pf.dial(&podAddress{name: name, namespace: namespace, port: podPortNumber}) 96 if err != nil { 97 dlog.Errorf(pf.logCtx, "Error with k8sPortForwardDialer dial: %s", err) 98 } 99 return conn, err 100 } 101 102 func (pf *k8sPortForwardDialer) Close() error { 103 pf.spdyStreamsMu.Lock() 104 defer pf.spdyStreamsMu.Unlock() 105 for k, s := range pf.spdyStreams { 106 dlog.Errorf(pf.logCtx, "closing spdyStream: %s", k) 107 if err := s.Close(); err != nil { 108 dlog.Errorf(pf.logCtx, "failed to close spdyStream: %v", err) 109 } 110 } 111 return nil 112 } 113 114 func (pf *k8sPortForwardDialer) resolve(ctx context.Context, addr string) (*podAddress, error) { 115 var hostName, portName string 116 hostName, portName, err := net.SplitHostPort(addr) 117 if err != nil { 118 return nil, err 119 } 120 121 var objKind, objQName string 122 if slash := strings.Index(hostName, "/"); slash < 0 { 123 objKind = "Pod." 124 objQName = hostName 125 } else { 126 objKind = hostName[:slash] 127 objQName = hostName[slash+1:] 128 } 129 var objName, objNamespace string 130 if dot := strings.LastIndex(objQName, "."); dot < 0 { 131 objName = objQName 132 objNamespace = "" 133 } else { 134 objName = objQName[:dot] 135 objNamespace = objQName[dot+1:] 136 } 137 138 coreV1 := pf.k8sInterface.CoreV1() 139 if objKind == "svc" { 140 // Get the service. 141 svc, err := coreV1.Services(objNamespace).Get(ctx, objName, meta.GetOptions{}) 142 if err != nil { 143 return nil, err 144 } 145 svcPortNumber, err := func() (int32, error) { 146 if svcPortNumber, err := strconv.Atoi(portName); err == nil { 147 return int32(svcPortNumber), nil 148 } 149 return util.LookupServicePortNumberByName(*svc, portName) 150 }() 151 if err != nil { 152 return nil, fmt.Errorf("cannot find service port in %s.%s: %v", objName, objNamespace, err) 153 } 154 155 // Resolve the Service to a Pod. 156 var selector labels.Selector 157 var podNS string 158 podNS, selector, err = polymorphichelpers.SelectorsForObject(svc) 159 if err != nil { 160 return nil, fmt.Errorf("cannot attach to %T: %v", svc, err) 161 } 162 timeout := func() time.Duration { 163 if deadline, ok := ctx.Deadline(); ok { 164 return time.Until(deadline) 165 } 166 // Fall back to the same default as --pod-running-timeout. 167 return time.Minute 168 }() 169 170 sortBy := func(pods []*core.Pod) sort.Interface { return sort.Reverse(podutils.ActivePods(pods)) } 171 pod, _, err := polymorphichelpers.GetFirstPod(coreV1, podNS, selector.String(), timeout, sortBy) 172 if err != nil { 173 return nil, fmt.Errorf("cannot find first pod for %s.%s: %v", objName, objNamespace, err) 174 } 175 containerPortNumber, err := util.LookupContainerPortNumberByServicePort(*svc, *pod, svcPortNumber) 176 if err != nil { 177 return nil, fmt.Errorf("cannot find first container port %s.%s: %v", pod.Name, pod.Namespace, err) 178 } 179 return &podAddress{name: pod.Name, namespace: pod.Namespace, port: uint16(containerPortNumber)}, nil 180 } 181 182 if p, err := strconv.Atoi(portName); err == nil { 183 return &podAddress{name: objName, namespace: objNamespace, port: uint16(p)}, nil 184 } 185 186 // Get the pod. 187 pod, err := coreV1.Pods(objNamespace).Get(ctx, objName, meta.GetOptions{}) 188 if err != nil { 189 return nil, fmt.Errorf("unable to get %s %s.%s: %w", objKind, objName, objNamespace, err) 190 } 191 pn, err := util.LookupContainerPortNumberByName(*pod, portName) 192 if err != nil { 193 return nil, err 194 } 195 return &podAddress{ 196 name: pod.Name, 197 namespace: pod.Namespace, 198 port: uint16(pn), 199 }, nil 200 } 201 202 func (pf *k8sPortForwardDialer) spdyStream(pod *podAddress) (httpstream.Connection, error) { 203 cacheKey := pod.name + "." + pod.namespace 204 pf.spdyStreamsMu.Lock() 205 defer pf.spdyStreamsMu.Unlock() 206 if spdyStream, ok := pf.spdyStreams[cacheKey]; ok { 207 return spdyStream, nil 208 } 209 210 // Most of the Kubernetes API is HTTP/2+gRPC, not SPDY; and so that's what client-go mostly 211 // helps us with. So in order to get the URL to use in the SPDY request, we're going to 212 // build a standard Kubernetes HTTP/2 *rest.Request and extract the URL from that, and 213 // discard the rest of the *rest.Request. 214 reqURL := pf.k8sInterface.CoreV1().RESTClient(). 215 Post(). 216 Resource("pods"). 217 Namespace(pod.namespace). 218 Name(pod.name). 219 SubResource("portforward"). 220 URL() 221 222 // Don't bother caching dialers in .pf, they're just stateless utility structures. 223 spdyDialer := spdy.NewDialer(pf.spdyUpgrader, &http.Client{Transport: pf.spdyTransport}, http.MethodPost, reqURL) 224 225 dlog.Debugf(pf.logCtx, "k8sPortForwardDialer.spdyDial(ctx, Pod./%s.%s)", pod.name, pod.namespace) 226 227 spdyStream, _, err := spdyDialer.Dial(portforward.PortForwardProtocolV1Name) 228 if err != nil { 229 return nil, err 230 } 231 232 pf.spdyStreams[cacheKey] = spdyStream 233 go func() { 234 <-spdyStream.CloseChan() 235 pf.spdyStreamsMu.Lock() 236 delete(pf.spdyStreams, cacheKey) 237 pf.spdyStreamsMu.Unlock() 238 }() 239 240 return spdyStream, nil 241 } 242 243 func (pf *k8sPortForwardDialer) dial(pod *podAddress) (conn *kpfConn, err error) { 244 dlog.Debugf(pf.logCtx, "k8sPortForwardDialer.dial(ctx, Pod./%s.%s, %d)", 245 pod.name, 246 pod.namespace, 247 pod.port) 248 249 // All port-forwards to the same Pod get multiplexed over the same SPDY stream. 250 spdyStream, err := pf.spdyStream(pod) 251 if err != nil { 252 return nil, err 253 } 254 defer func() { 255 if err != nil { 256 pf.spdyStreamsMu.Lock() 257 delete(pf.spdyStreams, pod.name+"."+pod.namespace) 258 pf.spdyStreamsMu.Unlock() 259 } 260 }() 261 262 requestID := atomic.AddInt64(&pf.nextRequestID, 1) - 1 263 264 headers := http.Header{} 265 headers.Set(core.PortHeader, strconv.FormatInt(int64(pod.port), 10)) 266 headers.Set(core.PortForwardRequestIDHeader, strconv.FormatInt(requestID, 10)) 267 268 // Quick note: spdyStream.CreateStream returns httpstream.Stream objects. These have 269 // confusing method names compared to net.Conn objects: 270 // 271 // | | net.Conn | httpstream.Stream | 272 // |----------------------------+--------------+-------------------| 273 // | close both ends | Close() | Reset() | 274 // | close just the 'read' end | CloseRead() | - | 275 // | close just the 'write' end | CloseWrite() | Close() | 276 277 headers.Set(core.StreamType, core.StreamTypeError) 278 errorStream, err := spdyStream.CreateStream(headers) 279 if err != nil { 280 return nil, fmt.Errorf("create port-forward error stream: %w", err) 281 } 282 // errorStream is read-only, we can go ahead and close the 'write' end. 283 _ = errorStream.Close() 284 285 headers.Set(core.StreamType, core.StreamTypeData) 286 dataStream, err := spdyStream.CreateStream(headers) 287 if err != nil { 288 return nil, fmt.Errorf("create port-forward data stream: %w", err) 289 } 290 291 conn = &kpfConn{ 292 Stream: dataStream, 293 remoteAddr: net.JoinHostPort(pod.name+"."+pod.namespace, strconv.FormatInt(int64(pod.port), 10)), 294 errorStream: errorStream, 295 } 296 conn.init() 297 return conn, nil 298 } 299 300 type kpfConn struct { 301 httpstream.Stream 302 303 // Configuration 304 305 remoteAddr string 306 // See the above comment about httpstream.Stream close semantics. 307 errorStream httpstream.Stream 308 309 // Internal data 310 311 oobErrCh chan struct{} 312 oobErr error // may only access .oobErr if .oobErrCh is closed (unless you're .oobWorker()). 313 314 readErr error 315 writeErr error 316 } 317 318 func (c *kpfConn) SetDeadline(t time.Time) error { 319 if dataConn, ok := c.Stream.(net.Conn); ok { 320 return dataConn.SetDeadline(t) 321 } 322 return nil 323 } 324 325 func (c *kpfConn) SetReadDeadline(t time.Time) error { 326 if dataConn, ok := c.Stream.(net.Conn); ok { 327 return dataConn.SetReadDeadline(t) 328 } 329 return nil 330 } 331 332 func (c *kpfConn) SetWriteDeadline(t time.Time) error { 333 if dataConn, ok := c.Stream.(net.Conn); ok { 334 return dataConn.SetWriteDeadline(t) 335 } 336 return nil 337 } 338 339 func (c *kpfConn) init() { 340 c.oobErrCh = make(chan struct{}) 341 go c.oobWorker() 342 } 343 344 func (c *kpfConn) oobWorker() { 345 msg, err := io.ReadAll(c.errorStream) 346 switch { 347 case err != nil: 348 c.oobErr = fmt.Errorf("reading error error stream: %w", err) 349 case len(msg) > 0: 350 c.oobErr = fmt.Errorf("error stream: %s", msg) 351 } 352 close(c.oobErrCh) 353 } 354 355 func (c *kpfConn) Read(data []byte) (int, error) { 356 switch { 357 case c.readErr != nil: 358 return 0, c.readErr 359 case isClosedChan(c.oobErrCh) && c.oobErr != nil: 360 return 0, c.oobErr 361 default: 362 n, err := c.Stream.Read(data) 363 if err != nil { 364 c.readErr = err 365 } 366 return n, err 367 } 368 } 369 370 func (c *kpfConn) Write(b []byte) (int, error) { 371 switch { 372 case c.writeErr != nil: 373 return 0, c.writeErr 374 case isClosedChan(c.oobErrCh) && c.oobErr != nil: 375 return 0, c.oobErr 376 default: 377 n, err := c.Stream.Write(b) 378 if err != nil { 379 c.writeErr = err 380 } 381 return n, err 382 } 383 } 384 385 func (c *kpfConn) Close() error { 386 closeErr := c.Reset() 387 <-c.oobErrCh 388 if c.oobErr != nil { 389 return c.oobErr 390 } 391 if closeErr != nil { 392 return closeErr 393 } 394 return nil 395 } 396 397 // LocalAddr implements UnbufferedConn. 398 func (c *kpfConn) LocalAddr() net.Addr { 399 return Addr{ 400 Net: "kubectl-port-forward", 401 Addr: "client", 402 } 403 } 404 405 // RemoteAddr implements UnbufferedConn. 406 func (c *kpfConn) RemoteAddr() net.Addr { 407 return Addr{ 408 Net: "kubectl-port-forward", 409 Addr: c.remoteAddr, 410 } 411 }