github.com/code-reading/golang@v0.0.0-20220303082512-ba5bc0e589a3/go/src/os/readfrom_linux_test.go (about) 1 // Copyright 2020 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package os_test 6 7 import ( 8 "bytes" 9 "internal/poll" 10 "io" 11 "math/rand" 12 "os" 13 . "os" 14 "path/filepath" 15 "strconv" 16 "syscall" 17 "testing" 18 "time" 19 ) 20 21 func TestCopyFileRange(t *testing.T) { 22 sizes := []int{ 23 1, 24 42, 25 1025, 26 syscall.Getpagesize() + 1, 27 32769, 28 } 29 t.Run("Basic", func(t *testing.T) { 30 for _, size := range sizes { 31 t.Run(strconv.Itoa(size), func(t *testing.T) { 32 testCopyFileRange(t, int64(size), -1) 33 }) 34 } 35 }) 36 t.Run("Limited", func(t *testing.T) { 37 t.Run("OneLess", func(t *testing.T) { 38 for _, size := range sizes { 39 t.Run(strconv.Itoa(size), func(t *testing.T) { 40 testCopyFileRange(t, int64(size), int64(size)-1) 41 }) 42 } 43 }) 44 t.Run("Half", func(t *testing.T) { 45 for _, size := range sizes { 46 t.Run(strconv.Itoa(size), func(t *testing.T) { 47 testCopyFileRange(t, int64(size), int64(size)/2) 48 }) 49 } 50 }) 51 t.Run("More", func(t *testing.T) { 52 for _, size := range sizes { 53 t.Run(strconv.Itoa(size), func(t *testing.T) { 54 testCopyFileRange(t, int64(size), int64(size)+7) 55 }) 56 } 57 }) 58 }) 59 t.Run("DoesntTryInAppendMode", func(t *testing.T) { 60 dst, src, data, hook := newCopyFileRangeTest(t, 42) 61 62 dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755) 63 if err != nil { 64 t.Fatal(err) 65 } 66 defer dst2.Close() 67 68 if _, err := io.Copy(dst2, src); err != nil { 69 t.Fatal(err) 70 } 71 if hook.called { 72 t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode") 73 } 74 mustSeekStart(t, dst2) 75 mustContainData(t, dst2, data) // through traditional means 76 }) 77 t.Run("NotRegular", func(t *testing.T) { 78 t.Run("BothPipes", func(t *testing.T) { 79 hook := hookCopyFileRange(t) 80 81 pr1, pw1, err := Pipe() 82 if err != nil { 83 t.Fatal(err) 84 } 85 defer pr1.Close() 86 defer pw1.Close() 87 88 pr2, pw2, err := Pipe() 89 if err != nil { 90 t.Fatal(err) 91 } 92 defer pr2.Close() 93 defer pw2.Close() 94 95 // The pipe is empty, and PIPE_BUF is large enough 96 // for this, by (POSIX) definition, so there is no 97 // need for an additional goroutine. 98 data := []byte("hello") 99 if _, err := pw1.Write(data); err != nil { 100 t.Fatal(err) 101 } 102 pw1.Close() 103 104 n, err := io.Copy(pw2, pr1) 105 if err != nil { 106 t.Fatal(err) 107 } 108 if n != int64(len(data)) { 109 t.Fatalf("transferred %d, want %d", n, len(data)) 110 } 111 if !hook.called { 112 t.Fatalf("should have called poll.CopyFileRange") 113 } 114 pw2.Close() 115 mustContainData(t, pr2, data) 116 }) 117 t.Run("DstPipe", func(t *testing.T) { 118 dst, src, data, hook := newCopyFileRangeTest(t, 255) 119 dst.Close() 120 121 pr, pw, err := Pipe() 122 if err != nil { 123 t.Fatal(err) 124 } 125 defer pr.Close() 126 defer pw.Close() 127 128 n, err := io.Copy(pw, src) 129 if err != nil { 130 t.Fatal(err) 131 } 132 if n != int64(len(data)) { 133 t.Fatalf("transferred %d, want %d", n, len(data)) 134 } 135 if !hook.called { 136 t.Fatalf("should have called poll.CopyFileRange") 137 } 138 pw.Close() 139 mustContainData(t, pr, data) 140 }) 141 t.Run("SrcPipe", func(t *testing.T) { 142 dst, src, data, hook := newCopyFileRangeTest(t, 255) 143 src.Close() 144 145 pr, pw, err := Pipe() 146 if err != nil { 147 t.Fatal(err) 148 } 149 defer pr.Close() 150 defer pw.Close() 151 152 // The pipe is empty, and PIPE_BUF is large enough 153 // for this, by (POSIX) definition, so there is no 154 // need for an additional goroutine. 155 if _, err := pw.Write(data); err != nil { 156 t.Fatal(err) 157 } 158 pw.Close() 159 160 n, err := io.Copy(dst, pr) 161 if err != nil { 162 t.Fatal(err) 163 } 164 if n != int64(len(data)) { 165 t.Fatalf("transferred %d, want %d", n, len(data)) 166 } 167 if !hook.called { 168 t.Fatalf("should have called poll.CopyFileRange") 169 } 170 mustSeekStart(t, dst) 171 mustContainData(t, dst, data) 172 }) 173 }) 174 t.Run("Nil", func(t *testing.T) { 175 var nilFile *File 176 anyFile, err := os.CreateTemp("", "") 177 if err != nil { 178 t.Fatal(err) 179 } 180 defer Remove(anyFile.Name()) 181 defer anyFile.Close() 182 183 if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid { 184 t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid) 185 } 186 if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid { 187 t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid) 188 } 189 if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid { 190 t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid) 191 } 192 193 if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid { 194 t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid) 195 } 196 if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid { 197 t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid) 198 } 199 if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid { 200 t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid) 201 } 202 }) 203 } 204 205 func testCopyFileRange(t *testing.T, size int64, limit int64) { 206 dst, src, data, hook := newCopyFileRangeTest(t, size) 207 208 // If we have a limit, wrap the reader. 209 var ( 210 realsrc io.Reader 211 lr *io.LimitedReader 212 ) 213 if limit >= 0 { 214 lr = &io.LimitedReader{N: limit, R: src} 215 realsrc = lr 216 if limit < int64(len(data)) { 217 data = data[:limit] 218 } 219 } else { 220 realsrc = src 221 } 222 223 // Now call ReadFrom (through io.Copy), which will hopefully call 224 // poll.CopyFileRange. 225 n, err := io.Copy(dst, realsrc) 226 if err != nil { 227 t.Fatal(err) 228 } 229 230 // If we didn't have a limit, we should have called poll.CopyFileRange 231 // with the right file descriptor arguments. 232 if limit > 0 && !hook.called { 233 t.Fatal("never called poll.CopyFileRange") 234 } 235 if hook.called && hook.dstfd != int(dst.Fd()) { 236 t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd()) 237 } 238 if hook.called && hook.srcfd != int(src.Fd()) { 239 t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd()) 240 } 241 242 // Check that the offsets after the transfer make sense, that the size 243 // of the transfer was reported correctly, and that the destination 244 // file contains exactly the bytes we expect it to contain. 245 dstoff, err := dst.Seek(0, io.SeekCurrent) 246 if err != nil { 247 t.Fatal(err) 248 } 249 srcoff, err := src.Seek(0, io.SeekCurrent) 250 if err != nil { 251 t.Fatal(err) 252 } 253 if dstoff != srcoff { 254 t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff) 255 } 256 if dstoff != int64(len(data)) { 257 t.Errorf("dstoff = %d, want %d", dstoff, len(data)) 258 } 259 if n != int64(len(data)) { 260 t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data)) 261 } 262 mustSeekStart(t, dst) 263 mustContainData(t, dst, data) 264 265 // If we had a limit, check that it was updated. 266 if lr != nil { 267 if want := limit - n; lr.N != want { 268 t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want) 269 } 270 } 271 } 272 273 // newCopyFileRangeTest initializes a new test for copy_file_range. 274 // 275 // It creates source and destination files, and populates the source file 276 // with random data of the specified size. It also hooks package os' call 277 // to poll.CopyFileRange and returns the hook so it can be inspected. 278 func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) { 279 t.Helper() 280 281 hook = hookCopyFileRange(t) 282 tmp := t.TempDir() 283 284 src, err := Create(filepath.Join(tmp, "src")) 285 if err != nil { 286 t.Fatal(err) 287 } 288 t.Cleanup(func() { src.Close() }) 289 290 dst, err = Create(filepath.Join(tmp, "dst")) 291 if err != nil { 292 t.Fatal(err) 293 } 294 t.Cleanup(func() { dst.Close() }) 295 296 // Populate the source file with data, then rewind it, so it can be 297 // consumed by copy_file_range(2). 298 prng := rand.New(rand.NewSource(time.Now().Unix())) 299 data = make([]byte, size) 300 prng.Read(data) 301 if _, err := src.Write(data); err != nil { 302 t.Fatal(err) 303 } 304 if _, err := src.Seek(0, io.SeekStart); err != nil { 305 t.Fatal(err) 306 } 307 308 return dst, src, data, hook 309 } 310 311 // mustContainData ensures that the specified file contains exactly the 312 // specified data. 313 func mustContainData(t *testing.T, f *File, data []byte) { 314 t.Helper() 315 316 got := make([]byte, len(data)) 317 if _, err := io.ReadFull(f, got); err != nil { 318 t.Fatal(err) 319 } 320 if !bytes.Equal(got, data) { 321 t.Fatalf("didn't get the same data back from %s", f.Name()) 322 } 323 if _, err := f.Read(make([]byte, 1)); err != io.EOF { 324 t.Fatalf("not at EOF") 325 } 326 } 327 328 func mustSeekStart(t *testing.T, f *File) { 329 if _, err := f.Seek(0, io.SeekStart); err != nil { 330 t.Fatal(err) 331 } 332 } 333 334 func hookCopyFileRange(t *testing.T) *copyFileRangeHook { 335 h := new(copyFileRangeHook) 336 h.install() 337 t.Cleanup(h.uninstall) 338 return h 339 } 340 341 type copyFileRangeHook struct { 342 called bool 343 dstfd int 344 srcfd int 345 remain int64 346 347 original func(dst, src *poll.FD, remain int64) (int64, bool, error) 348 } 349 350 func (h *copyFileRangeHook) install() { 351 h.original = *PollCopyFileRangeP 352 *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) { 353 h.called = true 354 h.dstfd = dst.Sysfd 355 h.srcfd = src.Sysfd 356 h.remain = remain 357 return h.original(dst, src, remain) 358 } 359 } 360 361 func (h *copyFileRangeHook) uninstall() { 362 *PollCopyFileRangeP = h.original 363 } 364 365 // On some kernels copy_file_range fails on files in /proc. 366 func TestProcCopy(t *testing.T) { 367 const cmdlineFile = "/proc/self/cmdline" 368 cmdline, err := os.ReadFile(cmdlineFile) 369 if err != nil { 370 t.Skipf("can't read /proc file: %v", err) 371 } 372 in, err := os.Open(cmdlineFile) 373 if err != nil { 374 t.Fatal(err) 375 } 376 defer in.Close() 377 outFile := filepath.Join(t.TempDir(), "cmdline") 378 out, err := os.Create(outFile) 379 if err != nil { 380 t.Fatal(err) 381 } 382 if _, err := io.Copy(out, in); err != nil { 383 t.Fatal(err) 384 } 385 if err := out.Close(); err != nil { 386 t.Fatal(err) 387 } 388 copy, err := os.ReadFile(outFile) 389 if err != nil { 390 t.Fatal(err) 391 } 392 if !bytes.Equal(cmdline, copy) { 393 t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline) 394 } 395 }