github.hscsec.cn/u-root/u-root@v7.0.0+incompatible/pkg/mount/block/blockdev.go (about)

     1  // Copyright 2017-2019 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package block finds, mounts, and modifies block devices on Linux systems.
     6  package block
     7  
     8  import (
     9  	"bufio"
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"log"
    16  	"os"
    17  	"path/filepath"
    18  	"strconv"
    19  	"strings"
    20  	"unsafe"
    21  
    22  	"github.com/rekby/gpt"
    23  	"github.com/u-root/u-root/pkg/mount"
    24  	"github.com/u-root/u-root/pkg/pci"
    25  	"golang.org/x/sys/unix"
    26  )
    27  
    28  var (
    29  	// LinuxMountsPath is the standard mountpoint list path
    30  	LinuxMountsPath = "/proc/mounts"
    31  
    32  	Debug = func(string, ...interface{}) {}
    33  )
    34  
    35  // BlockDev maps a device name to a BlockStat structure for a given block device
    36  type BlockDev struct {
    37  	Name   string
    38  	FSType string
    39  	FsUUID string
    40  }
    41  
    42  // Device makes sure the block device exists and returns a handle to it.
    43  //
    44  // maybeDevpath can be path like /dev/sda1, /sys/class/block/sda1 or just sda1.
    45  // We will just use the last component.
    46  func Device(maybeDevpath string) (*BlockDev, error) {
    47  	devname := filepath.Base(maybeDevpath)
    48  	if _, err := os.Stat(filepath.Join("/sys/class/block", devname)); err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	devpath := filepath.Join("/dev/", devname)
    53  	if uuid, err := getFSUUID(devpath); err == nil {
    54  		return &BlockDev{Name: devname, FsUUID: uuid}, nil
    55  	}
    56  	return &BlockDev{Name: devname}, nil
    57  }
    58  
    59  // String implements fmt.Stringer.
    60  func (b *BlockDev) String() string {
    61  	if len(b.FSType) > 0 {
    62  		return fmt.Sprintf("BlockDevice(name=%s, fs_type=%s, fs_uuid=%s)", b.Name, b.FSType, b.FsUUID)
    63  	}
    64  	return fmt.Sprintf("BlockDevice(name=%s, fs_uuid=%s)", b.Name, b.FsUUID)
    65  }
    66  
    67  // DevicePath is the path to the actual device.
    68  func (b BlockDev) DevicePath() string {
    69  	return filepath.Join("/dev/", b.Name)
    70  }
    71  
    72  // Mount implements mount.Mounter.
    73  func (b *BlockDev) Mount(path string, flags uintptr) (*mount.MountPoint, error) {
    74  	devpath := filepath.Join("/dev", b.Name)
    75  	if len(b.FSType) > 0 {
    76  		return mount.Mount(devpath, path, b.FSType, "", flags)
    77  	}
    78  
    79  	return mount.TryMount(devpath, path, "", flags)
    80  }
    81  
    82  // GPTTable tries to read a GPT table from the block device described by the
    83  // passed BlockDev object, and returns a gpt.Table object, or an error if any
    84  func (b *BlockDev) GPTTable() (*gpt.Table, error) {
    85  	fd, err := os.Open(b.DevicePath())
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	defer fd.Close()
    90  
    91  	blkSize, err := b.BlockSize()
    92  	if err != nil {
    93  		blkSize = 512
    94  	}
    95  
    96  	if _, err := fd.Seek(int64(blkSize), io.SeekStart); err != nil {
    97  		return nil, err
    98  	}
    99  	table, err := gpt.ReadTable(fd, uint64(blkSize))
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	return &table, nil
   104  }
   105  
   106  // PhysicalBlockSize returns the physical block size.
   107  func (b *BlockDev) PhysicalBlockSize() (int, error) {
   108  	f, err := os.Open(b.DevicePath())
   109  	if err != nil {
   110  		return 0, err
   111  	}
   112  	defer f.Close()
   113  	return unix.IoctlGetInt(int(f.Fd()), unix.BLKPBSZGET)
   114  }
   115  
   116  // BlockSize returns the logical block size (BLKSSZGET).
   117  func (b *BlockDev) BlockSize() (int, error) {
   118  	f, err := os.Open(b.DevicePath())
   119  	if err != nil {
   120  		return 0, err
   121  	}
   122  	defer f.Close()
   123  	return unix.IoctlGetInt(int(f.Fd()), unix.BLKSSZGET)
   124  }
   125  
   126  // KernelBlockSize returns the soft block size used inside the kernel (BLKBSZGET).
   127  func (b *BlockDev) KernelBlockSize() (int, error) {
   128  	f, err := os.Open(b.DevicePath())
   129  	if err != nil {
   130  		return 0, err
   131  	}
   132  	defer f.Close()
   133  	return unix.IoctlGetInt(int(f.Fd()), unix.BLKBSZGET)
   134  }
   135  
   136  func ioctlGetUint64(fd int, req uint) (uint64, error) {
   137  	var value uint64
   138  	_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(unsafe.Pointer(&value)))
   139  	if err != 0 {
   140  		return 0, err
   141  	}
   142  	return value, nil
   143  }
   144  
   145  // Size returns the size in bytes.
   146  func (b *BlockDev) Size() (uint64, error) {
   147  	f, err := os.Open(b.DevicePath())
   148  	if err != nil {
   149  		return 0, err
   150  	}
   151  	defer f.Close()
   152  
   153  	sz, err := ioctlGetUint64(int(f.Fd()), unix.BLKGETSIZE64)
   154  	if err != nil {
   155  		return 0, &os.PathError{
   156  			Op:   "get size",
   157  			Path: b.DevicePath(),
   158  			Err:  os.NewSyscallError("ioctl(BLKGETSIZE64)", err),
   159  		}
   160  	}
   161  	return sz, nil
   162  }
   163  
   164  // ReadPartitionTable prompts the kernel to re-read the partition table on this block device.
   165  func (b *BlockDev) ReadPartitionTable() error {
   166  	f, err := os.OpenFile(b.DevicePath(), os.O_RDWR, 0)
   167  	if err != nil {
   168  		return err
   169  	}
   170  	defer f.Close()
   171  	return unix.IoctlSetInt(int(f.Fd()), unix.BLKRRPART, 0)
   172  }
   173  
   174  // PCIInfo searches sysfs for the PCI vendor and device id.
   175  // We fill in the PCI struct with just those two elements.
   176  func (b *BlockDev) PCIInfo() (*pci.PCI, error) {
   177  	p, err := filepath.EvalSymlinks(filepath.Join("/sys/class/block", b.Name))
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	// Loop through devices until we find the actual backing pci device.
   182  	// For Example:
   183  	// /sys/class/block/nvme0n1p1 usually resolves to something like
   184  	// /sys/devices/pci..../.../.../nvme/nvme0/nvme0n1/nvme0n1p1. This leads us to the
   185  	// first partition of the first namespace of the nvme0 device. In this case, the actual pci device and vendor
   186  	// is found in nvme, three levels up. We traverse back up to the parent device
   187  	// and we keep going until we find a device and vendor file.
   188  	dp := filepath.Join(p, "device")
   189  	vp := filepath.Join(p, "vendor")
   190  	found := false
   191  	for p != "/sys/devices" {
   192  		// Check if there is a vendor and device file in this directory.
   193  		if d, err := os.Stat(dp); err == nil && !d.IsDir() {
   194  			if v, err := os.Stat(vp); err == nil && !v.IsDir() {
   195  				found = true
   196  				break
   197  			}
   198  		}
   199  		p = filepath.Dir(p)
   200  		dp = filepath.Join(p, "device")
   201  		vp = filepath.Join(p, "vendor")
   202  	}
   203  	if !found {
   204  		return nil, fmt.Errorf("Unable to find backing pci device with device and vendor files for %v", b.Name)
   205  	}
   206  
   207  	// Read both files into the pci struct and return
   208  	device, err := ioutil.ReadFile(dp)
   209  	if err != nil {
   210  		return nil, fmt.Errorf("Error reading device file: %v", err)
   211  	}
   212  	vendor, err := ioutil.ReadFile(vp)
   213  	if err != nil {
   214  		return nil, fmt.Errorf("Error reading vendor file: %v", err)
   215  	}
   216  	return &pci.PCI{
   217  		Vendor:   strings.TrimSpace(string(vendor)),
   218  		Device:   strings.TrimSpace(string(device)),
   219  		FullPath: p,
   220  	}, nil
   221  }
   222  
   223  // SystemPartitionGUID is the GUID of EFI system partitions
   224  // EFI System partitions have GUID C12A7328-F81F-11D2-BA4B-00A0C93EC93B
   225  var SystemPartitionGUID = gpt.Guid([...]byte{
   226  	0x28, 0x73, 0x2a, 0xc1,
   227  	0x1f, 0xf8,
   228  	0xd2, 0x11,
   229  	0xba, 0x4b,
   230  	0x00, 0xa0, 0xc9, 0x3e, 0xc9, 0x3b,
   231  })
   232  
   233  // GetBlockDevices iterates over /sys/class/block entries and returns a list of
   234  // BlockDev objects, or an error if any
   235  func GetBlockDevices() (BlockDevices, error) {
   236  	var blockdevs []*BlockDev
   237  	var devnames []string
   238  
   239  	root := "/sys/class/block"
   240  	err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
   241  		if err != nil {
   242  			return err
   243  		}
   244  		rel, err := filepath.Rel(root, path)
   245  		if err != nil {
   246  			return err
   247  		}
   248  		if rel == "." {
   249  			return nil
   250  		}
   251  		devnames = append(devnames, rel)
   252  		dev, err := Device(rel)
   253  		if err != nil {
   254  			return err
   255  		}
   256  		blockdevs = append(blockdevs, dev)
   257  		return nil
   258  	})
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  	return blockdevs, nil
   263  }
   264  
   265  func getFSUUID(devpath string) (string, error) {
   266  	file, err := os.Open(devpath)
   267  	if err != nil {
   268  		return "", err
   269  	}
   270  	defer file.Close()
   271  
   272  	fsuuid, err := tryFAT32(file)
   273  	if err == nil {
   274  		return fsuuid, nil
   275  	}
   276  	fsuuid, err = tryFAT16(file)
   277  	if err == nil {
   278  		return fsuuid, nil
   279  	}
   280  	fsuuid, err = tryEXT4(file)
   281  	if err == nil {
   282  		return fsuuid, nil
   283  	}
   284  	fsuuid, err = tryXFS(file)
   285  	if err == nil {
   286  		return fsuuid, nil
   287  	}
   288  	return "", fmt.Errorf("unknown UUID (not vfat, ext4, nor xfs)")
   289  }
   290  
   291  // See https://www.nongnu.org/ext2-doc/ext2.html#DISK-ORGANISATION.
   292  const (
   293  	// Offset of superblock in partition.
   294  	ext2SprblkOff = 1024
   295  
   296  	// Offset of magic number in suberblock.
   297  	ext2SprblkMagicOff  = 56
   298  	ext2SprblkMagicSize = 2
   299  
   300  	ext2SprblkMagic = 0xEF53
   301  
   302  	// Offset of UUID in superblock.
   303  	ext2SprblkUUIDOff  = 104
   304  	ext2SprblkUUIDSize = 16
   305  )
   306  
   307  func tryEXT4(file io.ReaderAt) (string, error) {
   308  	var off int64
   309  
   310  	// Read magic number.
   311  	b := make([]byte, ext2SprblkMagicSize)
   312  	off = ext2SprblkOff + ext2SprblkMagicOff
   313  	if _, err := file.ReadAt(b, off); err != nil {
   314  		return "", err
   315  	}
   316  	magic := binary.LittleEndian.Uint16(b[:2])
   317  	if magic != ext2SprblkMagic {
   318  		return "", fmt.Errorf("ext4 magic not found")
   319  	}
   320  
   321  	// Filesystem UUID.
   322  	b = make([]byte, ext2SprblkUUIDSize)
   323  	off = ext2SprblkOff + ext2SprblkUUIDOff
   324  	if _, err := file.ReadAt(b, off); err != nil {
   325  		return "", err
   326  	}
   327  
   328  	return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), nil
   329  }
   330  
   331  // See https://de.wikipedia.org/wiki/File_Allocation_Table#Aufbau.
   332  const (
   333  	fat12Magic = "FAT12   "
   334  	fat16Magic = "FAT16   "
   335  
   336  	// Offset of magic number.
   337  	fat16MagicOff  = 0x36
   338  	fat16MagicSize = 8
   339  
   340  	// Offset of filesystem ID / serial number. Treated as short filesystem UUID.
   341  	fat16IDOff  = 0x27
   342  	fat16IDSize = 4
   343  )
   344  
   345  func tryFAT16(file io.ReaderAt) (string, error) {
   346  	// Read magic number.
   347  	b := make([]byte, fat16MagicSize)
   348  	if _, err := file.ReadAt(b, fat16MagicOff); err != nil {
   349  		return "", err
   350  	}
   351  	magic := string(b)
   352  	if magic != fat16Magic && magic != fat12Magic {
   353  		return "", fmt.Errorf("fat16 magic not found")
   354  	}
   355  
   356  	// Filesystem UUID.
   357  	b = make([]byte, fat16IDSize)
   358  	if _, err := file.ReadAt(b, fat16IDOff); err != nil {
   359  		return "", err
   360  	}
   361  
   362  	return fmt.Sprintf("%02x%02x-%02x%02x", b[3], b[2], b[1], b[0]), nil
   363  }
   364  
   365  // See https://de.wikipedia.org/wiki/File_Allocation_Table#Aufbau.
   366  const (
   367  	fat32Magic = "FAT32   "
   368  
   369  	// Offset of magic number.
   370  	fat32MagicOff  = 0x52
   371  	fat32MagicSize = 8
   372  
   373  	// Offset of filesystem ID / serial number. Treated as short filesystem UUID.
   374  	fat32IDOff  = 67
   375  	fat32IDSize = 4
   376  )
   377  
   378  func tryFAT32(file io.ReaderAt) (string, error) {
   379  	// Read magic number.
   380  	b := make([]byte, fat32MagicSize)
   381  	if _, err := file.ReadAt(b, fat32MagicOff); err != nil {
   382  		return "", err
   383  	}
   384  	magic := string(b)
   385  	if magic != fat32Magic {
   386  		return "", fmt.Errorf("fat32 magic not found")
   387  	}
   388  
   389  	// Filesystem UUID.
   390  	b = make([]byte, fat32IDSize)
   391  	if _, err := file.ReadAt(b, fat32IDOff); err != nil {
   392  		return "", err
   393  	}
   394  
   395  	return fmt.Sprintf("%02x%02x-%02x%02x", b[3], b[2], b[1], b[0]), nil
   396  }
   397  
   398  const (
   399  	xfsMagic     = "XFSB"
   400  	xfsMagicSize = 4
   401  	xfsUUIDOff   = 32
   402  	xfsUUIDSize  = 16
   403  )
   404  
   405  func tryXFS(file io.ReaderAt) (string, error) {
   406  	// Read magic number.
   407  	b := make([]byte, xfsMagicSize)
   408  	if _, err := file.ReadAt(b, 0); err != nil {
   409  		return "", err
   410  	}
   411  	magic := string(b)
   412  	if magic != xfsMagic {
   413  		return "", fmt.Errorf("xfs magic not found")
   414  	}
   415  
   416  	// Filesystem UUID.
   417  	b = make([]byte, xfsUUIDSize)
   418  	if _, err := file.ReadAt(b, xfsUUIDOff); err != nil {
   419  		return "", err
   420  	}
   421  
   422  	return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), nil
   423  }
   424  
   425  // BlockDevices is a list of block devices.
   426  type BlockDevices []*BlockDev
   427  
   428  // FilterZeroSize attempts to find block devices that have at least one block
   429  // of content.
   430  //
   431  // This serves to eliminate block devices that have no backing storage, but
   432  // appear in /sys/class/block anyway (like some loop, nbd, or ram devices).
   433  func (b BlockDevices) FilterZeroSize() BlockDevices {
   434  	var nb BlockDevices
   435  	for _, device := range b {
   436  		if n, err := device.Size(); err != nil || n == 0 {
   437  			continue
   438  		}
   439  		nb = append(nb, device)
   440  	}
   441  	return nb
   442  }
   443  
   444  // FilterPartID returns partitions with the given partition ID GUID.
   445  func (b BlockDevices) FilterPartID(guid string) BlockDevices {
   446  	var names []string
   447  	for _, device := range b {
   448  		table, err := device.GPTTable()
   449  		if err != nil {
   450  			continue
   451  		}
   452  		for i, part := range table.Partitions {
   453  			if part.IsEmpty() {
   454  				continue
   455  			}
   456  			if strings.ToLower(part.Id.String()) == strings.ToLower(guid) {
   457  				names = append(names, fmt.Sprintf("%s%d", device.Name, i+1))
   458  			}
   459  		}
   460  	}
   461  	return b.FilterNames(names...)
   462  }
   463  
   464  // FilterPartType returns partitions with the given partition type GUID.
   465  func (b BlockDevices) FilterPartType(guid string) BlockDevices {
   466  	var names []string
   467  	for _, device := range b {
   468  		table, err := device.GPTTable()
   469  		if err != nil {
   470  			continue
   471  		}
   472  		for i, part := range table.Partitions {
   473  			if part.IsEmpty() {
   474  				continue
   475  			}
   476  			if strings.ToLower(part.Type.String()) == strings.ToLower(guid) {
   477  				names = append(names, fmt.Sprintf("%s%d", device.Name, i+1))
   478  			}
   479  		}
   480  	}
   481  	return b.FilterNames(names...)
   482  }
   483  
   484  // FilterNames filters block devices by the given list of device names (e.g.
   485  // /dev/sda1 sda2 /sys/class/block/sda3).
   486  func (b BlockDevices) FilterNames(names ...string) BlockDevices {
   487  	m := make(map[string]struct{})
   488  	for _, n := range names {
   489  		m[filepath.Base(n)] = struct{}{}
   490  	}
   491  
   492  	var devices BlockDevices
   493  	for _, device := range b {
   494  		if _, ok := m[device.Name]; ok {
   495  			devices = append(devices, device)
   496  		}
   497  	}
   498  	return devices
   499  }
   500  
   501  // FilterFSUUID returns a list of BlockDev objects whose underlying block
   502  // device has a filesystem with the given UUID.
   503  func (b BlockDevices) FilterFSUUID(fsuuid string) BlockDevices {
   504  	partitions := make(BlockDevices, 0)
   505  	for _, device := range b {
   506  		if device.FsUUID == fsuuid {
   507  			partitions = append(partitions, device)
   508  		}
   509  	}
   510  	return partitions
   511  }
   512  
   513  // FilterName returns a list of BlockDev objects whose underlying
   514  // block device has a Name with the given Name
   515  func (b BlockDevices) FilterName(name string) BlockDevices {
   516  	partitions := make(BlockDevices, 0)
   517  	for _, device := range b {
   518  		if device.Name == name {
   519  			partitions = append(partitions, device)
   520  		}
   521  	}
   522  	return partitions
   523  }
   524  
   525  // parsePCIBlockList parses a string in the format vendor:device,vendor:device
   526  // and returns a list of PCI devices containing the vendor and device pairs to block.
   527  func parsePCIBlockList(blockList string) (pci.Devices, error) {
   528  	pciList := pci.Devices{}
   529  	bL := strings.Split(blockList, ",")
   530  	for _, b := range bL {
   531  		p := strings.Split(b, ":")
   532  		if len(p) != 2 {
   533  			return nil, fmt.Errorf("BlockList needs to be of format vendor1:device1,vendor2:device2...! got %v", blockList)
   534  		}
   535  		// Check that values are hex and convert them to sysfs formats
   536  		// This accepts 0xABCD and turns it into 0xabcd
   537  		// abcd also turns into 0xabcd
   538  		v, err := strconv.ParseUint(strings.TrimPrefix(p[0], "0x"), 16, 16)
   539  		if err != nil {
   540  			return nil, fmt.Errorf("BlockList needs to contain a hex vendor ID, got %v, err %v", p[0], err)
   541  		}
   542  		vs := fmt.Sprintf("%#04x", v)
   543  
   544  		d, err := strconv.ParseUint(strings.TrimPrefix(p[1], "0x"), 16, 16)
   545  		if err != nil {
   546  			return nil, fmt.Errorf("BlockList needs to contain a hex device ID, got %v, err %v", p[1], err)
   547  		}
   548  		ds := fmt.Sprintf("%#04x", d)
   549  		pciList = append(pciList, &pci.PCI{Vendor: vs, Device: ds})
   550  	}
   551  	return pciList, nil
   552  }
   553  
   554  // FilterBlockPCIString parses a string in the format vendor:device,vendor:device
   555  // and returns a list of BlockDev objects whose backing pci devices do not match
   556  // the vendor:device pairs passed in. All values are treated as hex.
   557  // E.g. 0x8086:0xABCD,8086:0x1234
   558  func (b BlockDevices) FilterBlockPCIString(blocklist string) (BlockDevices, error) {
   559  	pciList, err := parsePCIBlockList(blocklist)
   560  	if err != nil {
   561  		return nil, err
   562  	}
   563  	return b.FilterBlockPCI(pciList), nil
   564  }
   565  
   566  // FilterBlockPCI returns a list of BlockDev objects whose backing
   567  // pci devices do not match the blocklist of PCI devices passed in.
   568  // FilterBlockPCI discards entries which have a matching PCI vendor
   569  // and device ID as an entry in the blocklist.
   570  func (b BlockDevices) FilterBlockPCI(blocklist pci.Devices) BlockDevices {
   571  	type mapKey struct {
   572  		vendor, device string
   573  	}
   574  	m := make(map[mapKey]bool)
   575  
   576  	for _, v := range blocklist {
   577  		m[mapKey{v.Vendor, v.Device}] = true
   578  	}
   579  	Debug("block map is %v", m)
   580  
   581  	partitions := make(BlockDevices, 0)
   582  	for _, device := range b {
   583  		p, err := device.PCIInfo()
   584  		if err != nil {
   585  			// In the case of an error, we err on the safe side and choose not to block it.
   586  			// Not all block devices are backed by a pci device, for example SATA drives.
   587  			Debug("Failed to find PCI info; %v", err)
   588  			partitions = append(partitions, device)
   589  			continue
   590  		}
   591  		if _, ok := m[mapKey{p.Vendor, p.Device}]; !ok {
   592  			// Not in blocklist, we're good to go
   593  			Debug("Not blocking device %v, with pci %v, not in map", device, p)
   594  			partitions = append(partitions, device)
   595  		} else {
   596  			log.Printf("Blocking device %v since it appears in blocklist", device.Name)
   597  		}
   598  	}
   599  	return partitions
   600  }
   601  
   602  // GetMountpointByDevice gets the mountpoint by given
   603  // device name. Returns on first match
   604  func GetMountpointByDevice(devicePath string) (*string, error) {
   605  	file, err := os.Open(LinuxMountsPath)
   606  	if err != nil {
   607  		return nil, err
   608  	}
   609  
   610  	defer file.Close()
   611  	scanner := bufio.NewScanner(file)
   612  
   613  	for scanner.Scan() {
   614  		deviceInfo := strings.Fields(scanner.Text())
   615  		if deviceInfo[0] == devicePath {
   616  			return &deviceInfo[1], nil
   617  		}
   618  	}
   619  
   620  	return nil, errors.New("Mountpoint not found")
   621  }