github.com/Serizao/go-winio@v0.0.0-20230906082528-f02f7f4ad6e8/pkg/bindfilter/bind_filter.go (about) 1 //go:build windows 2 // +build windows 3 4 package bindfilter 5 6 import ( 7 "bytes" 8 "encoding/binary" 9 "errors" 10 "fmt" 11 "os" 12 "path/filepath" 13 "strings" 14 "unsafe" 15 16 "golang.org/x/sys/windows" 17 ) 18 19 //go:generate go run github.com/Serizao/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./bind_filter.go 20 //sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath string, virtTargetPath string, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter? 21 //sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath string) (hr error) = bindfltapi.BfRemoveMapping? 22 //sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer *byte) (hr error) = bindfltapi.BfGetMappings? 23 24 // BfSetupFilter flags. See: 25 // https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240 26 // 27 //nolint:revive // var-naming: ALL_CAPS 28 const ( 29 BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001 30 // Tells bindflt to fail mapping with STATUS_INVALID_PARAMETER if a mapping produces 31 // multiple targets. 32 BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040 33 ) 34 35 //nolint:revive // var-naming: ALL_CAPS 36 const ( 37 BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001 38 BINDFLT_GET_MAPPINGS_FLAG_SILO uint32 = 0x00000002 39 BINDFLT_GET_MAPPINGS_FLAG_USER uint32 = 0x00000004 40 ) 41 42 // ApplyFileBinding creates a global mount of the source in root, with an optional 43 // read only flag. 44 // The bind filter allows us to create mounts of directories and volumes. By default it allows 45 // us to mount multiple sources inside a single root, acting as an overlay. Files from the 46 // second source will superscede the first source that was mounted. 47 // This function disables this behavior and sets the BINDFLT_FLAG_NO_MULTIPLE_TARGETS flag 48 // on the mount. 49 func ApplyFileBinding(root, source string, readOnly bool) error { 50 // The parent directory needs to exist for the bind to work. MkdirAll stats and 51 // returns nil if the directory exists internally so we should be fine to mkdirall 52 // every time. 53 if err := os.MkdirAll(filepath.Dir(root), 0); err != nil { 54 return err 55 } 56 57 if strings.Contains(source, "Volume{") && !strings.HasSuffix(source, "\\") { 58 // Add trailing slash to volumes, otherwise we get an error when binding it to 59 // a folder. 60 source = source + "\\" 61 } 62 63 flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS 64 if readOnly { 65 flags |= BINDFLT_FLAG_READ_ONLY_MAPPING 66 } 67 68 // Set the job handle to 0 to create a global mount. 69 if err := bfSetupFilter( 70 0, 71 flags, 72 root, 73 source, 74 nil, 75 0, 76 ); err != nil { 77 return fmt.Errorf("failed to bind target %q to root %q: %w", source, root, err) 78 } 79 return nil 80 } 81 82 // RemoveFileBinding removes a mount from the root path. 83 func RemoveFileBinding(root string) error { 84 if err := bfRemoveMapping(0, root); err != nil { 85 return fmt.Errorf("removing file binding: %w", err) 86 } 87 return nil 88 } 89 90 // GetBindMappings returns a list of bind mappings that have their root on a 91 // particular volume. The volumePath parameter can be any path that exists on 92 // a volume. For example, if a number of mappings are created in C:\ProgramData\test, 93 // to get a list of those mappings, the volumePath parameter would have to be set to 94 // C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child 95 // path that exists. 96 func GetBindMappings(volumePath string) ([]BindMapping, error) { 97 rootPtr, err := windows.UTF16PtrFromString(volumePath) 98 if err != nil { 99 return nil, err 100 } 101 102 flags := BINDFLT_GET_MAPPINGS_FLAG_VOLUME 103 // allocate a large buffer for results 104 var outBuffSize uint32 = 256 * 1024 105 buf := make([]byte, outBuffSize) 106 107 if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, &buf[0]); err != nil { 108 return nil, err 109 } 110 111 if outBuffSize < 12 { 112 return nil, fmt.Errorf("invalid buffer returned") 113 } 114 115 result := buf[:outBuffSize] 116 117 // The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{} 118 headerBuffer := result[:12] 119 // The alternative to using unsafe and casting it to the above defined structures, is to manually 120 // parse the fields. Not too terrible, but not sure it'd worth the trouble. 121 header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0])) 122 123 if header.MappingCount == 0 { 124 // no mappings 125 return []BindMapping{}, nil 126 } 127 128 mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)] 129 // Get a pointer to the first mapping in the slice 130 mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0])) 131 // Get slice of mappings 132 mappings := unsafe.Slice(mappingsPointer, header.MappingCount) 133 134 mappingEntries := make([]BindMapping, header.MappingCount) 135 for i := 0; i < int(header.MappingCount); i++ { 136 bindMapping, err := getBindMappingFromBuffer(result, mappings[i]) 137 if err != nil { 138 return nil, fmt.Errorf("fetching bind mappings: %w", err) 139 } 140 mappingEntries[i] = bindMapping 141 } 142 143 return mappingEntries, nil 144 } 145 146 // mappingEntry holds information about where in the response buffer we can 147 // find information about the virtual root (the mount point) and the targets (sources) 148 // that get mounted, as well as the flags used to bind the targets to the virtual root. 149 type mappingEntry struct { 150 VirtRootLength uint32 151 VirtRootOffset uint32 152 Flags uint32 153 NumberOfTargets uint32 154 TargetEntriesOffset uint32 155 } 156 157 type mappingTargetEntry struct { 158 TargetRootLength uint32 159 TargetRootOffset uint32 160 } 161 162 // getMappingsResponseHeader represents the first 12 bytes of the BfGetMappings() response. 163 // It gives us the size of the buffer, the status of the call and the number of mappings. 164 // A response 165 type getMappingsResponseHeader struct { 166 Size uint32 167 Status uint32 168 MappingCount uint32 169 } 170 171 type BindMapping struct { 172 MountPoint string 173 Flags uint32 174 Targets []string 175 } 176 177 func decodeEntry(buffer []byte) (string, error) { 178 name := make([]uint16, len(buffer)/2) 179 err := binary.Read(bytes.NewReader(buffer), binary.LittleEndian, &name) 180 if err != nil { 181 return "", fmt.Errorf("decoding name: %w", err) 182 } 183 return windows.UTF16ToString(name), nil 184 } 185 186 func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) { 187 if len(buffer) < offset+count*6 { 188 return nil, fmt.Errorf("invalid buffer") 189 } 190 191 targets := make([]string, count) 192 for i := 0; i < count; i++ { 193 entryBuf := buffer[offset+i*8 : offset+i*8+8] 194 tgt := *(*mappingTargetEntry)(unsafe.Pointer(&entryBuf[0])) 195 if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) { 196 return nil, fmt.Errorf("invalid buffer") 197 } 198 decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : tgt.TargetRootOffset+tgt.TargetRootLength]) 199 if err != nil { 200 return nil, fmt.Errorf("decoding name: %w", err) 201 } 202 decoded, err = getFinalPath(decoded) 203 if err != nil { 204 return nil, fmt.Errorf("fetching final path: %w", err) 205 } 206 207 targets[i] = decoded 208 } 209 return targets, nil 210 } 211 212 func getFinalPath(pth string) (string, error) { 213 // BfGetMappings returns VOLUME_NAME_NT paths like \Device\HarddiskVolume2\ProgramData. 214 // These can be accessed by prepending \\.\GLOBALROOT to the path. We use this to get the 215 // DOS paths for these files. 216 if strings.HasPrefix(pth, `\Device`) { 217 pth = `\\.\GLOBALROOT` + pth 218 } 219 220 han, err := openPath(pth) 221 if err != nil { 222 return "", fmt.Errorf("fetching file handle: %w", err) 223 } 224 defer func() { 225 _ = windows.CloseHandle(han) 226 }() 227 228 buf := make([]uint16, 100) 229 var flags uint32 = 0x0 230 for { 231 n, err := windows.GetFinalPathNameByHandle(han, &buf[0], uint32(len(buf)), flags) 232 if err != nil { 233 // if we mounted a volume that does not also have a drive letter assigned, attempting to 234 // fetch the VOLUME_NAME_DOS will fail with os.ErrNotExist. Attempt to get the VOLUME_NAME_GUID. 235 if errors.Is(err, os.ErrNotExist) && flags != 0x1 { 236 flags = 0x1 237 continue 238 } 239 return "", fmt.Errorf("getting final path name: %w", err) 240 } 241 if n < uint32(len(buf)) { 242 break 243 } 244 buf = make([]uint16, n) 245 } 246 finalPath := windows.UTF16ToString(buf) 247 // We got VOLUME_NAME_DOS, we need to strip away some leading slashes. 248 // Leave unchanged if we ended up requesting VOLUME_NAME_GUID 249 if len(finalPath) > 4 && finalPath[:4] == `\\?\` && flags == 0x0 { 250 finalPath = finalPath[4:] 251 if len(finalPath) > 3 && finalPath[:3] == `UNC` { 252 // return path like \\server\share\... 253 finalPath = `\` + finalPath[3:] 254 } 255 } 256 257 return finalPath, nil 258 } 259 260 func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, error) { 261 if len(buffer) < int(entry.VirtRootOffset)+int(entry.VirtRootLength) { 262 return BindMapping{}, fmt.Errorf("invalid buffer") 263 } 264 265 src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+entry.VirtRootLength]) 266 if err != nil { 267 return BindMapping{}, fmt.Errorf("decoding entry: %w", err) 268 } 269 targets, err := getTargetsFromBuffer(buffer, int(entry.TargetEntriesOffset), int(entry.NumberOfTargets)) 270 if err != nil { 271 return BindMapping{}, fmt.Errorf("fetching targets: %w", err) 272 } 273 274 src, err = getFinalPath(src) 275 if err != nil { 276 return BindMapping{}, fmt.Errorf("fetching final path: %w", err) 277 } 278 279 return BindMapping{ 280 Flags: entry.Flags, 281 Targets: targets, 282 MountPoint: src, 283 }, nil 284 } 285 286 func openPath(path string) (windows.Handle, error) { 287 u16, err := windows.UTF16PtrFromString(path) 288 if err != nil { 289 return 0, err 290 } 291 h, err := windows.CreateFile( 292 u16, 293 0, 294 windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, 295 nil, 296 windows.OPEN_EXISTING, 297 windows.FILE_FLAG_BACKUP_SEMANTICS, // Needed to open a directory handle. 298 0) 299 if err != nil { 300 return 0, &os.PathError{ 301 Op: "CreateFile", 302 Path: path, 303 Err: err, 304 } 305 } 306 return h, nil 307 }