github.com/u-root/u-root@v7.0.1-0.20200915234505-ad7babab0a8e+incompatible/pkg/securelaunch/helpers.go (about)

     1  // Copyright 2019 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package securelaunch takes integrity measurements before launching the target system.
     6  package securelaunch
     7  
     8  import (
     9  	"fmt"
    10  	"io/ioutil"
    11  	"log"
    12  	"os"
    13  	"path/filepath"
    14  	"strings"
    15  	"sync"
    16  
    17  	"github.com/u-root/u-root/pkg/mount"
    18  	"github.com/u-root/u-root/pkg/mount/block"
    19  )
    20  
    21  type persistDataItem struct {
    22  	desc        string // Description
    23  	data        []byte
    24  	location    string // of form sda:/path/to/file
    25  	defaultFile string // if location turns out to be dir only
    26  }
    27  
    28  var persistData []persistDataItem
    29  
    30  type mountCacheData struct {
    31  	flags     uintptr
    32  	mountPath string
    33  }
    34  
    35  type mountCacheType struct {
    36  	m  map[string]mountCacheData
    37  	mu sync.RWMutex
    38  }
    39  
    40  // mountCache is used by sluinit to reduce number of mount/unmount operations
    41  var mountCache = mountCacheType{m: make(map[string]mountCacheData)}
    42  
    43  // StorageBlkDevices helps securelaunch pkg mount devices.
    44  var StorageBlkDevices block.BlockDevices
    45  
    46  // Debug enables verbose logs if kernel cmd line has uroot.uinitargs=-d flag set.
    47  // kernel cmdline is checked in sluinit.
    48  var Debug = func(string, ...interface{}) {}
    49  
    50  // WriteToFile writes a byte slice to a target file on an
    51  // already mounted disk and returns the target file path.
    52  //
    53  // defFileName is default dst file name, only used if user doesn't provide one.
    54  func WriteToFile(data []byte, dst, defFileName string) (string, error) {
    55  
    56  	// make sure dst is an absolute file path
    57  	if !filepath.IsAbs(dst) {
    58  		return "", fmt.Errorf("dst =%s Not an absolute path ", dst)
    59  	}
    60  
    61  	// target is the full absolute path where []byte will be written to
    62  	target := dst
    63  	dstInfo, err := os.Stat(dst)
    64  	if err == nil && dstInfo.IsDir() {
    65  		Debug("No file name provided. Adding it now. old target=%s", target)
    66  		target = filepath.Join(dst, defFileName)
    67  		Debug("New target=%s", target)
    68  	}
    69  
    70  	Debug("WriteToFile: target=%s", target)
    71  	err = ioutil.WriteFile(target, data, 0644)
    72  	if err != nil {
    73  		return "", fmt.Errorf("failed to write date to file =%s, err=%v", target, err)
    74  	}
    75  	Debug("WriteToFile: exit w success data written to target=%s", target)
    76  	return target, nil
    77  }
    78  
    79  // persist writes data to targetPath.
    80  // targetPath is of form sda:/boot/cpuid.txt
    81  func persist(data []byte, targetPath string, defaultFile string) error {
    82  
    83  	filePath, r := GetMountedFilePath(targetPath, 0) // 0 is flag for rw mount option
    84  	if r != nil {
    85  		return fmt.Errorf("persist: err: input %s could NOT be located, err=%v", targetPath, r)
    86  	}
    87  
    88  	dst := filePath // /tmp/boot-733276578/cpuid
    89  
    90  	target, err := WriteToFile(data, dst, defaultFile)
    91  	if err != nil {
    92  		log.Printf("persist: err=%s", err)
    93  		return err
    94  	}
    95  
    96  	Debug("persist: Target File%s", target)
    97  	return nil
    98  }
    99  
   100  // AddToPersistQueue enqueues an action item to persistData slice
   101  // so that it can be deferred to the last step of sluinit.
   102  func AddToPersistQueue(desc string, data []byte, location string, defFile string) error {
   103  	persistData = append(persistData, persistDataItem{desc, data, location, defFile})
   104  	return nil
   105  }
   106  
   107  // ClearPersistQueue persists any pending data/logs to disk
   108  func ClearPersistQueue() error {
   109  	for _, entry := range persistData {
   110  		if err := persist(entry.data, entry.location, entry.defaultFile); err != nil {
   111  			return fmt.Errorf("%s: persist failed for location %s", entry.desc, entry.location)
   112  		}
   113  	}
   114  	return nil
   115  }
   116  
   117  func getDeviceFromUUID(uuid string) (*block.BlockDev, error) {
   118  	if e := GetBlkInfo(); e != nil {
   119  		return nil, fmt.Errorf("fn GetBlkInfo err=%s", e)
   120  	}
   121  	devices := StorageBlkDevices.FilterFSUUID(uuid)
   122  	Debug("%d device(s) matched with UUID=%s", len(devices), uuid)
   123  	for i, d := range devices {
   124  		Debug("No#%d ,device=%s with fsUUID=%s", i, d.Name, d.FsUUID)
   125  		return d, nil // return first device found
   126  	}
   127  	return nil, fmt.Errorf("no block device exists with UUID=%s", uuid)
   128  }
   129  
   130  func getDeviceFromName(name string) (*block.BlockDev, error) {
   131  	if e := GetBlkInfo(); e != nil {
   132  		return nil, fmt.Errorf("fn GetBlkInfo err=%s", e)
   133  	}
   134  	devices := StorageBlkDevices.FilterName(name)
   135  	Debug("%d device(s) matched with Name=%s", len(devices), name)
   136  	for i, d := range devices {
   137  		Debug("No#%d ,device=%s with fsUUID=%s", i, d.Name, d.FsUUID)
   138  		return d, nil // return first device found
   139  	}
   140  	return nil, fmt.Errorf("no block device exists with name=%s", name)
   141  }
   142  
   143  // GetStorageDevice parses input of type UUID:/tmp/foo or sda2:/tmp/foo,
   144  // and returns any matching devices.
   145  func GetStorageDevice(input string) (*block.BlockDev, error) {
   146  	device, e := getDeviceFromUUID(input)
   147  	if e != nil {
   148  		d2, e2 := getDeviceFromName(input)
   149  		if e2 != nil {
   150  			return nil, fmt.Errorf("getDeviceFromUUID: err=%v, getDeviceFromName: err=%v", e, e2)
   151  		}
   152  		device = d2
   153  	}
   154  	return device, nil
   155  }
   156  
   157  func deleteEntryMountCache(key string) {
   158  	mountCache.mu.Lock()
   159  	delete(mountCache.m, key)
   160  	mountCache.mu.Unlock()
   161  
   162  	Debug("mountCache: Deleted key %s", key)
   163  }
   164  
   165  func setMountCache(key string, val mountCacheData) {
   166  
   167  	mountCache.mu.Lock()
   168  	mountCache.m[key] = val
   169  	mountCache.mu.Unlock()
   170  
   171  	Debug("mountCache: Updated key %s, value %v", key, val)
   172  }
   173  
   174  // getMountCacheData looks up mountCache using devName as key
   175  // and clears an entry in cache if result is found with different
   176  // flags, otherwise returns the cached entry or nil.
   177  func getMountCacheData(key string, flags uintptr) (string, error) {
   178  
   179  	Debug("mountCache: Lookup with key %s", key)
   180  	cachedData, ok := mountCache.m[key]
   181  	if ok {
   182  		cachedMountPath := cachedData.mountPath
   183  		cachedFlags := cachedData.flags
   184  		Debug("mountCache: Lookup succeeded: cachedMountPath %s, cachedFlags %d found for key %s", cachedMountPath, cachedFlags, key)
   185  		if cachedFlags == flags {
   186  			return cachedMountPath, nil
   187  		}
   188  		Debug("mountCache: need to mount the same device with different flags")
   189  		Debug("mountCache: Unmounting %s first", cachedMountPath)
   190  		if e := mount.Unmount(cachedMountPath, true, false); e != nil {
   191  			log.Printf("Unmount failed for %s. PANIC", cachedMountPath)
   192  			panic(e)
   193  		}
   194  		Debug("mountCache: unmount successfull. lets delete entry in map")
   195  		deleteEntryMountCache(key)
   196  		return "", fmt.Errorf("device was already mounted: mount again")
   197  	}
   198  
   199  	return "", fmt.Errorf("mountCache: lookup failed, no key exists that matches %s", key)
   200  }
   201  
   202  // MountDevice looks up mountCache map. if no entry is found, it
   203  // mounts a device and updates cache, otherwise returns mountPath.
   204  func MountDevice(device *block.BlockDev, flags uintptr) (string, error) {
   205  
   206  	devName := device.Name
   207  
   208  	Debug("MountDevice: Checking cache first for %s", devName)
   209  	cachedMountPath, err := getMountCacheData(devName, flags)
   210  	if err == nil {
   211  		log.Printf("getMountCacheData succeeded for %s", devName)
   212  		return cachedMountPath, nil
   213  	}
   214  	Debug("MountDevice: cache lookup failed for %s", devName)
   215  
   216  	Debug("MountDevice: Attempting to mount %s with flags %d", devName, flags)
   217  	mountPath, err := ioutil.TempDir("/tmp", "slaunch-")
   218  	if err != nil {
   219  		return "", fmt.Errorf("failed to create tmp mount directory: %v", err)
   220  	}
   221  
   222  	if _, err := device.Mount(mountPath, flags); err != nil {
   223  		return "", fmt.Errorf("failed to mount %s, flags %d, err=%v", devName, flags, err)
   224  	}
   225  
   226  	Debug("MountDevice: Mounted %s with flags %d", devName, flags)
   227  	setMountCache(devName, mountCacheData{flags: flags, mountPath: mountPath}) // update cache
   228  	return mountPath, nil
   229  }
   230  
   231  // GetMountedFilePath returns a file path corresponding to a <device_identifier>:<path> user input format.
   232  // <device_identifier> may be a Linux block device identifier like sda or a FS UUID.
   233  func GetMountedFilePath(inputVal string, flags uintptr) (string, error) {
   234  	s := strings.Split(inputVal, ":")
   235  	if len(s) != 2 {
   236  		return "", fmt.Errorf("%s: Usage: <block device identifier>:<path>", inputVal)
   237  	}
   238  
   239  	// s[0] can be sda or UUID.
   240  	device, err := GetStorageDevice(s[0])
   241  	if err != nil {
   242  		return "", fmt.Errorf("fn GetStorageDevice: err = %v", err)
   243  	}
   244  
   245  	devName := device.Name
   246  	mountPath, err := MountDevice(device, flags)
   247  	if err != nil {
   248  		return "", fmt.Errorf("failed to mount %s , flags=%v, err=%v", devName, flags, err)
   249  	}
   250  
   251  	fPath := filepath.Join(mountPath, s[1]) // mountPath=/tmp/path/to/target/file if /dev/sda mounted on /tmp
   252  	return fPath, nil
   253  }
   254  
   255  // UnmountAll loops detaches any mounted device from the file heirarchy.
   256  func UnmountAll() {
   257  	Debug("UnmountAll: %d devices need to be unmounted", len(mountCache.m))
   258  	for key, mountCacheData := range mountCache.m {
   259  		cachedMountPath := mountCacheData.mountPath
   260  		Debug("UnmountAll: Unmounting %s", cachedMountPath)
   261  		if e := mount.Unmount(cachedMountPath, true, false); e != nil {
   262  			log.Printf("Unmount failed for %s. PANIC", cachedMountPath)
   263  			panic(e)
   264  		}
   265  		Debug("UnmountAll: Unmounted %s", cachedMountPath)
   266  		deleteEntryMountCache(key)
   267  		Debug("UnmountAll: Deleted key %s from cache", key)
   268  	}
   269  }
   270  
   271  // GetBlkInfo calls storage package to get information on all block devices.
   272  // The information is stored in a global variable 'StorageBlkDevices'
   273  // If the global variable is already non-zero, we skip the call to storage package.
   274  //
   275  // In debug mode, it also prints names and UUIDs for all devices.
   276  func GetBlkInfo() error {
   277  	if len(StorageBlkDevices) == 0 {
   278  		var err error
   279  		Debug("getBlkInfo: expensive function call to get block stats from storage pkg")
   280  		StorageBlkDevices, err = block.GetBlockDevices()
   281  		if err != nil {
   282  			return fmt.Errorf("getBlkInfo: storage.GetBlockDevices err=%v. Exiting", err)
   283  		}
   284  		// no block devices exist on the system.
   285  		if len(StorageBlkDevices) == 0 {
   286  			return fmt.Errorf("getBlkInfo: no block devices found")
   287  		}
   288  		// print the debug info only when expensive call to storage is made
   289  		for k, d := range StorageBlkDevices {
   290  			Debug("block device #%d: %s", k, d)
   291  		}
   292  		return nil
   293  	}
   294  	Debug("getBlkInfo: noop")
   295  	return nil
   296  }