github.com/Serizao/go-winio@v0.0.0-20230906082528-f02f7f4ad6e8/pkg/bindfilter/bind_filter.go (about)

     1  //go:build windows
     2  // +build windows
     3  
     4  package bindfilter
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"unsafe"
    15  
    16  	"golang.org/x/sys/windows"
    17  )
    18  
    19  //go:generate go run github.com/Serizao/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./bind_filter.go
    20  //sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath string, virtTargetPath string, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter?
    21  //sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath string)  (hr error) = bindfltapi.BfRemoveMapping?
    22  //sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer *byte)  (hr error) = bindfltapi.BfGetMappings?
    23  
    24  // BfSetupFilter flags. See:
    25  // https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240
    26  //
    27  //nolint:revive // var-naming: ALL_CAPS
    28  const (
    29  	BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001
    30  	// Tells bindflt to fail mapping with STATUS_INVALID_PARAMETER if a mapping produces
    31  	// multiple targets.
    32  	BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040
    33  )
    34  
    35  //nolint:revive // var-naming: ALL_CAPS
    36  const (
    37  	BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001
    38  	BINDFLT_GET_MAPPINGS_FLAG_SILO   uint32 = 0x00000002
    39  	BINDFLT_GET_MAPPINGS_FLAG_USER   uint32 = 0x00000004
    40  )
    41  
    42  // ApplyFileBinding creates a global mount of the source in root, with an optional
    43  // read only flag.
    44  // The bind filter allows us to create mounts of directories and volumes. By default it allows
    45  // us to mount multiple sources inside a single root, acting as an overlay. Files from the
    46  // second source will superscede the first source that was mounted.
    47  // This function disables this behavior and sets the BINDFLT_FLAG_NO_MULTIPLE_TARGETS flag
    48  // on the mount.
    49  func ApplyFileBinding(root, source string, readOnly bool) error {
    50  	// The parent directory needs to exist for the bind to work. MkdirAll stats and
    51  	// returns nil if the directory exists internally so we should be fine to mkdirall
    52  	// every time.
    53  	if err := os.MkdirAll(filepath.Dir(root), 0); err != nil {
    54  		return err
    55  	}
    56  
    57  	if strings.Contains(source, "Volume{") && !strings.HasSuffix(source, "\\") {
    58  		// Add trailing slash to volumes, otherwise we get an error when binding it to
    59  		// a folder.
    60  		source = source + "\\"
    61  	}
    62  
    63  	flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS
    64  	if readOnly {
    65  		flags |= BINDFLT_FLAG_READ_ONLY_MAPPING
    66  	}
    67  
    68  	// Set the job handle to 0 to create a global mount.
    69  	if err := bfSetupFilter(
    70  		0,
    71  		flags,
    72  		root,
    73  		source,
    74  		nil,
    75  		0,
    76  	); err != nil {
    77  		return fmt.Errorf("failed to bind target %q to root %q: %w", source, root, err)
    78  	}
    79  	return nil
    80  }
    81  
    82  // RemoveFileBinding removes a mount from the root path.
    83  func RemoveFileBinding(root string) error {
    84  	if err := bfRemoveMapping(0, root); err != nil {
    85  		return fmt.Errorf("removing file binding: %w", err)
    86  	}
    87  	return nil
    88  }
    89  
    90  // GetBindMappings returns a list of bind mappings that have their root on a
    91  // particular volume. The volumePath parameter can be any path that exists on
    92  // a volume. For example, if a number of mappings are created in C:\ProgramData\test,
    93  // to get a list of those mappings, the volumePath parameter would have to be set to
    94  // C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child
    95  // path that exists.
    96  func GetBindMappings(volumePath string) ([]BindMapping, error) {
    97  	rootPtr, err := windows.UTF16PtrFromString(volumePath)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	flags := BINDFLT_GET_MAPPINGS_FLAG_VOLUME
   103  	// allocate a large buffer for results
   104  	var outBuffSize uint32 = 256 * 1024
   105  	buf := make([]byte, outBuffSize)
   106  
   107  	if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, &buf[0]); err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	if outBuffSize < 12 {
   112  		return nil, fmt.Errorf("invalid buffer returned")
   113  	}
   114  
   115  	result := buf[:outBuffSize]
   116  
   117  	// The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{}
   118  	headerBuffer := result[:12]
   119  	// The alternative to using unsafe and casting it to the above defined structures, is to manually
   120  	// parse the fields. Not too terrible, but not sure it'd worth the trouble.
   121  	header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0]))
   122  
   123  	if header.MappingCount == 0 {
   124  		// no mappings
   125  		return []BindMapping{}, nil
   126  	}
   127  
   128  	mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)]
   129  	// Get a pointer to the first mapping in the slice
   130  	mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0]))
   131  	// Get slice of mappings
   132  	mappings := unsafe.Slice(mappingsPointer, header.MappingCount)
   133  
   134  	mappingEntries := make([]BindMapping, header.MappingCount)
   135  	for i := 0; i < int(header.MappingCount); i++ {
   136  		bindMapping, err := getBindMappingFromBuffer(result, mappings[i])
   137  		if err != nil {
   138  			return nil, fmt.Errorf("fetching bind mappings: %w", err)
   139  		}
   140  		mappingEntries[i] = bindMapping
   141  	}
   142  
   143  	return mappingEntries, nil
   144  }
   145  
   146  // mappingEntry holds information about where in the response buffer we can
   147  // find information about the virtual root (the mount point) and the targets (sources)
   148  // that get mounted, as well as the flags used to bind the targets to the virtual root.
   149  type mappingEntry struct {
   150  	VirtRootLength      uint32
   151  	VirtRootOffset      uint32
   152  	Flags               uint32
   153  	NumberOfTargets     uint32
   154  	TargetEntriesOffset uint32
   155  }
   156  
   157  type mappingTargetEntry struct {
   158  	TargetRootLength uint32
   159  	TargetRootOffset uint32
   160  }
   161  
   162  // getMappingsResponseHeader represents the first 12 bytes of the BfGetMappings() response.
   163  // It gives us the size of the buffer, the status of the call and the number of mappings.
   164  // A response
   165  type getMappingsResponseHeader struct {
   166  	Size         uint32
   167  	Status       uint32
   168  	MappingCount uint32
   169  }
   170  
   171  type BindMapping struct {
   172  	MountPoint string
   173  	Flags      uint32
   174  	Targets    []string
   175  }
   176  
   177  func decodeEntry(buffer []byte) (string, error) {
   178  	name := make([]uint16, len(buffer)/2)
   179  	err := binary.Read(bytes.NewReader(buffer), binary.LittleEndian, &name)
   180  	if err != nil {
   181  		return "", fmt.Errorf("decoding name: %w", err)
   182  	}
   183  	return windows.UTF16ToString(name), nil
   184  }
   185  
   186  func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) {
   187  	if len(buffer) < offset+count*6 {
   188  		return nil, fmt.Errorf("invalid buffer")
   189  	}
   190  
   191  	targets := make([]string, count)
   192  	for i := 0; i < count; i++ {
   193  		entryBuf := buffer[offset+i*8 : offset+i*8+8]
   194  		tgt := *(*mappingTargetEntry)(unsafe.Pointer(&entryBuf[0]))
   195  		if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) {
   196  			return nil, fmt.Errorf("invalid buffer")
   197  		}
   198  		decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : tgt.TargetRootOffset+tgt.TargetRootLength])
   199  		if err != nil {
   200  			return nil, fmt.Errorf("decoding name: %w", err)
   201  		}
   202  		decoded, err = getFinalPath(decoded)
   203  		if err != nil {
   204  			return nil, fmt.Errorf("fetching final path: %w", err)
   205  		}
   206  
   207  		targets[i] = decoded
   208  	}
   209  	return targets, nil
   210  }
   211  
   212  func getFinalPath(pth string) (string, error) {
   213  	// BfGetMappings returns VOLUME_NAME_NT paths like \Device\HarddiskVolume2\ProgramData.
   214  	// These can be accessed by prepending \\.\GLOBALROOT to the path. We use this to get the
   215  	// DOS paths for these files.
   216  	if strings.HasPrefix(pth, `\Device`) {
   217  		pth = `\\.\GLOBALROOT` + pth
   218  	}
   219  
   220  	han, err := openPath(pth)
   221  	if err != nil {
   222  		return "", fmt.Errorf("fetching file handle: %w", err)
   223  	}
   224  	defer func() {
   225  		_ = windows.CloseHandle(han)
   226  	}()
   227  
   228  	buf := make([]uint16, 100)
   229  	var flags uint32 = 0x0
   230  	for {
   231  		n, err := windows.GetFinalPathNameByHandle(han, &buf[0], uint32(len(buf)), flags)
   232  		if err != nil {
   233  			// if we mounted a volume that does not also have a drive letter assigned, attempting to
   234  			// fetch the VOLUME_NAME_DOS will fail with os.ErrNotExist. Attempt to get the VOLUME_NAME_GUID.
   235  			if errors.Is(err, os.ErrNotExist) && flags != 0x1 {
   236  				flags = 0x1
   237  				continue
   238  			}
   239  			return "", fmt.Errorf("getting final path name: %w", err)
   240  		}
   241  		if n < uint32(len(buf)) {
   242  			break
   243  		}
   244  		buf = make([]uint16, n)
   245  	}
   246  	finalPath := windows.UTF16ToString(buf)
   247  	// We got VOLUME_NAME_DOS, we need to strip away some leading slashes.
   248  	// Leave unchanged if we ended up requesting VOLUME_NAME_GUID
   249  	if len(finalPath) > 4 && finalPath[:4] == `\\?\` && flags == 0x0 {
   250  		finalPath = finalPath[4:]
   251  		if len(finalPath) > 3 && finalPath[:3] == `UNC` {
   252  			// return path like \\server\share\...
   253  			finalPath = `\` + finalPath[3:]
   254  		}
   255  	}
   256  
   257  	return finalPath, nil
   258  }
   259  
   260  func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, error) {
   261  	if len(buffer) < int(entry.VirtRootOffset)+int(entry.VirtRootLength) {
   262  		return BindMapping{}, fmt.Errorf("invalid buffer")
   263  	}
   264  
   265  	src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+entry.VirtRootLength])
   266  	if err != nil {
   267  		return BindMapping{}, fmt.Errorf("decoding entry: %w", err)
   268  	}
   269  	targets, err := getTargetsFromBuffer(buffer, int(entry.TargetEntriesOffset), int(entry.NumberOfTargets))
   270  	if err != nil {
   271  		return BindMapping{}, fmt.Errorf("fetching targets: %w", err)
   272  	}
   273  
   274  	src, err = getFinalPath(src)
   275  	if err != nil {
   276  		return BindMapping{}, fmt.Errorf("fetching final path: %w", err)
   277  	}
   278  
   279  	return BindMapping{
   280  		Flags:      entry.Flags,
   281  		Targets:    targets,
   282  		MountPoint: src,
   283  	}, nil
   284  }
   285  
   286  func openPath(path string) (windows.Handle, error) {
   287  	u16, err := windows.UTF16PtrFromString(path)
   288  	if err != nil {
   289  		return 0, err
   290  	}
   291  	h, err := windows.CreateFile(
   292  		u16,
   293  		0,
   294  		windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE,
   295  		nil,
   296  		windows.OPEN_EXISTING,
   297  		windows.FILE_FLAG_BACKUP_SEMANTICS, // Needed to open a directory handle.
   298  		0)
   299  	if err != nil {
   300  		return 0, &os.PathError{
   301  			Op:   "CreateFile",
   302  			Path: path,
   303  			Err:  err,
   304  		}
   305  	}
   306  	return h, nil
   307  }