github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/cmd/test_app/fds.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"context"
    19  	"io"
    20  	"io/ioutil"
    21  	"log"
    22  	"os"
    23  	"time"
    24  
    25  	"github.com/google/subcommands"
    26  	"github.com/SagerNet/gvisor/pkg/test/testutil"
    27  	"github.com/SagerNet/gvisor/pkg/unet"
    28  	"github.com/SagerNet/gvisor/runsc/flag"
    29  )
    30  
    31  const fileContents = "foobarbaz"
    32  
    33  // fdSender will open a file and send the FD over a unix domain socket.
    34  type fdSender struct {
    35  	socketPath string
    36  }
    37  
    38  // Name implements subcommands.Command.Name.
    39  func (*fdSender) Name() string {
    40  	return "fd_sender"
    41  }
    42  
    43  // Synopsis implements subcommands.Command.Synopsys.
    44  func (*fdSender) Synopsis() string {
    45  	return "creates a file and sends the FD over the socket"
    46  }
    47  
    48  // Usage implements subcommands.Command.Usage.
    49  func (*fdSender) Usage() string {
    50  	return "fd_sender <flags>"
    51  }
    52  
    53  // SetFlags implements subcommands.Command.SetFlags.
    54  func (fds *fdSender) SetFlags(f *flag.FlagSet) {
    55  	f.StringVar(&fds.socketPath, "socket", "", "path to socket")
    56  }
    57  
    58  // Execute implements subcommands.Command.Execute.
    59  func (fds *fdSender) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
    60  	if fds.socketPath == "" {
    61  		log.Fatalf("socket flag must be set")
    62  	}
    63  
    64  	dir, err := ioutil.TempDir("", "")
    65  	if err != nil {
    66  		log.Fatalf("TempDir failed: %v", err)
    67  	}
    68  
    69  	fileToSend, err := ioutil.TempFile(dir, "")
    70  	if err != nil {
    71  		log.Fatalf("TempFile failed: %v", err)
    72  	}
    73  	defer fileToSend.Close()
    74  
    75  	if _, err := fileToSend.WriteString(fileContents); err != nil {
    76  		log.Fatalf("Write(%q) failed: %v", fileContents, err)
    77  	}
    78  
    79  	// Receiver may not be started yet, so try connecting in a poll loop.
    80  	var s *unet.Socket
    81  	if err := testutil.Poll(func() error {
    82  		var err error
    83  		s, err = unet.Connect(fds.socketPath, true /* SEQPACKET, so we can send empty message with FD */)
    84  		return err
    85  	}, 10*time.Second); err != nil {
    86  		log.Fatalf("Error connecting to socket %q: %v", fds.socketPath, err)
    87  	}
    88  	defer s.Close()
    89  
    90  	w := s.Writer(true)
    91  	w.ControlMessage.PackFDs(int(fileToSend.Fd()))
    92  	if _, err := w.WriteVec([][]byte{{'a'}}); err != nil {
    93  		log.Fatalf("Error sending FD %q over socket %q: %v", fileToSend.Fd(), fds.socketPath, err)
    94  	}
    95  
    96  	log.Print("FD SENDER exiting successfully")
    97  	return subcommands.ExitSuccess
    98  }
    99  
   100  // fdReceiver receives an FD from a unix domain socket and does things to it.
   101  type fdReceiver struct {
   102  	socketPath string
   103  }
   104  
   105  // Name implements subcommands.Command.Name.
   106  func (*fdReceiver) Name() string {
   107  	return "fd_receiver"
   108  }
   109  
   110  // Synopsis implements subcommands.Command.Synopsys.
   111  func (*fdReceiver) Synopsis() string {
   112  	return "reads an FD from a unix socket, and then does things to it"
   113  }
   114  
   115  // Usage implements subcommands.Command.Usage.
   116  func (*fdReceiver) Usage() string {
   117  	return "fd_receiver <flags>"
   118  }
   119  
   120  // SetFlags implements subcommands.Command.SetFlags.
   121  func (fdr *fdReceiver) SetFlags(f *flag.FlagSet) {
   122  	f.StringVar(&fdr.socketPath, "socket", "", "path to socket")
   123  }
   124  
   125  // Execute implements subcommands.Command.Execute.
   126  func (fdr *fdReceiver) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
   127  	if fdr.socketPath == "" {
   128  		log.Fatalf("Flags cannot be empty, given: socket: %q", fdr.socketPath)
   129  	}
   130  
   131  	ss, err := unet.BindAndListen(fdr.socketPath, true /* packet */)
   132  	if err != nil {
   133  		log.Fatalf("BindAndListen(%q) failed: %v", fdr.socketPath, err)
   134  	}
   135  	defer ss.Close()
   136  
   137  	var s *unet.Socket
   138  	c := make(chan error, 1)
   139  	go func() {
   140  		var err error
   141  		s, err = ss.Accept()
   142  		c <- err
   143  	}()
   144  
   145  	select {
   146  	case err := <-c:
   147  		if err != nil {
   148  			log.Fatalf("Accept() failed: %v", err)
   149  		}
   150  	case <-time.After(10 * time.Second):
   151  		log.Fatalf("Timeout waiting for accept")
   152  	}
   153  
   154  	r := s.Reader(true)
   155  	r.EnableFDs(1)
   156  	b := [][]byte{{'a'}}
   157  	if n, err := r.ReadVec(b); n != 1 || err != nil {
   158  		log.Fatalf("ReadVec got n=%d err %v (wanted 0, nil)", n, err)
   159  	}
   160  
   161  	fds, err := r.ExtractFDs()
   162  	if err != nil {
   163  		log.Fatalf("ExtractFD() got err %v", err)
   164  	}
   165  	if len(fds) != 1 {
   166  		log.Fatalf("ExtractFD() got %d FDs, wanted 1", len(fds))
   167  	}
   168  	fd := fds[0]
   169  
   170  	file := os.NewFile(uintptr(fd), "received file")
   171  	defer file.Close()
   172  	if _, err := file.Seek(0, io.SeekStart); err != nil {
   173  		log.Fatalf("Error from seek(0, 0): %v", err)
   174  	}
   175  
   176  	got, err := ioutil.ReadAll(file)
   177  	if err != nil {
   178  		log.Fatalf("ReadAll failed: %v", err)
   179  	}
   180  	if string(got) != fileContents {
   181  		log.Fatalf("ReadAll got %q want %q", string(got), fileContents)
   182  	}
   183  
   184  	log.Print("FD RECEIVER exiting successfully")
   185  	return subcommands.ExitSuccess
   186  }