github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/os/readfrom_linux_test.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"internal/poll"
    10  	"io"
    11  	"math/rand"
    12  	"os"
    13  	. "os"
    14  	"path/filepath"
    15  	"strconv"
    16  	"strings"
    17  	"syscall"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  func TestCopyFileRange(t *testing.T) {
    23  	sizes := []int{
    24  		1,
    25  		42,
    26  		1025,
    27  		syscall.Getpagesize() + 1,
    28  		32769,
    29  	}
    30  	t.Run("Basic", func(t *testing.T) {
    31  		for _, size := range sizes {
    32  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    33  				testCopyFileRange(t, int64(size), -1)
    34  			})
    35  		}
    36  	})
    37  	t.Run("Limited", func(t *testing.T) {
    38  		t.Run("OneLess", func(t *testing.T) {
    39  			for _, size := range sizes {
    40  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    41  					testCopyFileRange(t, int64(size), int64(size)-1)
    42  				})
    43  			}
    44  		})
    45  		t.Run("Half", func(t *testing.T) {
    46  			for _, size := range sizes {
    47  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    48  					testCopyFileRange(t, int64(size), int64(size)/2)
    49  				})
    50  			}
    51  		})
    52  		t.Run("More", func(t *testing.T) {
    53  			for _, size := range sizes {
    54  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    55  					testCopyFileRange(t, int64(size), int64(size)+7)
    56  				})
    57  			}
    58  		})
    59  	})
    60  	t.Run("DoesntTryInAppendMode", func(t *testing.T) {
    61  		dst, src, data, hook := newCopyFileRangeTest(t, 42)
    62  
    63  		dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
    64  		if err != nil {
    65  			t.Fatal(err)
    66  		}
    67  		defer dst2.Close()
    68  
    69  		if _, err := io.Copy(dst2, src); err != nil {
    70  			t.Fatal(err)
    71  		}
    72  		if hook.called {
    73  			t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
    74  		}
    75  		mustSeekStart(t, dst2)
    76  		mustContainData(t, dst2, data) // through traditional means
    77  	})
    78  	t.Run("CopyFileItself", func(t *testing.T) {
    79  		hook := hookCopyFileRange(t)
    80  
    81  		f, err := os.CreateTemp("", "file-readfrom-itself-test")
    82  		if err != nil {
    83  			t.Fatalf("failed to create tmp file: %v", err)
    84  		}
    85  		t.Cleanup(func() {
    86  			f.Close()
    87  			os.Remove(f.Name())
    88  		})
    89  
    90  		data := []byte("hello world!")
    91  		if _, err := f.Write(data); err != nil {
    92  			t.Fatalf("failed to create and feed the file: %v", err)
    93  		}
    94  
    95  		if err := f.Sync(); err != nil {
    96  			t.Fatalf("failed to save the file: %v", err)
    97  		}
    98  
    99  		// Rewind it.
   100  		if _, err := f.Seek(0, io.SeekStart); err != nil {
   101  			t.Fatalf("failed to rewind the file: %v", err)
   102  		}
   103  
   104  		// Read data from the file itself.
   105  		if _, err := io.Copy(f, f); err != nil {
   106  			t.Fatalf("failed to read from the file: %v", err)
   107  		}
   108  
   109  		if !hook.called || hook.written != 0 || hook.handled || hook.err != nil {
   110  			t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err)
   111  		}
   112  
   113  		// Rewind it.
   114  		if _, err := f.Seek(0, io.SeekStart); err != nil {
   115  			t.Fatalf("failed to rewind the file: %v", err)
   116  		}
   117  
   118  		data2, err := io.ReadAll(f)
   119  		if err != nil {
   120  			t.Fatalf("failed to read from the file: %v", err)
   121  		}
   122  
   123  		// It should wind up a double of the original data.
   124  		if strings.Repeat(string(data), 2) != string(data2) {
   125  			t.Fatalf("data mismatch: %s != %s", string(data), string(data2))
   126  		}
   127  	})
   128  	t.Run("NotRegular", func(t *testing.T) {
   129  		t.Run("BothPipes", func(t *testing.T) {
   130  			hook := hookCopyFileRange(t)
   131  
   132  			pr1, pw1, err := Pipe()
   133  			if err != nil {
   134  				t.Fatal(err)
   135  			}
   136  			defer pr1.Close()
   137  			defer pw1.Close()
   138  
   139  			pr2, pw2, err := Pipe()
   140  			if err != nil {
   141  				t.Fatal(err)
   142  			}
   143  			defer pr2.Close()
   144  			defer pw2.Close()
   145  
   146  			// The pipe is empty, and PIPE_BUF is large enough
   147  			// for this, by (POSIX) definition, so there is no
   148  			// need for an additional goroutine.
   149  			data := []byte("hello")
   150  			if _, err := pw1.Write(data); err != nil {
   151  				t.Fatal(err)
   152  			}
   153  			pw1.Close()
   154  
   155  			n, err := io.Copy(pw2, pr1)
   156  			if err != nil {
   157  				t.Fatal(err)
   158  			}
   159  			if n != int64(len(data)) {
   160  				t.Fatalf("transferred %d, want %d", n, len(data))
   161  			}
   162  			if !hook.called {
   163  				t.Fatalf("should have called poll.CopyFileRange")
   164  			}
   165  			pw2.Close()
   166  			mustContainData(t, pr2, data)
   167  		})
   168  		t.Run("DstPipe", func(t *testing.T) {
   169  			dst, src, data, hook := newCopyFileRangeTest(t, 255)
   170  			dst.Close()
   171  
   172  			pr, pw, err := Pipe()
   173  			if err != nil {
   174  				t.Fatal(err)
   175  			}
   176  			defer pr.Close()
   177  			defer pw.Close()
   178  
   179  			n, err := io.Copy(pw, src)
   180  			if err != nil {
   181  				t.Fatal(err)
   182  			}
   183  			if n != int64(len(data)) {
   184  				t.Fatalf("transferred %d, want %d", n, len(data))
   185  			}
   186  			if !hook.called {
   187  				t.Fatalf("should have called poll.CopyFileRange")
   188  			}
   189  			pw.Close()
   190  			mustContainData(t, pr, data)
   191  		})
   192  		t.Run("SrcPipe", func(t *testing.T) {
   193  			dst, src, data, hook := newCopyFileRangeTest(t, 255)
   194  			src.Close()
   195  
   196  			pr, pw, err := Pipe()
   197  			if err != nil {
   198  				t.Fatal(err)
   199  			}
   200  			defer pr.Close()
   201  			defer pw.Close()
   202  
   203  			// The pipe is empty, and PIPE_BUF is large enough
   204  			// for this, by (POSIX) definition, so there is no
   205  			// need for an additional goroutine.
   206  			if _, err := pw.Write(data); err != nil {
   207  				t.Fatal(err)
   208  			}
   209  			pw.Close()
   210  
   211  			n, err := io.Copy(dst, pr)
   212  			if err != nil {
   213  				t.Fatal(err)
   214  			}
   215  			if n != int64(len(data)) {
   216  				t.Fatalf("transferred %d, want %d", n, len(data))
   217  			}
   218  			if !hook.called {
   219  				t.Fatalf("should have called poll.CopyFileRange")
   220  			}
   221  			mustSeekStart(t, dst)
   222  			mustContainData(t, dst, data)
   223  		})
   224  	})
   225  	t.Run("Nil", func(t *testing.T) {
   226  		var nilFile *File
   227  		anyFile, err := os.CreateTemp("", "")
   228  		if err != nil {
   229  			t.Fatal(err)
   230  		}
   231  		defer Remove(anyFile.Name())
   232  		defer anyFile.Close()
   233  
   234  		if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
   235  			t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
   236  		}
   237  		if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
   238  			t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
   239  		}
   240  		if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
   241  			t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
   242  		}
   243  
   244  		if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
   245  			t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
   246  		}
   247  		if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
   248  			t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
   249  		}
   250  		if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
   251  			t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
   252  		}
   253  	})
   254  }
   255  
   256  func testCopyFileRange(t *testing.T, size int64, limit int64) {
   257  	dst, src, data, hook := newCopyFileRangeTest(t, size)
   258  
   259  	// If we have a limit, wrap the reader.
   260  	var (
   261  		realsrc io.Reader
   262  		lr      *io.LimitedReader
   263  	)
   264  	if limit >= 0 {
   265  		lr = &io.LimitedReader{N: limit, R: src}
   266  		realsrc = lr
   267  		if limit < int64(len(data)) {
   268  			data = data[:limit]
   269  		}
   270  	} else {
   271  		realsrc = src
   272  	}
   273  
   274  	// Now call ReadFrom (through io.Copy), which will hopefully call
   275  	// poll.CopyFileRange.
   276  	n, err := io.Copy(dst, realsrc)
   277  	if err != nil {
   278  		t.Fatal(err)
   279  	}
   280  
   281  	// If we didn't have a limit, we should have called poll.CopyFileRange
   282  	// with the right file descriptor arguments.
   283  	if limit > 0 && !hook.called {
   284  		t.Fatal("never called poll.CopyFileRange")
   285  	}
   286  	if hook.called && hook.dstfd != int(dst.Fd()) {
   287  		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
   288  	}
   289  	if hook.called && hook.srcfd != int(src.Fd()) {
   290  		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
   291  	}
   292  
   293  	// Check that the offsets after the transfer make sense, that the size
   294  	// of the transfer was reported correctly, and that the destination
   295  	// file contains exactly the bytes we expect it to contain.
   296  	dstoff, err := dst.Seek(0, io.SeekCurrent)
   297  	if err != nil {
   298  		t.Fatal(err)
   299  	}
   300  	srcoff, err := src.Seek(0, io.SeekCurrent)
   301  	if err != nil {
   302  		t.Fatal(err)
   303  	}
   304  	if dstoff != srcoff {
   305  		t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
   306  	}
   307  	if dstoff != int64(len(data)) {
   308  		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
   309  	}
   310  	if n != int64(len(data)) {
   311  		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
   312  	}
   313  	mustSeekStart(t, dst)
   314  	mustContainData(t, dst, data)
   315  
   316  	// If we had a limit, check that it was updated.
   317  	if lr != nil {
   318  		if want := limit - n; lr.N != want {
   319  			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
   320  		}
   321  	}
   322  }
   323  
   324  // newCopyFileRangeTest initializes a new test for copy_file_range.
   325  //
   326  // It creates source and destination files, and populates the source file
   327  // with random data of the specified size. It also hooks package os' call
   328  // to poll.CopyFileRange and returns the hook so it can be inspected.
   329  func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
   330  	t.Helper()
   331  
   332  	hook = hookCopyFileRange(t)
   333  	tmp := t.TempDir()
   334  
   335  	src, err := Create(filepath.Join(tmp, "src"))
   336  	if err != nil {
   337  		t.Fatal(err)
   338  	}
   339  	t.Cleanup(func() { src.Close() })
   340  
   341  	dst, err = Create(filepath.Join(tmp, "dst"))
   342  	if err != nil {
   343  		t.Fatal(err)
   344  	}
   345  	t.Cleanup(func() { dst.Close() })
   346  
   347  	// Populate the source file with data, then rewind it, so it can be
   348  	// consumed by copy_file_range(2).
   349  	prng := rand.New(rand.NewSource(time.Now().Unix()))
   350  	data = make([]byte, size)
   351  	prng.Read(data)
   352  	if _, err := src.Write(data); err != nil {
   353  		t.Fatal(err)
   354  	}
   355  	if _, err := src.Seek(0, io.SeekStart); err != nil {
   356  		t.Fatal(err)
   357  	}
   358  
   359  	return dst, src, data, hook
   360  }
   361  
   362  // mustContainData ensures that the specified file contains exactly the
   363  // specified data.
   364  func mustContainData(t *testing.T, f *File, data []byte) {
   365  	t.Helper()
   366  
   367  	got := make([]byte, len(data))
   368  	if _, err := io.ReadFull(f, got); err != nil {
   369  		t.Fatal(err)
   370  	}
   371  	if !bytes.Equal(got, data) {
   372  		t.Fatalf("didn't get the same data back from %s", f.Name())
   373  	}
   374  	if _, err := f.Read(make([]byte, 1)); err != io.EOF {
   375  		t.Fatalf("not at EOF")
   376  	}
   377  }
   378  
   379  func mustSeekStart(t *testing.T, f *File) {
   380  	if _, err := f.Seek(0, io.SeekStart); err != nil {
   381  		t.Fatal(err)
   382  	}
   383  }
   384  
   385  func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
   386  	h := new(copyFileRangeHook)
   387  	h.install()
   388  	t.Cleanup(h.uninstall)
   389  	return h
   390  }
   391  
   392  type copyFileRangeHook struct {
   393  	called bool
   394  	dstfd  int
   395  	srcfd  int
   396  	remain int64
   397  
   398  	written int64
   399  	handled bool
   400  	err     error
   401  
   402  	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
   403  }
   404  
   405  func (h *copyFileRangeHook) install() {
   406  	h.original = *PollCopyFileRangeP
   407  	*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
   408  		h.called = true
   409  		h.dstfd = dst.Sysfd
   410  		h.srcfd = src.Sysfd
   411  		h.remain = remain
   412  		h.written, h.handled, h.err = h.original(dst, src, remain)
   413  		return h.written, h.handled, h.err
   414  	}
   415  }
   416  
   417  func (h *copyFileRangeHook) uninstall() {
   418  	*PollCopyFileRangeP = h.original
   419  }
   420  
   421  // On some kernels copy_file_range fails on files in /proc.
   422  func TestProcCopy(t *testing.T) {
   423  	const cmdlineFile = "/proc/self/cmdline"
   424  	cmdline, err := os.ReadFile(cmdlineFile)
   425  	if err != nil {
   426  		t.Skipf("can't read /proc file: %v", err)
   427  	}
   428  	in, err := os.Open(cmdlineFile)
   429  	if err != nil {
   430  		t.Fatal(err)
   431  	}
   432  	defer in.Close()
   433  	outFile := filepath.Join(t.TempDir(), "cmdline")
   434  	out, err := os.Create(outFile)
   435  	if err != nil {
   436  		t.Fatal(err)
   437  	}
   438  	if _, err := io.Copy(out, in); err != nil {
   439  		t.Fatal(err)
   440  	}
   441  	if err := out.Close(); err != nil {
   442  		t.Fatal(err)
   443  	}
   444  	copy, err := os.ReadFile(outFile)
   445  	if err != nil {
   446  		t.Fatal(err)
   447  	}
   448  	if !bytes.Equal(cmdline, copy) {
   449  		t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
   450  	}
   451  }