github.com/tilt-dev/tilt@v0.36.0/internal/k8s/portforward/portforward.go (about) 1 /* 2 Copyright 2015 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package portforward 18 19 import ( 20 "errors" 21 "fmt" 22 "io" 23 "net" 24 "net/http" 25 "sort" 26 "strconv" 27 "strings" 28 "sync" 29 30 v1 "k8s.io/api/core/v1" 31 "k8s.io/apimachinery/pkg/util/httpstream" 32 netutils "k8s.io/utils/net" 33 ) 34 35 // PortForwardProtocolV1Name is the subprotocol used for port forwarding. 36 // TODO move to API machinery and re-unify with kubelet/server/portfoward 37 const PortForwardProtocolV1Name = "portforward.k8s.io" 38 39 var ( 40 // error returned whenever we lost connection to a pod 41 ErrLostConnectionToPod = errors.New("lost connection to pod") 42 43 // set of error we're expecting during port-forwarding 44 networkClosedError = "use of closed network connection" 45 ) 46 47 // PortForwarder knows how to listen for local connections and forward them to 48 // a remote pod via an upgraded HTTP request. 49 type PortForwarder struct { 50 addresses []listenAddress 51 ports []ForwardedPort 52 stopChan <-chan struct{} 53 54 dialer httpstream.Dialer 55 streamConn httpstream.Connection 56 errorHandler *errorHandler 57 listeners []io.Closer 58 Ready chan struct{} 59 requestIDLock sync.Mutex 60 requestID int 61 out io.Writer 62 errOut io.Writer 63 } 64 65 // ForwardedPort contains a Local:Remote port pairing. 66 type ForwardedPort struct { 67 Local uint16 68 Remote uint16 69 } 70 71 /* 72 valid port specifications: 73 74 5000 75 - forwards from localhost:5000 to pod:5000 76 77 8888:5000 78 - forwards from localhost:8888 to pod:5000 79 80 0:5000 81 :5000 82 - selects a random available local port, 83 forwards from localhost:<random port> to pod:5000 84 */ 85 func parsePorts(ports []string) ([]ForwardedPort, error) { 86 var forwards []ForwardedPort 87 for _, portString := range ports { 88 parts := strings.Split(portString, ":") 89 var localString, remoteString string 90 if len(parts) == 1 { 91 localString = parts[0] 92 remoteString = parts[0] 93 } else if len(parts) == 2 { 94 localString = parts[0] 95 if localString == "" { 96 // support :5000 97 localString = "0" 98 } 99 remoteString = parts[1] 100 } else { 101 return nil, fmt.Errorf("invalid port format '%s'", portString) 102 } 103 104 localPort, err := strconv.ParseUint(localString, 10, 16) 105 if err != nil { 106 return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err) 107 } 108 109 remotePort, err := strconv.ParseUint(remoteString, 10, 16) 110 if err != nil { 111 return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err) 112 } 113 if remotePort == 0 { 114 return nil, fmt.Errorf("remote port must be > 0") 115 } 116 117 forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)}) 118 } 119 120 return forwards, nil 121 } 122 123 type listenAddress struct { 124 address string 125 protocol string 126 failureMode string 127 } 128 129 func parseAddresses(addressesToParse []string) ([]listenAddress, error) { 130 var addresses []listenAddress 131 parsed := make(map[string]listenAddress) 132 for _, address := range addressesToParse { 133 if address == "localhost" { 134 if _, exists := parsed["127.0.0.1"]; !exists { 135 ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"} 136 parsed[ip.address] = ip 137 } 138 if _, exists := parsed["::1"]; !exists { 139 ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"} 140 parsed[ip.address] = ip 141 } 142 } else if netutils.ParseIPSloppy(address).To4() != nil { 143 parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"} 144 } else if netutils.ParseIPSloppy(address) != nil { 145 parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"} 146 } else { 147 return nil, fmt.Errorf("%s is not a valid IP", address) 148 } 149 } 150 addresses = make([]listenAddress, len(parsed)) 151 id := 0 152 for _, v := range parsed { 153 addresses[id] = v 154 id++ 155 } 156 // Sort addresses before returning to get a stable order 157 sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address }) 158 159 return addresses, nil 160 } 161 162 // New creates a new PortForwarder with localhost listen addresses. 163 func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) { 164 return NewOnAddresses(dialer, []string{"localhost"}, ports, stopChan, readyChan, out, errOut) 165 } 166 167 // NewOnAddresses creates a new PortForwarder with custom listen addresses. 168 func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) { 169 if len(addresses) == 0 { 170 return nil, errors.New("you must specify at least 1 address") 171 } 172 parsedAddresses, err := parseAddresses(addresses) 173 if err != nil { 174 return nil, err 175 } 176 if len(ports) == 0 { 177 return nil, errors.New("you must specify at least 1 port") 178 } 179 parsedPorts, err := parsePorts(ports) 180 if err != nil { 181 return nil, err 182 } 183 return &PortForwarder{ 184 dialer: dialer, 185 addresses: parsedAddresses, 186 ports: parsedPorts, 187 stopChan: stopChan, 188 Ready: readyChan, 189 out: out, 190 errOut: errOut, 191 }, nil 192 } 193 194 func (pf *PortForwarder) Addresses() []string { 195 var addresses []string 196 for _, la := range pf.addresses { 197 addresses = append(addresses, la.address) 198 } 199 return addresses 200 } 201 202 // ForwardPorts formats and executes a port forwarding request. The connection will remain 203 // open until stopChan is closed. 204 func (pf *PortForwarder) ForwardPorts() error { 205 defer pf.Close() 206 207 var err error 208 var protocol string 209 pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name) 210 if err != nil { 211 return fmt.Errorf("error upgrading connection: %s", err) 212 } 213 defer pf.streamConn.Close() 214 if protocol != PortForwardProtocolV1Name { 215 return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol) 216 } 217 218 return pf.forward() 219 } 220 221 // forward dials the remote host specific in req, upgrades the request, starts 222 // listeners for each port specified in ports, and forwards local connections 223 // to the remote host via streams. 224 // 225 // Returns an error if any of the local ports aren't available. 226 func (pf *PortForwarder) forward() error { 227 var err error 228 pf.errorHandler = newErrorHandler() 229 defer pf.errorHandler.Close() 230 231 for i := range pf.ports { 232 port := &pf.ports[i] 233 err = pf.listenOnPort(port) 234 if err != nil { 235 return fmt.Errorf("Unable to listen on port %d: %v", port.Local, err) 236 } 237 } 238 239 if pf.Ready != nil { 240 close(pf.Ready) 241 } 242 243 // wait for interrupt or conn closure 244 select { 245 case err := <-pf.errorHandler.Done(): 246 return err 247 case <-pf.stopChan: 248 case <-pf.streamConn.CloseChan(): 249 return ErrLostConnectionToPod 250 } 251 252 return nil 253 } 254 255 // listenOnPort delegates listener creation and waits for connections on requested bind addresses. 256 // An error is raised based on address groups (default and localhost) and their failure modes 257 func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error { 258 var errors []error 259 failCounters := make(map[string]int, 2) 260 successCounters := make(map[string]int, 2) 261 for _, addr := range pf.addresses { 262 err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address) 263 if err != nil { 264 errors = append(errors, err) 265 failCounters[addr.failureMode]++ 266 } else { 267 successCounters[addr.failureMode]++ 268 } 269 } 270 if successCounters["all"] == 0 && failCounters["all"] > 0 { 271 return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors) 272 } 273 if failCounters["any"] > 0 { 274 return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors) 275 } 276 return nil 277 } 278 279 // listenOnPortAndAddress delegates listener creation and waits for new connections 280 // in the background f 281 func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error { 282 listener, err := pf.getListener(protocol, address, port) 283 if err != nil { 284 return err 285 } 286 pf.listeners = append(pf.listeners, listener) 287 go pf.waitForConnection(listener, *port) 288 return nil 289 } 290 291 // getListener creates a listener on the interface targeted by the given hostname on the given port with 292 // the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6 293 func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) { 294 listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local)))) 295 if err != nil { 296 return nil, fmt.Errorf("unable to create listener: Error %s", err) 297 } 298 listenerAddress := listener.Addr().String() 299 host, localPort, _ := net.SplitHostPort(listenerAddress) 300 localPortUInt, err := strconv.ParseUint(localPort, 10, 16) 301 302 if err != nil { 303 return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host) 304 } 305 port.Local = uint16(localPortUInt) 306 307 return listener, nil 308 } 309 310 // waitForConnection waits for new connections to listener and handles them in 311 // the background. 312 func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) { 313 for { 314 select { 315 case <-pf.streamConn.CloseChan(): 316 return 317 default: 318 conn, err := listener.Accept() 319 if err != nil { 320 // TODO consider using something like https://github.com/hydrogen18/stoppableListener? 321 if !strings.Contains(strings.ToLower(err.Error()), networkClosedError) { 322 _, _ = fmt.Fprintf(pf.out, "error accepting connection on port %d: %v", port.Local, err) 323 } 324 return 325 } 326 go pf.handleConnection(conn, port) 327 } 328 } 329 } 330 331 func (pf *PortForwarder) nextRequestID() int { 332 pf.requestIDLock.Lock() 333 defer pf.requestIDLock.Unlock() 334 id := pf.requestID 335 pf.requestID++ 336 return id 337 } 338 339 // handleConnection copies data between the local connection and the stream to 340 // the remote server. 341 func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { 342 defer conn.Close() 343 344 requestID := pf.nextRequestID() 345 346 // create error stream 347 headers := http.Header{} 348 headers.Set(v1.StreamType, v1.StreamTypeError) 349 headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote)) 350 headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID)) 351 errorStream, err := pf.streamConn.CreateStream(headers) 352 if err != nil { 353 // If CreateStream fails, stop the whole portforwarder, because this might 354 // mean the whole streamConn is wedged. The PortForward reconciler will backoff 355 // and re-create the connection. 356 pf.errorHandler.Stop(fmt.Errorf("creating stream: %v", err)) 357 return 358 } 359 // we're not writing to this stream 360 errorStream.Close() 361 defer pf.streamConn.RemoveStreams(errorStream) 362 363 errorChan := make(chan error) 364 go func() { 365 message, err := io.ReadAll(errorStream) 366 switch { 367 case err != nil: 368 errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err) 369 case len(message) > 0: 370 errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message)) 371 } 372 close(errorChan) 373 }() 374 375 // create data stream 376 headers.Set(v1.StreamType, v1.StreamTypeData) 377 dataStream, err := pf.streamConn.CreateStream(headers) 378 if err != nil { 379 // If CreateStream fails, stop the whole portforwarder, because this might 380 // mean the whole streamConn is wedged. The PortForward reconciler will backoff 381 // and re-create the connection. 382 pf.errorHandler.Stop(fmt.Errorf("creating stream: %v", err)) 383 return 384 } 385 defer pf.streamConn.RemoveStreams(dataStream) 386 387 localError := make(chan struct{}) 388 remoteDone := make(chan struct{}) 389 390 go func() { 391 // Copy from the remote side to the local port. 392 if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) { 393 _, _ = fmt.Fprintf(pf.out, "error copying from remote stream to local connection: %v", err) 394 } 395 396 // inform the select below that the remote copy is done 397 close(remoteDone) 398 }() 399 400 go func() { 401 // inform server we're not sending any more data after copy unblocks 402 defer dataStream.Close() 403 404 // Copy from the local port to the remote side. 405 if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) { 406 _, _ = fmt.Fprintf(pf.out, "error copying from local connection to remote stream: %v", err) 407 // break out of the select below without waiting for the other copy to finish 408 close(localError) 409 } 410 }() 411 412 // wait for either a local->remote error or for copying from remote->local to finish 413 select { 414 case <-remoteDone: 415 case <-localError: 416 } 417 418 // reset dataStream to discard any unsent data, preventing port forwarding from being blocked. 419 // we must reset dataStream before waiting on errorChan, otherwise, 420 // the blocking data will affect errorStream and cause <-errorChan to block indefinitely. 421 _ = dataStream.Reset() 422 423 // always expect something on errorChan (it may be nil) 424 err = <-errorChan 425 if err != nil { 426 _, _ = fmt.Fprintf(pf.out, "%v", err) 427 pf.streamConn.Close() 428 } 429 } 430 431 // Close stops all listeners of PortForwarder. 432 func (pf *PortForwarder) Close() { 433 // stop all listeners 434 for _, l := range pf.listeners { 435 if err := l.Close(); err != nil { 436 _, _ = fmt.Fprintf(pf.out, "error closing listener: %v", err) 437 } 438 } 439 } 440 441 // GetPorts will return the ports that were forwarded; this can be used to 442 // retrieve the locally-bound port in cases where the input was port 0. This 443 // function will signal an error if the Ready channel is nil or if the 444 // listeners are not ready yet; this function will succeed after the Ready 445 // channel has been closed. 446 func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) { 447 if pf.Ready == nil { 448 return nil, fmt.Errorf("no Ready channel provided") 449 } 450 select { 451 case <-pf.Ready: 452 return pf.ports, nil 453 default: 454 return nil, fmt.Errorf("listeners not ready") 455 } 456 }