github.com/nycdavid/zeus@v0.0.0-20201208104106-9ba439429e03/go/unixsocket/unixsocket_test.go (about) 1 package unixsocket 2 3 import ( 4 "io/ioutil" 5 "os" 6 "strings" 7 "syscall" 8 "testing" 9 ) 10 11 func TestLongMessage(t *testing.T) { 12 message := strings.Repeat("abcdefghijklmonpqrstuvwxyz", 1000) 13 14 a, b := makeUsockPair(t) 15 16 go sendMessage(t, a, message) 17 expectMessage(t, b, message) 18 } 19 20 func TestMessagesAndFDs(t *testing.T) { 21 var msg string 22 a, b := makeUsockPair(t) 23 24 tempFile := makeTempFile(t) 25 defer os.Remove(tempFile.Name()) 26 27 messages := []string{"zomg", "wtf", "lol"} 28 29 sendFD(t, a, tempFile.Fd()) 30 for _, msg = range messages { 31 sendMessage(t, a, msg) 32 } 33 34 for _, msg = range messages { 35 expectMessage(t, b, msg) 36 } 37 expectFD(t, b, tempFile.Fd()) 38 } 39 40 func makeUsockPair(t *testing.T) (sockA, sockB *Usock) { 41 a, b, err := Socketpair(syscall.SOCK_STREAM) 42 if err != nil { 43 t.Fatal(err) 44 } 45 46 sockA, err = NewFromFile(a) 47 if err != nil { 48 t.Fatal(err) 49 } 50 51 sockB, err = NewFromFile(b) 52 if err != nil { 53 t.Fatal(err) 54 } 55 56 return 57 } 58 59 func makeTempFile(t *testing.T) (tempFile *os.File) { 60 tempFile, err := ioutil.TempFile("/tmp", "unixsocket_test") 61 if err != nil { 62 t.Fatal(err) 63 } 64 return 65 } 66 67 func expectMessage(t *testing.T, b *Usock, msg string) { 68 readMsg, err := b.ReadMessage() 69 if err != nil { 70 t.Error(err) 71 } 72 if readMsg != msg { 73 t.Errorf("Expected \"%s\", but read \"%s\"\n", msg, readMsg) 74 } 75 } 76 77 func sendMessage(t *testing.T, a *Usock, msg string) { 78 n, err := a.WriteMessage(msg) 79 if err != nil { 80 t.Error(err) 81 } 82 if n != len(msg) { 83 t.Errorf("Expected %d bytes written, but was %d\n", len(msg), n) 84 } 85 } 86 87 func sendFD(t *testing.T, a *Usock, fd uintptr) { 88 if err := a.WriteFD(int(fd)); err != nil { 89 t.Error(err) 90 } 91 } 92 93 func expectFD(t *testing.T, b *Usock, compareFd uintptr) { 94 fd, err := b.ReadFD() 95 if err != nil { 96 t.Error(err) 97 } 98 if fd <= int(compareFd) { 99 t.Errorf("Expected new FD, but got %d\n", fd) 100 } 101 }