github.com/Serizao/go-winio@v0.0.0-20230906082528-f02f7f4ad6e8/pkg/bindfilter/bind_filter_test.go (about)

     1  //go:build windows
     2  // +build windows
     3  
     4  package bindfilter
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"testing"
    13  
    14  	"golang.org/x/sys/windows"
    15  )
    16  
    17  func TestApplyFileBinding(t *testing.T) {
    18  	requireElevated(t)
    19  
    20  	source := t.TempDir()
    21  	destination := t.TempDir()
    22  	fileName := "testFile.txt"
    23  	srcFile := filepath.Join(source, fileName)
    24  	dstFile := filepath.Join(destination, fileName)
    25  
    26  	err := ApplyFileBinding(destination, source, false)
    27  	if err != nil {
    28  		t.Fatal(err)
    29  	}
    30  	defer removeFileBinding(t, destination)
    31  
    32  	data := []byte("bind filter test")
    33  
    34  	if err := os.WriteFile(srcFile, data, 0600); err != nil {
    35  		t.Fatal(err)
    36  	}
    37  
    38  	readData, err := os.ReadFile(dstFile)
    39  	if err != nil {
    40  		t.Fatal(err)
    41  	}
    42  
    43  	if string(readData) != string(data) {
    44  		t.Fatalf("source and destination file contents differ. Expected: %s, got: %s", string(data), string(readData))
    45  	}
    46  
    47  	// Remove the file on the mount point. The mount is not read-only, this should work.
    48  	if err := os.Remove(dstFile); err != nil {
    49  		t.Fatalf("failed to remove file from mount point: %s", err)
    50  	}
    51  
    52  	// Check that it's gone from the source as well.
    53  	if _, err := os.Stat(srcFile); err == nil {
    54  		t.Fatalf("expected file %s to be gone but is not", srcFile)
    55  	}
    56  }
    57  
    58  func removeFileBinding(t *testing.T, mountpoint string) {
    59  	t.Helper()
    60  	if err := RemoveFileBinding(mountpoint); err != nil {
    61  		t.Logf("failed to remove file binding from %s: %q", mountpoint, err)
    62  	}
    63  }
    64  
    65  func TestApplyFileBindingReadOnly(t *testing.T) {
    66  	requireElevated(t)
    67  
    68  	source := t.TempDir()
    69  	destination := t.TempDir()
    70  	fileName := "testFile.txt"
    71  	srcFile := filepath.Join(source, fileName)
    72  	dstFile := filepath.Join(destination, fileName)
    73  
    74  	err := ApplyFileBinding(destination, source, true)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	defer removeFileBinding(t, destination)
    79  
    80  	data := []byte("bind filter test")
    81  
    82  	if err := os.WriteFile(srcFile, data, 0600); err != nil {
    83  		t.Fatal(err)
    84  	}
    85  
    86  	readData, err := os.ReadFile(dstFile)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  
    91  	if string(readData) != string(data) {
    92  		t.Fatalf("source and destination file contents differ. Expected: %s, got: %s", string(data), string(readData))
    93  	}
    94  
    95  	// Attempt to remove the file on the mount point
    96  	err = os.Remove(dstFile)
    97  	if err == nil {
    98  		t.Fatalf("should not be able to remove a file from a read-only mount")
    99  	}
   100  	if !errors.Is(err, os.ErrPermission) {
   101  		t.Fatalf("expected an access denied error, got: %q", err)
   102  	}
   103  
   104  	// Attempt to write on the read-only mount point.
   105  	err = os.WriteFile(dstFile, []byte("something else"), 0600)
   106  	if err == nil {
   107  		t.Fatalf("should not be able to overwrite a file from a read-only mount")
   108  	}
   109  	if !errors.Is(err, os.ErrPermission) {
   110  		t.Fatalf("expected an access denied error, got: %q", err)
   111  	}
   112  }
   113  
   114  func TestEnsureOnlyOneTargetCanBeMounted(t *testing.T) {
   115  	requireElevated(t)
   116  	requireBuild(t, RS5+1) // support added after RS5
   117  
   118  	source := t.TempDir()
   119  	secondarySource := t.TempDir()
   120  	destination := t.TempDir()
   121  
   122  	err := ApplyFileBinding(destination, source, false)
   123  	if err != nil {
   124  		t.Fatal(err)
   125  	}
   126  
   127  	defer removeFileBinding(t, destination)
   128  
   129  	err = ApplyFileBinding(destination, secondarySource, false)
   130  	if err == nil {
   131  		removeFileBinding(t, destination)
   132  		t.Fatalf("we should not be able to mount multiple targets in the same destination")
   133  	}
   134  }
   135  
   136  func checkSourceIsMountedOnDestination(src, dst string) (bool, error) {
   137  	mappings, err := GetBindMappings(dst)
   138  	if err != nil {
   139  		return false, err
   140  	}
   141  
   142  	found := false
   143  	// There may be pre-existing mappings on the system.
   144  	for _, mapping := range mappings {
   145  		if mapping.MountPoint == dst {
   146  			found = true
   147  			if len(mapping.Targets) != 1 {
   148  				return false, fmt.Errorf("expected only one target, got: %s", strings.Join(mapping.Targets, ", "))
   149  			}
   150  			if mapping.Targets[0] != src {
   151  				return false, fmt.Errorf("expected target to be %s, got %s", src, mapping.Targets[0])
   152  			}
   153  			break
   154  		}
   155  	}
   156  
   157  	return found, nil
   158  }
   159  
   160  func TestGetBindMappings(t *testing.T) {
   161  	requireElevated(t)
   162  	requireBuild(t, RS5+1) // support added after RS5
   163  
   164  	// GetBindMappings will expand short paths like ADMINI~1 and PROGRA~1 to their
   165  	// full names. In order to properly match the names later, we expand them here.
   166  	srcShort := t.TempDir()
   167  	source, err := getFinalPath(srcShort)
   168  	if err != nil {
   169  		t.Fatalf("failed to get long path")
   170  	}
   171  
   172  	dstShort := t.TempDir()
   173  	destination, err := getFinalPath(dstShort)
   174  	if err != nil {
   175  		t.Fatalf("failed to get long path")
   176  	}
   177  
   178  	err = ApplyFileBinding(destination, source, false)
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  	defer removeFileBinding(t, destination)
   183  
   184  	hasMapping, err := checkSourceIsMountedOnDestination(source, destination)
   185  	if err != nil {
   186  		t.Fatal(err)
   187  	}
   188  
   189  	if !hasMapping {
   190  		t.Fatalf("expected to find %s mounted on %s, but could not", source, destination)
   191  	}
   192  }
   193  
   194  func TestRemoveFileBinding(t *testing.T) {
   195  	requireElevated(t)
   196  
   197  	srcShort := t.TempDir()
   198  	source, err := getFinalPath(srcShort)
   199  	if err != nil {
   200  		t.Fatalf("failed to get long path")
   201  	}
   202  
   203  	dstShort := t.TempDir()
   204  	destination, err := getFinalPath(dstShort)
   205  	if err != nil {
   206  		t.Fatalf("failed to get long path")
   207  	}
   208  
   209  	fileName := "testFile.txt"
   210  	srcFile := filepath.Join(source, fileName)
   211  	dstFile := filepath.Join(destination, fileName)
   212  	data := []byte("bind filter test")
   213  
   214  	if err := os.WriteFile(srcFile, data, 0600); err != nil {
   215  		t.Fatal(err)
   216  	}
   217  
   218  	err = ApplyFileBinding(destination, source, false)
   219  	if err != nil {
   220  		t.Fatal(err)
   221  	}
   222  
   223  	if _, err := os.Stat(dstFile); err != nil {
   224  		removeFileBinding(t, destination)
   225  		t.Fatalf("expected to find %s, but could not", dstFile)
   226  	}
   227  
   228  	if err := RemoveFileBinding(destination); err != nil {
   229  		t.Fatal(err)
   230  	}
   231  
   232  	if _, err := os.Stat(dstFile); err == nil {
   233  		t.Fatalf("expected %s to be gone, but it is not", dstFile)
   234  	}
   235  }
   236  
   237  func TestGetBindMappingsSymlinks(t *testing.T) {
   238  	requireElevated(t)
   239  	requireBuild(t, RS5+1) // support added after RS5
   240  
   241  	srcShort := t.TempDir()
   242  	sourceNested := filepath.Join(srcShort, "source")
   243  	if err := os.MkdirAll(sourceNested, 0600); err != nil {
   244  		t.Fatalf("failed to create folder: %s", err)
   245  	}
   246  	simlinkSource := filepath.Join(srcShort, "symlink")
   247  	if err := os.Symlink(sourceNested, simlinkSource); err != nil {
   248  		t.Fatalf("failed to create symlink: %s", err)
   249  	}
   250  
   251  	// We'll need the long form of the source folder, as we expect bfSetupFilter()
   252  	// to resolve the symlink and create a mapping to the actual source the symlink
   253  	// points to.
   254  	source, err := getFinalPath(sourceNested)
   255  	if err != nil {
   256  		t.Fatalf("failed to get long path")
   257  	}
   258  
   259  	dstShort := t.TempDir()
   260  	destination, err := getFinalPath(dstShort)
   261  	if err != nil {
   262  		t.Fatalf("failed to get long path")
   263  	}
   264  
   265  	// Use the symlink as a source for the mapping.
   266  	err = ApplyFileBinding(destination, simlinkSource, false)
   267  	if err != nil {
   268  		t.Fatal(err)
   269  	}
   270  	defer removeFileBinding(t, destination)
   271  
   272  	// We expect the mapping to point to the folder the symlink points to, not to the
   273  	// actual symlink.
   274  	hasMapping, err := checkSourceIsMountedOnDestination(source, destination)
   275  	if err != nil {
   276  		t.Fatal(err)
   277  	}
   278  
   279  	if !hasMapping {
   280  		t.Fatalf("expected to find %s mounted on %s, but could not", source, destination)
   281  	}
   282  }
   283  
   284  func requireElevated(tb testing.TB) {
   285  	tb.Helper()
   286  	if !windows.GetCurrentProcessToken().IsElevated() {
   287  		tb.Skip("requires elevated privileges")
   288  	}
   289  }
   290  
   291  const RS5 = 17763
   292  
   293  //todo: also check that `bindfltapi.dll` exists
   294  
   295  // require current build to be >= build
   296  func requireBuild(tb testing.TB, build uint32) {
   297  	tb.Helper()
   298  	_, _, b := windows.RtlGetNtVersionNumbers()
   299  	if b < build {
   300  		tb.Skipf("requires build %d+; current build is %d", build, b)
   301  	}
   302  }