gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/runsc/cmd/portforward.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cmd
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"io/ioutil"
    22  	"math"
    23  	"net"
    24  	"os"
    25  	"os/signal"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  	"syscall"
    30  
    31  	"github.com/google/subcommands"
    32  	"gvisor.dev/gvisor/pkg/log"
    33  	"gvisor.dev/gvisor/pkg/urpc"
    34  	"gvisor.dev/gvisor/runsc/boot"
    35  	"gvisor.dev/gvisor/runsc/cmd/util"
    36  	"gvisor.dev/gvisor/runsc/config"
    37  	"gvisor.dev/gvisor/runsc/container"
    38  	"gvisor.dev/gvisor/runsc/flag"
    39  )
    40  
    41  // PortForward implements subcommands.Command for the "portforward" command.
    42  type PortForward struct {
    43  	portNum int
    44  	stream  string
    45  }
    46  
    47  // Name implements subcommands.Command.Name.
    48  func (*PortForward) Name() string {
    49  	return "port-forward"
    50  }
    51  
    52  // Synopsis implements subcommands.Command.Synopsis.
    53  func (*PortForward) Synopsis() string {
    54  	return "port forward to a secure container"
    55  }
    56  
    57  // Usage implements subcommands.Command.Usage.
    58  func (*PortForward) Usage() string {
    59  	return `port-forward CONTAINER_ID [LOCAL_PORT:]REMOTE_PORT - port forward to gvisor container.
    60  
    61  Port forwarding has two modes. Local mode opens a local port and forwards
    62  connections to another port inside the specified container. Stream mode
    63  forwards a single connection on a UDS to the specified port in the container.
    64  
    65  EXAMPLES:
    66  
    67  The following will forward connections on local port 8080 to port 80 in the
    68  container named 'nginx':
    69  
    70  	# runsc port-forward nginx 8080:80
    71  
    72  The following will forward a single new connection on the unix domain socket at
    73  /tmp/pipe to port 80 in the container named 'nginx':
    74  
    75  	# runsc port-forward --stream /tmp/pipe nginx 80
    76  
    77  OPTIONS:
    78  `
    79  }
    80  
    81  // SetFlags implements subcommands.Command.SetFlags.
    82  func (p *PortForward) SetFlags(f *flag.FlagSet) {
    83  	f.StringVar(&p.stream, "stream", "", "Stream mode - a Unix doman socket")
    84  }
    85  
    86  // Execute implements subcommands.Command.Execute.
    87  func (p *PortForward) Execute(ctx context.Context, f *flag.FlagSet, args ...any) subcommands.ExitStatus {
    88  	conf := args[0].(*config.Config)
    89  	// Requires at least the container id and port.
    90  	if f.NArg() != 2 {
    91  		f.Usage()
    92  		return subcommands.ExitUsageError
    93  	}
    94  
    95  	id := f.Arg(0)
    96  	portStr := f.Arg(1)
    97  
    98  	c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
    99  	if err != nil {
   100  		util.Fatalf("loading container: %v", err)
   101  	}
   102  
   103  	if p.stream != "" {
   104  		if err := p.doStream(ctx, portStr, c); err != nil {
   105  			util.Fatalf("doStream: %v", err)
   106  		}
   107  		return subcommands.ExitSuccess
   108  	}
   109  
   110  	// Allow forwarding to a local port.
   111  	ports := strings.Split(portStr, ":")
   112  	if len(ports) != 2 {
   113  		util.Fatalf("invalid port string %q", portStr)
   114  	}
   115  
   116  	localPort, err := strconv.Atoi(ports[0])
   117  	if err != nil {
   118  		util.Fatalf("invalid port string %q: %v", portStr, err)
   119  	}
   120  	portNum, err := strconv.Atoi(ports[1])
   121  	if err != nil {
   122  		util.Fatalf("invalid port string %q: %v", portStr, err)
   123  	}
   124  	if portNum <= 0 || portNum > math.MaxUint16 {
   125  		util.Fatalf("invalid port %d: %v", portNum, err)
   126  	}
   127  
   128  	// Start port forwarding with the local port.
   129  	var wg sync.WaitGroup
   130  	ctx, cancel := context.WithCancel(ctx)
   131  	wg.Add(3)
   132  	go func(localPort, portNum int) {
   133  		defer cancel()
   134  		defer wg.Done()
   135  		// Print message to local user.
   136  		fmt.Printf("Forwarding local port %d to %d...\n", localPort, portNum)
   137  		if err := localForward(ctx, c, localPort, uint16(portNum)); err != nil {
   138  			log.Warningf("port forwarding: %v", err)
   139  		}
   140  	}(localPort, portNum)
   141  
   142  	// Exit port forwarding if the container exits.
   143  	go func() {
   144  		defer wg.Done()
   145  		// Cancel port forwarding after Wait returns regardless of return
   146  		// value as err may indicate sandbox has terminated already.
   147  		_, _ = c.Wait()
   148  		fmt.Printf("Container %q stopped. Exiting...\n", c.ID)
   149  		cancel()
   150  	}()
   151  
   152  	// Wait for ^C from the user.
   153  	go func() {
   154  		defer wg.Done()
   155  		sig := waitSignal()
   156  		fmt.Printf("Got %v, Exiting...\n", sig)
   157  		cancel()
   158  	}()
   159  
   160  	// Wait on a WaitGroup for port forwarding to clean up before exiting.
   161  	wg.Wait()
   162  
   163  	return subcommands.ExitSuccess
   164  }
   165  
   166  // localForward starts port forwarding from the given local port.
   167  func localForward(ctx context.Context, c *container.Container, localPort int, containerPort uint16) error {
   168  	l, err := net.Listen("tcp", ":"+strconv.Itoa(localPort))
   169  	if err != nil {
   170  		return err
   171  	}
   172  	defer l.Close()
   173  
   174  	var localConnChan = make(chan net.Conn, 1)
   175  	var errChan = make(chan error, 1)
   176  	go func() {
   177  		for {
   178  			if ctx.Err() != nil {
   179  				return
   180  			}
   181  			localConn, err := l.Accept()
   182  			if err != nil {
   183  				errChan <- err
   184  				continue
   185  			}
   186  			localConnChan <- localConn
   187  		}
   188  	}()
   189  
   190  	for {
   191  		// Exit if the context is done.
   192  		select {
   193  		case <-ctx.Done():
   194  			return ctx.Err()
   195  		case err := <-errChan:
   196  			if err != nil {
   197  				log.Warningf("accepting local connection: %v", err)
   198  			}
   199  		case localConn := <-localConnChan:
   200  			// Dispatch a new goroutine to handle the new connection.
   201  			go func() {
   202  				defer localConn.Close()
   203  				fmt.Println("Forwarding new connection...")
   204  				err := portCopy(ctx, c, localConn, containerPort)
   205  				if err != nil {
   206  					log.Warningf("port forwarding: %v", err)
   207  				}
   208  				fmt.Println("Finished forwarding connection...")
   209  			}()
   210  		}
   211  	}
   212  }
   213  
   214  // doStream does the stream version of the port-forward command.
   215  func (p *PortForward) doStream(ctx context.Context, port string, c *container.Container) error {
   216  	var err error
   217  	p.portNum, err = strconv.Atoi(port)
   218  	if err != nil {
   219  		return fmt.Errorf("invalid port string %q: %v", port, err)
   220  	}
   221  
   222  	if p.portNum < 0 || p.portNum > math.MaxUint16 {
   223  		return fmt.Errorf("invalid port %d: %v", p.portNum, err)
   224  	}
   225  
   226  	f, err := openStream(p.stream)
   227  	if err != nil {
   228  		return fmt.Errorf("opening uds stream: %v", err)
   229  	}
   230  	defer f.Close()
   231  
   232  	if err := c.PortForward(&boot.PortForwardOpts{
   233  		Port:        uint16(p.portNum),
   234  		ContainerID: c.ID,
   235  		FilePayload: urpc.FilePayload{Files: []*os.File{f}},
   236  	}); err != nil {
   237  		return fmt.Errorf("PortForward: %v", err)
   238  	}
   239  
   240  	return nil
   241  }
   242  
   243  // portCopy creates a UDS and begins copying data to and from the local
   244  // connection.
   245  func portCopy(ctx context.Context, c *container.Container, localConn net.Conn, port uint16) error {
   246  	// Create a new path address for the UDS.
   247  	addr, err := tmpUDSAddr()
   248  	if err != nil {
   249  		return err
   250  	}
   251  
   252  	// Create the UDS and Listen on it.
   253  	l, err := net.Listen("unix", addr)
   254  	if err != nil {
   255  		return err
   256  	}
   257  	defer l.Close()
   258  
   259  	// Open the UDS as a File so it can be donated to the sentry.
   260  	streamFile, err := openStream(addr)
   261  	if err != nil {
   262  		return fmt.Errorf("opening uds stream: %v", err)
   263  	}
   264  	defer streamFile.Close()
   265  
   266  	// Request port forwarding from the sentry. This request will return
   267  	// immediately after port forwarding is started and connection state is
   268  	// handled via the UDS from then on.
   269  	if err := c.PortForward(&boot.PortForwardOpts{
   270  		Port:        port,
   271  		FilePayload: urpc.FilePayload{Files: []*os.File{streamFile}},
   272  	}); err != nil {
   273  		return fmt.Errorf("PortForward: %v", err)
   274  	}
   275  
   276  	// We have already opened a single connection on the UDS and passed the
   277  	// client end to the sentry. We accept the connection now in order to get
   278  	// the other half of the connection.
   279  	conn, err := l.Accept()
   280  	if err != nil {
   281  		return err
   282  	}
   283  
   284  	toErrCh := make(chan error)
   285  	fromErrCh := make(chan error)
   286  	// Copy data from the local port to the UDS.
   287  	go func() {
   288  		defer conn.Close()
   289  		defer localConn.Close()
   290  		log.Debugf("Start copying from %q to %q", localConn.LocalAddr().String(), conn.LocalAddr().String())
   291  		_, err := io.Copy(localConn, conn)
   292  		log.Debugf("Stopped copying from %q to %q", localConn.LocalAddr().String(), conn.LocalAddr().String())
   293  		toErrCh <- err
   294  		close(toErrCh)
   295  	}()
   296  
   297  	// Copy data from the UDS to the local port.
   298  	go func() {
   299  		defer conn.Close()
   300  		defer localConn.Close()
   301  		log.Debugf("Start copying from %q to %q", conn.LocalAddr().String(), localConn.LocalAddr().String())
   302  		_, err := io.Copy(conn, localConn)
   303  		log.Debugf("Stopped copying from %q to %q", conn.LocalAddr().String(), localConn.LocalAddr().String())
   304  		fromErrCh <- err
   305  		close(fromErrCh)
   306  	}()
   307  
   308  	errMap := map[string]error{}
   309  	for {
   310  		if len(errMap) == 2 {
   311  			return nil
   312  		}
   313  		select {
   314  		case e := <-toErrCh:
   315  			errMap["toChannel"] = e
   316  		case e := <-fromErrCh:
   317  			errMap["fromChannel"] = e
   318  		case <-ctx.Done():
   319  			log.Debugf("Port forwarding connection canceled for %q: %v", localConn.LocalAddr().String(), ctx.Err())
   320  			return ctx.Err()
   321  		}
   322  	}
   323  }
   324  
   325  // tmpUDS generates a temporary UDS addr.
   326  func tmpUDSAddr() (string, error) {
   327  	tmpFile, err := ioutil.TempFile("", "runsc-port-forward")
   328  	if err != nil {
   329  		return "", err
   330  	}
   331  	path := tmpFile.Name()
   332  	// Remove the tempfile and just use its name.
   333  	os.Remove(path)
   334  
   335  	return path, nil
   336  }
   337  
   338  // openStream opens a UDS as a socket and returns the file descriptor in an
   339  // os.File object.
   340  func openStream(name string) (*os.File, error) {
   341  	// The net package will abstract the fd, so we use raw syscalls.
   342  	fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
   343  	if err != nil {
   344  		return nil, err
   345  	}
   346  
   347  	// We are acting as a client so we will connect to the socket.
   348  	if err = syscall.Connect(fd, &syscall.SockaddrUnix{Name: name}); err != nil {
   349  		syscall.Close(fd)
   350  		return nil, err
   351  	}
   352  
   353  	// Return a File so that we can pass it to the Sentry.
   354  	return os.NewFile(uintptr(fd), name), nil
   355  }
   356  
   357  // waitSignal waits for SIGINT, SIGQUIT, or SIGTERM from the user.
   358  func waitSignal() os.Signal {
   359  	ch := make(chan os.Signal, 2)
   360  	signal.Notify(
   361  		ch,
   362  		syscall.SIGINT,
   363  		syscall.SIGQUIT,
   364  		syscall.SIGTERM,
   365  	)
   366  	return <-ch
   367  }