github.com/Serizao/go-winio@v0.0.0-20230906082528-f02f7f4ad6e8/pkg/bindfilter/bind_filter_test.go (about) 1 //go:build windows 2 // +build windows 3 4 package bindfilter 5 6 import ( 7 "errors" 8 "fmt" 9 "os" 10 "path/filepath" 11 "strings" 12 "testing" 13 14 "golang.org/x/sys/windows" 15 ) 16 17 func TestApplyFileBinding(t *testing.T) { 18 requireElevated(t) 19 20 source := t.TempDir() 21 destination := t.TempDir() 22 fileName := "testFile.txt" 23 srcFile := filepath.Join(source, fileName) 24 dstFile := filepath.Join(destination, fileName) 25 26 err := ApplyFileBinding(destination, source, false) 27 if err != nil { 28 t.Fatal(err) 29 } 30 defer removeFileBinding(t, destination) 31 32 data := []byte("bind filter test") 33 34 if err := os.WriteFile(srcFile, data, 0600); err != nil { 35 t.Fatal(err) 36 } 37 38 readData, err := os.ReadFile(dstFile) 39 if err != nil { 40 t.Fatal(err) 41 } 42 43 if string(readData) != string(data) { 44 t.Fatalf("source and destination file contents differ. Expected: %s, got: %s", string(data), string(readData)) 45 } 46 47 // Remove the file on the mount point. The mount is not read-only, this should work. 48 if err := os.Remove(dstFile); err != nil { 49 t.Fatalf("failed to remove file from mount point: %s", err) 50 } 51 52 // Check that it's gone from the source as well. 53 if _, err := os.Stat(srcFile); err == nil { 54 t.Fatalf("expected file %s to be gone but is not", srcFile) 55 } 56 } 57 58 func removeFileBinding(t *testing.T, mountpoint string) { 59 t.Helper() 60 if err := RemoveFileBinding(mountpoint); err != nil { 61 t.Logf("failed to remove file binding from %s: %q", mountpoint, err) 62 } 63 } 64 65 func TestApplyFileBindingReadOnly(t *testing.T) { 66 requireElevated(t) 67 68 source := t.TempDir() 69 destination := t.TempDir() 70 fileName := "testFile.txt" 71 srcFile := filepath.Join(source, fileName) 72 dstFile := filepath.Join(destination, fileName) 73 74 err := ApplyFileBinding(destination, source, true) 75 if err != nil { 76 t.Fatal(err) 77 } 78 defer removeFileBinding(t, destination) 79 80 data := []byte("bind filter test") 81 82 if err := os.WriteFile(srcFile, data, 0600); err != nil { 83 t.Fatal(err) 84 } 85 86 readData, err := os.ReadFile(dstFile) 87 if err != nil { 88 t.Fatal(err) 89 } 90 91 if string(readData) != string(data) { 92 t.Fatalf("source and destination file contents differ. Expected: %s, got: %s", string(data), string(readData)) 93 } 94 95 // Attempt to remove the file on the mount point 96 err = os.Remove(dstFile) 97 if err == nil { 98 t.Fatalf("should not be able to remove a file from a read-only mount") 99 } 100 if !errors.Is(err, os.ErrPermission) { 101 t.Fatalf("expected an access denied error, got: %q", err) 102 } 103 104 // Attempt to write on the read-only mount point. 105 err = os.WriteFile(dstFile, []byte("something else"), 0600) 106 if err == nil { 107 t.Fatalf("should not be able to overwrite a file from a read-only mount") 108 } 109 if !errors.Is(err, os.ErrPermission) { 110 t.Fatalf("expected an access denied error, got: %q", err) 111 } 112 } 113 114 func TestEnsureOnlyOneTargetCanBeMounted(t *testing.T) { 115 requireElevated(t) 116 requireBuild(t, RS5+1) // support added after RS5 117 118 source := t.TempDir() 119 secondarySource := t.TempDir() 120 destination := t.TempDir() 121 122 err := ApplyFileBinding(destination, source, false) 123 if err != nil { 124 t.Fatal(err) 125 } 126 127 defer removeFileBinding(t, destination) 128 129 err = ApplyFileBinding(destination, secondarySource, false) 130 if err == nil { 131 removeFileBinding(t, destination) 132 t.Fatalf("we should not be able to mount multiple targets in the same destination") 133 } 134 } 135 136 func checkSourceIsMountedOnDestination(src, dst string) (bool, error) { 137 mappings, err := GetBindMappings(dst) 138 if err != nil { 139 return false, err 140 } 141 142 found := false 143 // There may be pre-existing mappings on the system. 144 for _, mapping := range mappings { 145 if mapping.MountPoint == dst { 146 found = true 147 if len(mapping.Targets) != 1 { 148 return false, fmt.Errorf("expected only one target, got: %s", strings.Join(mapping.Targets, ", ")) 149 } 150 if mapping.Targets[0] != src { 151 return false, fmt.Errorf("expected target to be %s, got %s", src, mapping.Targets[0]) 152 } 153 break 154 } 155 } 156 157 return found, nil 158 } 159 160 func TestGetBindMappings(t *testing.T) { 161 requireElevated(t) 162 requireBuild(t, RS5+1) // support added after RS5 163 164 // GetBindMappings will expand short paths like ADMINI~1 and PROGRA~1 to their 165 // full names. In order to properly match the names later, we expand them here. 166 srcShort := t.TempDir() 167 source, err := getFinalPath(srcShort) 168 if err != nil { 169 t.Fatalf("failed to get long path") 170 } 171 172 dstShort := t.TempDir() 173 destination, err := getFinalPath(dstShort) 174 if err != nil { 175 t.Fatalf("failed to get long path") 176 } 177 178 err = ApplyFileBinding(destination, source, false) 179 if err != nil { 180 t.Fatal(err) 181 } 182 defer removeFileBinding(t, destination) 183 184 hasMapping, err := checkSourceIsMountedOnDestination(source, destination) 185 if err != nil { 186 t.Fatal(err) 187 } 188 189 if !hasMapping { 190 t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) 191 } 192 } 193 194 func TestRemoveFileBinding(t *testing.T) { 195 requireElevated(t) 196 197 srcShort := t.TempDir() 198 source, err := getFinalPath(srcShort) 199 if err != nil { 200 t.Fatalf("failed to get long path") 201 } 202 203 dstShort := t.TempDir() 204 destination, err := getFinalPath(dstShort) 205 if err != nil { 206 t.Fatalf("failed to get long path") 207 } 208 209 fileName := "testFile.txt" 210 srcFile := filepath.Join(source, fileName) 211 dstFile := filepath.Join(destination, fileName) 212 data := []byte("bind filter test") 213 214 if err := os.WriteFile(srcFile, data, 0600); err != nil { 215 t.Fatal(err) 216 } 217 218 err = ApplyFileBinding(destination, source, false) 219 if err != nil { 220 t.Fatal(err) 221 } 222 223 if _, err := os.Stat(dstFile); err != nil { 224 removeFileBinding(t, destination) 225 t.Fatalf("expected to find %s, but could not", dstFile) 226 } 227 228 if err := RemoveFileBinding(destination); err != nil { 229 t.Fatal(err) 230 } 231 232 if _, err := os.Stat(dstFile); err == nil { 233 t.Fatalf("expected %s to be gone, but it is not", dstFile) 234 } 235 } 236 237 func TestGetBindMappingsSymlinks(t *testing.T) { 238 requireElevated(t) 239 requireBuild(t, RS5+1) // support added after RS5 240 241 srcShort := t.TempDir() 242 sourceNested := filepath.Join(srcShort, "source") 243 if err := os.MkdirAll(sourceNested, 0600); err != nil { 244 t.Fatalf("failed to create folder: %s", err) 245 } 246 simlinkSource := filepath.Join(srcShort, "symlink") 247 if err := os.Symlink(sourceNested, simlinkSource); err != nil { 248 t.Fatalf("failed to create symlink: %s", err) 249 } 250 251 // We'll need the long form of the source folder, as we expect bfSetupFilter() 252 // to resolve the symlink and create a mapping to the actual source the symlink 253 // points to. 254 source, err := getFinalPath(sourceNested) 255 if err != nil { 256 t.Fatalf("failed to get long path") 257 } 258 259 dstShort := t.TempDir() 260 destination, err := getFinalPath(dstShort) 261 if err != nil { 262 t.Fatalf("failed to get long path") 263 } 264 265 // Use the symlink as a source for the mapping. 266 err = ApplyFileBinding(destination, simlinkSource, false) 267 if err != nil { 268 t.Fatal(err) 269 } 270 defer removeFileBinding(t, destination) 271 272 // We expect the mapping to point to the folder the symlink points to, not to the 273 // actual symlink. 274 hasMapping, err := checkSourceIsMountedOnDestination(source, destination) 275 if err != nil { 276 t.Fatal(err) 277 } 278 279 if !hasMapping { 280 t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) 281 } 282 } 283 284 func requireElevated(tb testing.TB) { 285 tb.Helper() 286 if !windows.GetCurrentProcessToken().IsElevated() { 287 tb.Skip("requires elevated privileges") 288 } 289 } 290 291 const RS5 = 17763 292 293 //todo: also check that `bindfltapi.dll` exists 294 295 // require current build to be >= build 296 func requireBuild(tb testing.TB, build uint32) { 297 tb.Helper() 298 _, _, b := windows.RtlGetNtVersionNumbers() 299 if b < build { 300 tb.Skipf("requires build %d+; current build is %d", build, b) 301 } 302 }