github.com/brass-software/os@v0.0.0-20240129060254-960f457a5dea/writeto_linux_test.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package os_test 6 7 import ( 8 "bytes" 9 "internal/poll" 10 "io" 11 "math/rand" 12 "net" 13 . "os" 14 "strconv" 15 "syscall" 16 "testing" 17 "time" 18 ) 19 20 func TestSendFile(t *testing.T) { 21 sizes := []int{ 22 1, 23 42, 24 1025, 25 syscall.Getpagesize() + 1, 26 32769, 27 } 28 t.Run("sendfile-to-unix", func(t *testing.T) { 29 for _, size := range sizes { 30 t.Run(strconv.Itoa(size), func(t *testing.T) { 31 testSendFile(t, "unix", int64(size)) 32 }) 33 } 34 }) 35 t.Run("sendfile-to-tcp", func(t *testing.T) { 36 for _, size := range sizes { 37 t.Run(strconv.Itoa(size), func(t *testing.T) { 38 testSendFile(t, "tcp", int64(size)) 39 }) 40 } 41 }) 42 } 43 44 func testSendFile(t *testing.T, proto string, size int64) { 45 dst, src, recv, data, hook := newSendFileTest(t, proto, size) 46 47 // Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile 48 n, err := io.Copy(dst, src) 49 if err != nil { 50 t.Fatalf("io.Copy error: %v", err) 51 } 52 53 // We should have called poll.Splice with the right file descriptor arguments. 54 if n > 0 && !hook.called { 55 t.Fatal("expected to called poll.SendFile") 56 } 57 if hook.called && hook.srcfd != int(src.Fd()) { 58 t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd()) 59 } 60 sc, ok := dst.(syscall.Conn) 61 if !ok { 62 t.Fatalf("destination is not a syscall.Conn") 63 } 64 rc, err := sc.SyscallConn() 65 if err != nil { 66 t.Fatalf("destination SyscallConn error: %v", err) 67 } 68 if err = rc.Control(func(fd uintptr) { 69 if hook.called && hook.dstfd != int(fd) { 70 t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd)) 71 } 72 }); err != nil { 73 t.Fatalf("destination Conn Control error: %v", err) 74 } 75 76 // Verify the data size and content. 77 dataSize := len(data) 78 dstData := make([]byte, dataSize) 79 m, err := io.ReadFull(recv, dstData) 80 if err != nil { 81 t.Fatalf("server Conn Read error: %v", err) 82 } 83 if n != int64(dataSize) { 84 t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize) 85 } 86 if m != dataSize { 87 t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize) 88 } 89 if !bytes.Equal(dstData, data) { 90 t.Errorf("data mismatch, got %s, want %s", dstData, data) 91 } 92 } 93 94 // newSendFileTest initializes a new test for sendfile. 95 // 96 // It creates source file and destination sockets, and populates the source file 97 // with random data of the specified size. It also hooks package os' call 98 // to poll.Sendfile and returns the hook so it can be inspected. 99 func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) { 100 t.Helper() 101 102 hook := hookSendFile(t) 103 104 client, server := createSocketPair(t, proto) 105 tempFile, data := createTempFile(t, size) 106 107 return client, tempFile, server, data, hook 108 } 109 110 func hookSendFile(t *testing.T) *sendFileHook { 111 h := new(sendFileHook) 112 h.install() 113 t.Cleanup(h.uninstall) 114 return h 115 } 116 117 type sendFileHook struct { 118 called bool 119 dstfd int 120 srcfd int 121 remain int64 122 123 written int64 124 handled bool 125 err error 126 127 original func(dst *poll.FD, src int, remain int64) (int64, error, bool) 128 } 129 130 func (h *sendFileHook) install() { 131 h.original = *PollSendFile 132 *PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) { 133 h.called = true 134 h.dstfd = dst.Sysfd 135 h.srcfd = src 136 h.remain = remain 137 h.written, h.err, h.handled = h.original(dst, src, remain) 138 return h.written, h.err, h.handled 139 } 140 } 141 142 func (h *sendFileHook) uninstall() { 143 *PollSendFile = h.original 144 } 145 146 func createTempFile(t *testing.T, size int64) (*File, []byte) { 147 f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket") 148 if err != nil { 149 t.Fatalf("failed to create temporary file: %v", err) 150 } 151 t.Cleanup(func() { 152 f.Close() 153 }) 154 155 randSeed := time.Now().Unix() 156 t.Logf("random data seed: %d\n", randSeed) 157 prng := rand.New(rand.NewSource(randSeed)) 158 data := make([]byte, size) 159 prng.Read(data) 160 if _, err := f.Write(data); err != nil { 161 t.Fatalf("failed to create and feed the file: %v", err) 162 } 163 if err := f.Sync(); err != nil { 164 t.Fatalf("failed to save the file: %v", err) 165 } 166 if _, err := f.Seek(0, io.SeekStart); err != nil { 167 t.Fatalf("failed to rewind the file: %v", err) 168 } 169 170 return f, data 171 }