github.com/spotify/syslog-redirector-golang@v0.0.0-20140320174030-4859f03d829a/src/pkg/syscall/passfd_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build linux dragonfly darwin freebsd netbsd openbsd
     6  
     7  package syscall_test
     8  
     9  import (
    10  	"flag"
    11  	"fmt"
    12  	"io/ioutil"
    13  	"net"
    14  	"os"
    15  	"os/exec"
    16  	"runtime"
    17  	"syscall"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  // TestPassFD tests passing a file descriptor over a Unix socket.
    23  //
    24  // This test involved both a parent and child process. The parent
    25  // process is invoked as a normal test, with "go test", which then
    26  // runs the child process by running the current test binary with args
    27  // "-test.run=^TestPassFD$" and an environment variable used to signal
    28  // that the test should become the child process instead.
    29  func TestPassFD(t *testing.T) {
    30  	if runtime.GOOS == "dragonfly" {
    31  		// TODO(jsing): Figure out why sendmsg is returning EINVAL.
    32  		t.Skip("Skipping test on dragonfly")
    33  	}
    34  	if os.Getenv("GO_WANT_HELPER_PROCESS") == "1" {
    35  		passFDChild()
    36  		return
    37  	}
    38  
    39  	tempDir, err := ioutil.TempDir("", "TestPassFD")
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	defer os.RemoveAll(tempDir)
    44  
    45  	fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
    46  	if err != nil {
    47  		t.Fatalf("Socketpair: %v", err)
    48  	}
    49  	defer syscall.Close(fds[0])
    50  	defer syscall.Close(fds[1])
    51  	writeFile := os.NewFile(uintptr(fds[0]), "child-writes")
    52  	readFile := os.NewFile(uintptr(fds[1]), "parent-reads")
    53  	defer writeFile.Close()
    54  	defer readFile.Close()
    55  
    56  	cmd := exec.Command(os.Args[0], "-test.run=^TestPassFD$", "--", tempDir)
    57  	cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"}
    58  	cmd.ExtraFiles = []*os.File{writeFile}
    59  
    60  	out, err := cmd.CombinedOutput()
    61  	if len(out) > 0 || err != nil {
    62  		t.Fatalf("child process: %q, %v", out, err)
    63  	}
    64  
    65  	c, err := net.FileConn(readFile)
    66  	if err != nil {
    67  		t.Fatalf("FileConn: %v", err)
    68  	}
    69  	defer c.Close()
    70  
    71  	uc, ok := c.(*net.UnixConn)
    72  	if !ok {
    73  		t.Fatalf("unexpected FileConn type; expected UnixConn, got %T", c)
    74  	}
    75  
    76  	buf := make([]byte, 32) // expect 1 byte
    77  	oob := make([]byte, 32) // expect 24 bytes
    78  	closeUnix := time.AfterFunc(5*time.Second, func() {
    79  		t.Logf("timeout reading from unix socket")
    80  		uc.Close()
    81  	})
    82  	_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob)
    83  	closeUnix.Stop()
    84  
    85  	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
    86  	if err != nil {
    87  		t.Fatalf("ParseSocketControlMessage: %v", err)
    88  	}
    89  	if len(scms) != 1 {
    90  		t.Fatalf("expected 1 SocketControlMessage; got scms = %#v", scms)
    91  	}
    92  	scm := scms[0]
    93  	gotFds, err := syscall.ParseUnixRights(&scm)
    94  	if err != nil {
    95  		t.Fatalf("syscall.ParseUnixRights: %v", err)
    96  	}
    97  	if len(gotFds) != 1 {
    98  		t.Fatalf("wanted 1 fd; got %#v", gotFds)
    99  	}
   100  
   101  	f := os.NewFile(uintptr(gotFds[0]), "fd-from-child")
   102  	defer f.Close()
   103  
   104  	got, err := ioutil.ReadAll(f)
   105  	want := "Hello from child process!\n"
   106  	if string(got) != want {
   107  		t.Errorf("child process ReadAll: %q, %v; want %q", got, err, want)
   108  	}
   109  }
   110  
   111  // passFDChild is the child process used by TestPassFD.
   112  func passFDChild() {
   113  	defer os.Exit(0)
   114  
   115  	// Look for our fd. It should be fd 3, but we work around an fd leak
   116  	// bug here (http://golang.org/issue/2603) to let it be elsewhere.
   117  	var uc *net.UnixConn
   118  	for fd := uintptr(3); fd <= 10; fd++ {
   119  		f := os.NewFile(fd, "unix-conn")
   120  		var ok bool
   121  		netc, _ := net.FileConn(f)
   122  		uc, ok = netc.(*net.UnixConn)
   123  		if ok {
   124  			break
   125  		}
   126  	}
   127  	if uc == nil {
   128  		fmt.Println("failed to find unix fd")
   129  		return
   130  	}
   131  
   132  	// Make a file f to send to our parent process on uc.
   133  	// We make it in tempDir, which our parent will clean up.
   134  	flag.Parse()
   135  	tempDir := flag.Arg(0)
   136  	f, err := ioutil.TempFile(tempDir, "")
   137  	if err != nil {
   138  		fmt.Printf("TempFile: %v", err)
   139  		return
   140  	}
   141  
   142  	f.Write([]byte("Hello from child process!\n"))
   143  	f.Seek(0, 0)
   144  
   145  	rights := syscall.UnixRights(int(f.Fd()))
   146  	dummyByte := []byte("x")
   147  	n, oobn, err := uc.WriteMsgUnix(dummyByte, rights, nil)
   148  	if err != nil {
   149  		fmt.Printf("WriteMsgUnix: %v", err)
   150  		return
   151  	}
   152  	if n != 1 || oobn != len(rights) {
   153  		fmt.Printf("WriteMsgUnix = %d, %d; want 1, %d", n, oobn, len(rights))
   154  		return
   155  	}
   156  }
   157  
   158  // TestUnixRightsRoundtrip tests that UnixRights, ParseSocketControlMessage,
   159  // and ParseUnixRights are able to successfully round-trip lists of file descriptors.
   160  func TestUnixRightsRoundtrip(t *testing.T) {
   161  	testCases := [...][][]int{
   162  		{{42}},
   163  		{{1, 2}},
   164  		{{3, 4, 5}},
   165  		{{}},
   166  		{{1, 2}, {3, 4, 5}, {}, {7}},
   167  	}
   168  	for _, testCase := range testCases {
   169  		b := []byte{}
   170  		var n int
   171  		for _, fds := range testCase {
   172  			// Last assignment to n wins
   173  			n = len(b) + syscall.CmsgLen(4*len(fds))
   174  			b = append(b, syscall.UnixRights(fds...)...)
   175  		}
   176  		// Truncate b
   177  		b = b[:n]
   178  
   179  		scms, err := syscall.ParseSocketControlMessage(b)
   180  		if err != nil {
   181  			t.Fatalf("ParseSocketControlMessage: %v", err)
   182  		}
   183  		if len(scms) != len(testCase) {
   184  			t.Fatalf("expected %v SocketControlMessage; got scms = %#v", len(testCase), scms)
   185  		}
   186  		for i, scm := range scms {
   187  			gotFds, err := syscall.ParseUnixRights(&scm)
   188  			if err != nil {
   189  				t.Fatalf("ParseUnixRights: %v", err)
   190  			}
   191  			wantFds := testCase[i]
   192  			if len(gotFds) != len(wantFds) {
   193  				t.Fatalf("expected %v fds, got %#v", len(wantFds), gotFds)
   194  			}
   195  			for j, fd := range gotFds {
   196  				if fd != wantFds[j] {
   197  					t.Fatalf("expected fd %v, got %v", wantFds[j], fd)
   198  				}
   199  			}
   200  		}
   201  	}
   202  }