github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/securelaunch/helpers.go (about)

     1  // Copyright 2019 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package securelaunch takes integrity measurements before launching the target system.
     6  package securelaunch
     7  
     8  import (
     9  	"fmt"
    10  	"log"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"sync"
    15  
    16  	"github.com/mvdan/u-root-coreutils/pkg/mount"
    17  	"github.com/mvdan/u-root-coreutils/pkg/mount/block"
    18  )
    19  
    20  type persistDataItem struct {
    21  	desc        string // Description
    22  	data        []byte
    23  	location    string // of form sda:/path/to/file
    24  	defaultFile string // if location turns out to be dir only
    25  }
    26  
    27  var persistData []persistDataItem
    28  
    29  type mountCacheData struct {
    30  	flags     uintptr
    31  	mountPath string
    32  }
    33  
    34  type mountCacheType struct {
    35  	m  map[string]mountCacheData
    36  	mu sync.RWMutex
    37  }
    38  
    39  // mountCache is used by sluinit to reduce number of mount/unmount operations
    40  var mountCache = mountCacheType{m: make(map[string]mountCacheData)}
    41  
    42  // StorageBlkDevices helps securelaunch pkg mount devices.
    43  var StorageBlkDevices block.BlockDevices
    44  
    45  // Debug enables verbose logs if kernel cmd line has uroot.uinitargs=-d flag set.
    46  // kernel cmdline is checked in sluinit.
    47  var Debug = func(string, ...interface{}) {}
    48  
    49  // WriteToFile writes a byte slice to a file on an already mounted disk and
    50  // returns the file path written to.
    51  func WriteToFile(data []byte, dst, defFileName string) (string, error) {
    52  	// make sure dst is an absolute file path
    53  	if !filepath.IsAbs(dst) {
    54  		return "", fmt.Errorf("dst =%s Not an absolute path ", dst)
    55  	}
    56  
    57  	// target is the full absolute path where []byte will be written to
    58  	target := dst
    59  	dstInfo, err := os.Stat(dst)
    60  	if err == nil && dstInfo.IsDir() {
    61  		Debug("No file name provided. Adding it now. old target=%s", target)
    62  		target = filepath.Join(dst, defFileName)
    63  		Debug("New target=%s", target)
    64  	}
    65  
    66  	Debug("WriteToFile: target=%s", target)
    67  	err = os.WriteFile(target, data, 0o644)
    68  	if err != nil {
    69  		return "", fmt.Errorf("failed to write date to file =%s, err=%v", target, err)
    70  	}
    71  	Debug("WriteToFile: exit w success data written to target=%s", target)
    72  	return target, nil
    73  }
    74  
    75  // persist writes data to targetPath.
    76  // targetPath is of form sda:/boot/cpuid.txt
    77  func persist(data []byte, targetPath string, defaultFile string) error {
    78  	filePath, r := GetMountedFilePath(targetPath, 0) // 0 is flag for rw mount option
    79  	if r != nil {
    80  		return fmt.Errorf("persist: err: input %s could NOT be located, err=%v", targetPath, r)
    81  	}
    82  
    83  	dst := filePath // /tmp/boot-733276578/cpuid
    84  
    85  	target, err := WriteToFile(data, dst, defaultFile)
    86  	if err != nil {
    87  		log.Printf("persist: err=%s", err)
    88  		return err
    89  	}
    90  
    91  	Debug("persist: Target File%s", target)
    92  	return nil
    93  }
    94  
    95  // AddToPersistQueue enqueues an action item to persistData slice
    96  // so that it can be deferred to the last step of sluinit.
    97  func AddToPersistQueue(desc string, data []byte, location string, defFile string) error {
    98  	persistData = append(persistData, persistDataItem{desc, data, location, defFile})
    99  	return nil
   100  }
   101  
   102  // ClearPersistQueue persists any pending data/logs to disk
   103  func ClearPersistQueue() error {
   104  	for _, entry := range persistData {
   105  		if err := persist(entry.data, entry.location, entry.defaultFile); err != nil {
   106  			return fmt.Errorf("%s: persist failed for location %s", entry.desc, entry.location)
   107  		}
   108  	}
   109  	return nil
   110  }
   111  
   112  func getDeviceFromUUID(uuid string) (*block.BlockDev, error) {
   113  	if e := GetBlkInfo(); e != nil {
   114  		return nil, fmt.Errorf("fn GetBlkInfo err=%s", e)
   115  	}
   116  	devices := StorageBlkDevices.FilterFSUUID(uuid)
   117  	Debug("%d device(s) matched with UUID=%s", len(devices), uuid)
   118  	for i, d := range devices {
   119  		Debug("No#%d ,device=%s with fsUUID=%s", i, d.Name, d.FsUUID)
   120  		return d, nil // return first device found
   121  	}
   122  	return nil, fmt.Errorf("no block device exists with UUID=%s", uuid)
   123  }
   124  
   125  func getDeviceFromName(name string) (*block.BlockDev, error) {
   126  	if e := GetBlkInfo(); e != nil {
   127  		return nil, fmt.Errorf("fn GetBlkInfo err=%s", e)
   128  	}
   129  	devices := StorageBlkDevices.FilterName(name)
   130  	Debug("%d device(s) matched with Name=%s", len(devices), name)
   131  	for i, d := range devices {
   132  		Debug("No#%d ,device=%s with fsUUID=%s", i, d.Name, d.FsUUID)
   133  		return d, nil // return first device found
   134  	}
   135  	return nil, fmt.Errorf("no block device exists with name=%s", name)
   136  }
   137  
   138  // GetStorageDevice parses input of type UUID:/tmp/foo or sda2:/tmp/foo,
   139  // and returns any matching devices.
   140  func GetStorageDevice(input string) (*block.BlockDev, error) {
   141  	device, e := getDeviceFromUUID(input)
   142  	if e != nil {
   143  		d2, e2 := getDeviceFromName(input)
   144  		if e2 != nil {
   145  			return nil, fmt.Errorf("getDeviceFromUUID: err=%v, getDeviceFromName: err=%v", e, e2)
   146  		}
   147  		device = d2
   148  	}
   149  	return device, nil
   150  }
   151  
   152  func deleteEntryMountCache(key string) {
   153  	mountCache.mu.Lock()
   154  	delete(mountCache.m, key)
   155  	mountCache.mu.Unlock()
   156  
   157  	Debug("mountCache: Deleted key %s", key)
   158  }
   159  
   160  func setMountCache(key string, val mountCacheData) {
   161  	mountCache.mu.Lock()
   162  	mountCache.m[key] = val
   163  	mountCache.mu.Unlock()
   164  
   165  	Debug("mountCache: Updated key %s, value %v", key, val)
   166  }
   167  
   168  // getMountCacheData looks up mountCache using devName as key
   169  // and clears an entry in cache if result is found with different
   170  // flags, otherwise returns the cached entry or nil.
   171  func getMountCacheData(key string, flags uintptr) (string, error) {
   172  	Debug("mountCache: Lookup with key %s", key)
   173  	cachedData, ok := mountCache.m[key]
   174  	if ok {
   175  		cachedMountPath := cachedData.mountPath
   176  		cachedFlags := cachedData.flags
   177  		Debug("mountCache: Lookup succeeded: cachedMountPath %s, cachedFlags %d found for key %s", cachedMountPath, cachedFlags, key)
   178  		if cachedFlags == flags {
   179  			return cachedMountPath, nil
   180  		}
   181  		Debug("mountCache: need to mount the same device with different flags")
   182  		Debug("mountCache: Unmounting %s first", cachedMountPath)
   183  		if err := mount.Unmount(cachedMountPath, true, false); err != nil {
   184  			return "", fmt.Errorf("failed to unmount '%s': %w", cachedMountPath, err)
   185  		}
   186  		Debug("mountCache: unmount successfull. lets delete entry in map")
   187  		deleteEntryMountCache(key)
   188  		return "", fmt.Errorf("device was already mounted: mount again")
   189  	}
   190  
   191  	return "", fmt.Errorf("mountCache: lookup failed, no key exists that matches %s", key)
   192  }
   193  
   194  // MountDevice looks up mountCache map. if no entry is found, it
   195  // mounts a device and updates cache, otherwise returns mountPath.
   196  func MountDevice(device *block.BlockDev, flags uintptr) (string, error) {
   197  	devName := device.Name
   198  
   199  	Debug("MountDevice: Checking cache first for %s", devName)
   200  	cachedMountPath, err := getMountCacheData(devName, flags)
   201  	if err == nil {
   202  		log.Printf("getMountCacheData succeeded for %s", devName)
   203  		return cachedMountPath, nil
   204  	}
   205  	Debug("MountDevice: cache lookup failed for %s", devName)
   206  
   207  	Debug("MountDevice: Attempting to mount %s with flags %d", devName, flags)
   208  	mountPath, err := os.MkdirTemp("/tmp", "slaunch-")
   209  	if err != nil {
   210  		return "", fmt.Errorf("failed to create tmp mount directory: %v", err)
   211  	}
   212  
   213  	if _, err := device.Mount(mountPath, flags); err != nil {
   214  		return "", fmt.Errorf("failed to mount %s, flags %d, err=%v", devName, flags, err)
   215  	}
   216  
   217  	Debug("MountDevice: Mounted %s with flags %d", devName, flags)
   218  	setMountCache(devName, mountCacheData{flags: flags, mountPath: mountPath}) // update cache
   219  	return mountPath, nil
   220  }
   221  
   222  // GetMountedFilePath returns the file path corresponding to the given
   223  // <device_identifier>:<path>.
   224  // <device_identifier> is a Linux block device identifier (e.g, sda or UUID).
   225  func GetMountedFilePath(inputVal string, flags uintptr) (string, error) {
   226  	s := strings.Split(inputVal, ":")
   227  	if len(s) != 2 {
   228  		return "", fmt.Errorf("%s: Usage: <block device identifier>:<path>", inputVal)
   229  	}
   230  
   231  	// s[0] can be sda or UUID.
   232  	device, err := GetStorageDevice(s[0])
   233  	if err != nil {
   234  		return "", fmt.Errorf("fn GetStorageDevice: err = %v", err)
   235  	}
   236  
   237  	devName := device.Name
   238  	mountPath, err := MountDevice(device, flags)
   239  	if err != nil {
   240  		return "", fmt.Errorf("failed to mount %s , flags=%v, err=%v", devName, flags, err)
   241  	}
   242  
   243  	fPath := filepath.Join(mountPath, s[1])
   244  	return fPath, nil
   245  }
   246  
   247  // UnmountAll unmounts all mounted devices from the file heirarchy.
   248  func UnmountAll() error {
   249  	Debug("UnmountAll: %d devices need to be unmounted", len(mountCache.m))
   250  	for key, mountCacheData := range mountCache.m {
   251  		cachedMountPath := mountCacheData.mountPath
   252  		Debug("UnmountAll: Unmounting %s", cachedMountPath)
   253  		if err := mount.Unmount(cachedMountPath, true, false); err != nil {
   254  			return fmt.Errorf("failed to unmount '%s': %w", cachedMountPath, err)
   255  		}
   256  		Debug("UnmountAll: Unmounted %s", cachedMountPath)
   257  		deleteEntryMountCache(key)
   258  		Debug("UnmountAll: Deleted key %s from cache", key)
   259  	}
   260  
   261  	return nil
   262  }
   263  
   264  // GetBlkInfo gets information on all block devices and stores it in the
   265  // global variable 'StorageBlkDevices'. If it is called more than once, the
   266  // subsequent calls just return.
   267  //
   268  // In debug mode, it also prints names and UUIDs for all devices.
   269  func GetBlkInfo() error {
   270  	if len(StorageBlkDevices) == 0 {
   271  		var err error
   272  		Debug("getBlkInfo: expensive function call to get block stats from storage pkg")
   273  		StorageBlkDevices, err = block.GetBlockDevices()
   274  		if err != nil {
   275  			return fmt.Errorf("getBlkInfo: storage.GetBlockDevices err=%v. Exiting", err)
   276  		}
   277  		// no block devices exist on the system.
   278  		if len(StorageBlkDevices) == 0 {
   279  			return fmt.Errorf("getBlkInfo: no block devices found")
   280  		}
   281  		// print the debug info only when expensive call to storage is made
   282  		for k, d := range StorageBlkDevices {
   283  			Debug("block device #%d: %s", k, d)
   284  		}
   285  		return nil
   286  	}
   287  	Debug("getBlkInfo: noop")
   288  	return nil
   289  }