github.com/vipernet-xyz/tm@v0.34.24/libs/os/os_test.go (about)

     1  package os
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  func TestCopyFile(t *testing.T) {
    14  	tmpfile, err := os.CreateTemp("", "example")
    15  	if err != nil {
    16  		t.Fatal(err)
    17  	}
    18  	defer os.Remove(tmpfile.Name())
    19  	content := []byte("hello world")
    20  	if _, err := tmpfile.Write(content); err != nil {
    21  		t.Fatal(err)
    22  	}
    23  
    24  	copyfile := fmt.Sprintf("%s.copy", tmpfile.Name())
    25  	if err := CopyFile(tmpfile.Name(), copyfile); err != nil {
    26  		t.Fatal(err)
    27  	}
    28  	if _, err := os.Stat(copyfile); os.IsNotExist(err) {
    29  		t.Fatal("copy should exist")
    30  	}
    31  	data, err := os.ReadFile(copyfile)
    32  	if err != nil {
    33  		t.Fatal(err)
    34  	}
    35  	if !bytes.Equal(data, content) {
    36  		t.Fatalf("copy file content differs: expected %v, got %v", content, data)
    37  	}
    38  	os.Remove(copyfile)
    39  }
    40  
    41  func TestEnsureDir(t *testing.T) {
    42  	tmp, err := os.MkdirTemp("", "ensure-dir")
    43  	require.NoError(t, err)
    44  	defer os.RemoveAll(tmp)
    45  
    46  	// Should be possible to create a new directory.
    47  	err = EnsureDir(filepath.Join(tmp, "dir"), 0o755)
    48  	require.NoError(t, err)
    49  	require.DirExists(t, filepath.Join(tmp, "dir"))
    50  
    51  	// Should succeed on existing directory.
    52  	err = EnsureDir(filepath.Join(tmp, "dir"), 0o755)
    53  	require.NoError(t, err)
    54  
    55  	// Should fail on file.
    56  	err = os.WriteFile(filepath.Join(tmp, "file"), []byte{}, 0o644)
    57  	require.NoError(t, err)
    58  	err = EnsureDir(filepath.Join(tmp, "file"), 0o755)
    59  	require.Error(t, err)
    60  
    61  	// Should allow symlink to dir.
    62  	err = os.Symlink(filepath.Join(tmp, "dir"), filepath.Join(tmp, "linkdir"))
    63  	require.NoError(t, err)
    64  	err = EnsureDir(filepath.Join(tmp, "linkdir"), 0o755)
    65  	require.NoError(t, err)
    66  
    67  	// Should error on symlink to file.
    68  	err = os.Symlink(filepath.Join(tmp, "file"), filepath.Join(tmp, "linkfile"))
    69  	require.NoError(t, err)
    70  	err = EnsureDir(filepath.Join(tmp, "linkfile"), 0o755)
    71  	require.Error(t, err)
    72  }
    73  
    74  // Ensure that using CopyFile does not truncate the destination file before
    75  // the origin is positively a non-directory and that it is ready for copying.
    76  // See https://github.com/vipernet-xyz/tm/issues/6427
    77  func TestTrickedTruncation(t *testing.T) {
    78  	tmpDir, err := os.MkdirTemp(os.TempDir(), "pwn_truncate")
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  	defer os.Remove(tmpDir)
    83  
    84  	originalWALPath := filepath.Join(tmpDir, "wal")
    85  	originalWALContent := []byte("I AM BECOME DEATH, DESTROYER OF ALL WORLDS!")
    86  	if err := os.WriteFile(originalWALPath, originalWALContent, 0o755); err != nil {
    87  		t.Fatal(err)
    88  	}
    89  
    90  	// 1. Sanity check.
    91  	readWAL, err := os.ReadFile(originalWALPath)
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	if !bytes.Equal(readWAL, originalWALContent) {
    96  		t.Fatalf("Cannot proceed as the content does not match\nGot:  %q\nWant: %q", readWAL, originalWALContent)
    97  	}
    98  
    99  	// 2. Now cause the truncation of the original file.
   100  	// It is absolutely legal to invoke os.Open on a directory.
   101  	if err := CopyFile(tmpDir, originalWALPath); err == nil {
   102  		t.Fatal("Expected an error")
   103  	}
   104  
   105  	// 3. Check the WAL's content
   106  	reReadWAL, err := os.ReadFile(originalWALPath)
   107  	if err != nil {
   108  		t.Fatal(err)
   109  	}
   110  	if !bytes.Equal(reReadWAL, originalWALContent) {
   111  		t.Fatalf("Oops, the WAL's content was changed :(\nGot:  %q\nWant: %q", reReadWAL, originalWALContent)
   112  	}
   113  }