github.com/tilt-dev/tilt@v0.33.15-0.20240515162809-0a22ed45d8a0/internal/k8s/portforward/portforward.go (about) 1 /* 2 Copyright 2015 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package portforward 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "io" 24 "net" 25 "net/http" 26 "sort" 27 "strconv" 28 "strings" 29 "sync" 30 31 v1 "k8s.io/api/core/v1" 32 "k8s.io/apimachinery/pkg/util/httpstream" 33 netutils "k8s.io/utils/net" 34 35 "github.com/tilt-dev/tilt/pkg/logger" 36 ) 37 38 // PortForwardProtocolV1Name is the subprotocol used for port forwarding. 39 // TODO move to API machinery and re-unify with kubelet/server/portfoward 40 const PortForwardProtocolV1Name = "portforward.k8s.io" 41 42 // PortForwarder knows how to listen for local connections and forward them to 43 // a remote pod via an upgraded HTTP request. 44 type PortForwarder struct { 45 ctx context.Context 46 addresses []listenAddress 47 ports []ForwardedPort 48 49 dialer httpstream.Dialer 50 streamConn httpstream.Connection 51 errorHandler *errorHandler 52 listeners []io.Closer 53 Ready chan struct{} 54 requestIDLock sync.Mutex 55 requestID int 56 } 57 58 // ForwardedPort contains a Local:Remote port pairing. 59 type ForwardedPort struct { 60 Local uint16 61 Remote uint16 62 } 63 64 /* 65 valid port specifications: 66 67 5000 68 - forwards from localhost:5000 to pod:5000 69 70 8888:5000 71 - forwards from localhost:8888 to pod:5000 72 73 0:5000 74 :5000 75 - selects a random available local port, 76 forwards from localhost:<random port> to pod:5000 77 */ 78 func parsePorts(ports []string) ([]ForwardedPort, error) { 79 var forwards []ForwardedPort 80 for _, portString := range ports { 81 parts := strings.Split(portString, ":") 82 var localString, remoteString string 83 if len(parts) == 1 { 84 localString = parts[0] 85 remoteString = parts[0] 86 } else if len(parts) == 2 { 87 localString = parts[0] 88 if localString == "" { 89 // support :5000 90 localString = "0" 91 } 92 remoteString = parts[1] 93 } else { 94 return nil, fmt.Errorf("invalid port format '%s'", portString) 95 } 96 97 localPort, err := strconv.ParseUint(localString, 10, 16) 98 if err != nil { 99 return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err) 100 } 101 102 remotePort, err := strconv.ParseUint(remoteString, 10, 16) 103 if err != nil { 104 return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err) 105 } 106 if remotePort == 0 { 107 return nil, fmt.Errorf("remote port must be > 0") 108 } 109 110 forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)}) 111 } 112 113 return forwards, nil 114 } 115 116 type listenAddress struct { 117 address string 118 protocol string 119 failureMode string 120 } 121 122 func parseAddresses(addressesToParse []string) ([]listenAddress, error) { 123 var addresses []listenAddress 124 parsed := make(map[string]listenAddress) 125 for _, address := range addressesToParse { 126 if address == "localhost" { 127 if _, exists := parsed["127.0.0.1"]; !exists { 128 ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"} 129 parsed[ip.address] = ip 130 } 131 if _, exists := parsed["::1"]; !exists { 132 ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"} 133 parsed[ip.address] = ip 134 } 135 } else if netutils.ParseIPSloppy(address).To4() != nil { 136 parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"} 137 } else if netutils.ParseIPSloppy(address) != nil { 138 parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"} 139 } else { 140 return nil, fmt.Errorf("%s is not a valid IP", address) 141 } 142 } 143 addresses = make([]listenAddress, len(parsed)) 144 id := 0 145 for _, v := range parsed { 146 addresses[id] = v 147 id++ 148 } 149 // Sort addresses before returning to get a stable order 150 sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address }) 151 152 return addresses, nil 153 } 154 155 // New creates a new PortForwarder with localhost listen addresses. 156 func New(ctx context.Context, dialer httpstream.Dialer, ports []string, readyChan chan struct{}) (*PortForwarder, error) { 157 return NewOnAddresses(ctx, dialer, []string{"localhost"}, ports, readyChan) 158 } 159 160 // NewOnAddresses creates a new PortForwarder with custom listen addresses. 161 func NewOnAddresses(ctx context.Context, dialer httpstream.Dialer, addresses []string, ports []string, readyChan chan struct{}) (*PortForwarder, error) { 162 if len(addresses) == 0 { 163 return nil, errors.New("you must specify at least 1 address") 164 } 165 parsedAddresses, err := parseAddresses(addresses) 166 if err != nil { 167 return nil, err 168 } 169 if len(ports) == 0 { 170 return nil, errors.New("you must specify at least 1 port") 171 } 172 parsedPorts, err := parsePorts(ports) 173 if err != nil { 174 return nil, err 175 } 176 return &PortForwarder{ 177 ctx: ctx, 178 dialer: dialer, 179 addresses: parsedAddresses, 180 ports: parsedPorts, 181 Ready: readyChan, 182 }, nil 183 } 184 185 func (pf *PortForwarder) Addresses() []string { 186 var addresses []string 187 for _, la := range pf.addresses { 188 addresses = append(addresses, la.address) 189 } 190 return addresses 191 } 192 193 // ForwardPorts formats and executes a port forwarding request. The connection will remain 194 // open until stopChan is closed. 195 func (pf *PortForwarder) ForwardPorts() error { 196 defer pf.Close() 197 198 var err error 199 pf.streamConn, _, err = pf.dialer.Dial(PortForwardProtocolV1Name) 200 if err != nil { 201 return fmt.Errorf("error upgrading connection: %s", err) 202 } 203 defer pf.streamConn.Close() 204 205 return pf.forward() 206 } 207 208 // forward dials the remote host specific in req, upgrades the request, starts 209 // listeners for each port specified in ports, and forwards local connections 210 // to the remote host via streams. 211 // 212 // Returns an error if any of the local ports aren't available. 213 func (pf *PortForwarder) forward() error { 214 var err error 215 pf.errorHandler = newErrorHandler() 216 defer pf.errorHandler.Close() 217 218 for i := range pf.ports { 219 port := &pf.ports[i] 220 err = pf.listenOnPort(port) 221 if err != nil { 222 return fmt.Errorf("Unable to listen on port %d: %v", port.Local, err) 223 } 224 } 225 226 if pf.Ready != nil { 227 close(pf.Ready) 228 } 229 230 // wait for interrupt or conn closure 231 select { 232 case err := <-pf.errorHandler.Done(): 233 return err 234 235 case <-pf.ctx.Done(): 236 case <-pf.streamConn.CloseChan(): 237 logger.Get(pf.ctx).Debugf("lost connection to pod") 238 } 239 240 return nil 241 } 242 243 // listenOnPort delegates listener creation and waits for connections on requested bind addresses. 244 // An error is raised based on address groups (default and localhost) and their failure modes 245 func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error { 246 var errors []error 247 failCounters := make(map[string]int, 2) 248 successCounters := make(map[string]int, 2) 249 for _, addr := range pf.addresses { 250 err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address) 251 if err != nil { 252 errors = append(errors, err) 253 failCounters[addr.failureMode]++ 254 } else { 255 successCounters[addr.failureMode]++ 256 } 257 } 258 if successCounters["all"] == 0 && failCounters["all"] > 0 { 259 return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors) 260 } 261 if failCounters["any"] > 0 { 262 return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors) 263 } 264 return nil 265 } 266 267 // listenOnPortAndAddress delegates listener creation and waits for new connections 268 // in the background f 269 func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error { 270 listener, err := pf.getListener(protocol, address, port) 271 if err != nil { 272 return err 273 } 274 pf.listeners = append(pf.listeners, listener) 275 go pf.waitForConnection(listener, *port) 276 return nil 277 } 278 279 // getListener creates a listener on the interface targeted by the given hostname on the given port with 280 // the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6 281 func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) { 282 listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local)))) 283 if err != nil { 284 return nil, fmt.Errorf("unable to create listener: Error %s", err) 285 } 286 listenerAddress := listener.Addr().String() 287 host, localPort, _ := net.SplitHostPort(listenerAddress) 288 localPortUInt, err := strconv.ParseUint(localPort, 10, 16) 289 290 if err != nil { 291 return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host) 292 } 293 port.Local = uint16(localPortUInt) 294 295 return listener, nil 296 } 297 298 // waitForConnection waits for new connections to listener and handles them in 299 // the background. 300 func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) { 301 for { 302 select { 303 case <-pf.streamConn.CloseChan(): 304 return 305 default: 306 conn, err := listener.Accept() 307 if err != nil { 308 // TODO consider using something like https://github.com/hydrogen18/stoppableListener? 309 if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") { 310 logger.Get(pf.ctx).Debugf("error accepting connection on port %d: %v", port.Local, err) 311 } 312 return 313 } 314 go pf.handleConnection(conn, port) 315 } 316 } 317 } 318 319 func (pf *PortForwarder) nextRequestID() int { 320 pf.requestIDLock.Lock() 321 defer pf.requestIDLock.Unlock() 322 id := pf.requestID 323 pf.requestID++ 324 return id 325 } 326 327 // handleConnection copies data between the local connection and the stream to 328 // the remote server. 329 func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { 330 defer conn.Close() 331 332 requestID := pf.nextRequestID() 333 334 // create error stream 335 headers := http.Header{} 336 headers.Set(v1.StreamType, v1.StreamTypeError) 337 headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote)) 338 headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID)) 339 errorStream, err := pf.streamConn.CreateStream(headers) 340 if err != nil { 341 // If CreateStream fails, stop the whole portforwarder, because this might 342 // mean the whole streamConn is wedged. The PortForward reconciler will backoff 343 // and re-create the connection. 344 pf.errorHandler.Stop(fmt.Errorf("creating stream: %v", err)) 345 return 346 } 347 // we're not writing to this stream 348 errorStream.Close() 349 350 errorChan := make(chan error) 351 go func() { 352 message, err := io.ReadAll(errorStream) 353 switch { 354 case err != nil: 355 errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err) 356 case len(message) > 0: 357 errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message)) 358 } 359 close(errorChan) 360 }() 361 362 // create data stream 363 headers.Set(v1.StreamType, v1.StreamTypeData) 364 dataStream, err := pf.streamConn.CreateStream(headers) 365 if err != nil { 366 // If CreateStream fails, stop the whole portforwarder, because this might 367 // mean the whole streamConn is wedged. The PortForward reconciler will backoff 368 // and re-create the connection. 369 pf.errorHandler.Stop(fmt.Errorf("creating stream: %v", err)) 370 return 371 } 372 373 localError := make(chan struct{}) 374 remoteDone := make(chan struct{}) 375 376 go func() { 377 // Copy from the remote side to the local port. 378 if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { 379 logger.Get(pf.ctx).Debugf("error copying from remote stream to local connection: %v", err) 380 } 381 382 // inform the select below that the remote copy is done 383 close(remoteDone) 384 }() 385 386 go func() { 387 // inform server we're not sending any more data after copy unblocks 388 defer dataStream.Close() 389 390 // Copy from the local port to the remote side. 391 if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { 392 logger.Get(pf.ctx).Debugf("error copying from local connection to remote stream: %v", err) 393 // break out of the select below without waiting for the other copy to finish 394 close(localError) 395 } 396 }() 397 398 // wait for either a local->remote error or for copying from remote->local to finish 399 select { 400 case <-remoteDone: 401 case <-localError: 402 } 403 404 // always expect something on errorChan (it may be nil) 405 err = <-errorChan 406 if err != nil { 407 logger.Get(pf.ctx).Debugf("%v", err) 408 pf.streamConn.Close() 409 } 410 } 411 412 // Close stops all listeners of PortForwarder. 413 func (pf *PortForwarder) Close() { 414 // stop all listeners 415 for _, l := range pf.listeners { 416 if err := l.Close(); err != nil { 417 logger.Get(pf.ctx).Debugf("error closing listener: %v", err) 418 } 419 } 420 } 421 422 // GetPorts will return the ports that were forwarded; this can be used to 423 // retrieve the locally-bound port in cases where the input was port 0. This 424 // function will signal an error if the Ready channel is nil or if the 425 // listeners are not ready yet; this function will succeed after the Ready 426 // channel has been closed. 427 func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) { 428 if pf.Ready == nil { 429 return nil, fmt.Errorf("no Ready channel provided") 430 } 431 select { 432 case <-pf.Ready: 433 return pf.ports, nil 434 default: 435 return nil, fmt.Errorf("listeners not ready") 436 } 437 }