github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/core/sluinit/uinit_linux.go (about)

     1  // Copyright 2019 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"flag"
     9  	"fmt"
    10  	"log"
    11  	"os"
    12  	"os/signal"
    13  	"syscall"
    14  	"time"
    15  
    16  	"github.com/u-root/iscsinl"
    17  	"github.com/mvdan/u-root-coreutils/pkg/cmdline"
    18  	"github.com/mvdan/u-root-coreutils/pkg/dhclient"
    19  	slaunch "github.com/mvdan/u-root-coreutils/pkg/securelaunch"
    20  	"github.com/mvdan/u-root-coreutils/pkg/securelaunch/eventlog"
    21  	"github.com/mvdan/u-root-coreutils/pkg/securelaunch/policy"
    22  	"github.com/mvdan/u-root-coreutils/pkg/securelaunch/tpm"
    23  )
    24  
    25  var slDebug = flag.Bool("d", false, "enable debug logs")
    26  
    27  // step keeps track of the current step (e.g., parse policy, measure).
    28  var step = 1
    29  
    30  // printStep prints a message for the next step.
    31  func printStep(msg string) {
    32  	slaunch.Debug("******** Step %d: %s ********", step, msg)
    33  	step++
    34  }
    35  
    36  // checkDebugFlag checks if `uroot.uinitargs=-d` is set on the kernel cmdline.
    37  // If it is set, slaunch.Debug is set to log.Printf.
    38  func checkDebugFlag() {
    39  	// By default, CommandLine exits on error, but this makes it trivial to get
    40  	// a shell in u-root. Instead, continue on error and let the error handling
    41  	// code here handle it.
    42  	flag.CommandLine.Init(flag.CommandLine.Name(), flag.ContinueOnError)
    43  
    44  	flag.Parse()
    45  
    46  	if flag.NArg() > 1 {
    47  		log.Fatal("Incorrect number of arguments")
    48  	}
    49  
    50  	if *slDebug {
    51  		slaunch.Debug = log.Printf
    52  		slaunch.Debug("debug flag is set. Logging Enabled.")
    53  	}
    54  }
    55  
    56  // iscsiSpecified checks if iscsi has been set on the kernel command line.
    57  func iscsiSpecified() bool {
    58  	return cmdline.ContainsFlag("netroot") && cmdline.ContainsFlag("rd.iscsi.initator")
    59  }
    60  
    61  // scanIscsiDrives calls dhcleint to parse cmdline and iscsinl to mount iscsi
    62  // drives.
    63  func scanIscsiDrives() error {
    64  	uri, ok := cmdline.Flag("netroot")
    65  	if !ok {
    66  		return fmt.Errorf("could not get `netroot` argument")
    67  	}
    68  	slaunch.Debug("scanIscsiDrives: netroot flag is set: '%s'", uri)
    69  
    70  	initiator, ok := cmdline.Flag("rd.iscsi.initiator")
    71  	if !ok {
    72  		return fmt.Errorf("could not get `rd.iscsi.initiator` argument")
    73  	}
    74  	slaunch.Debug("scanIscsiDrives: rd.iscsi.initiator flag is set: '%s'", initiator)
    75  
    76  	target, volume, err := dhclient.ParseISCSIURI(uri)
    77  	if err != nil {
    78  		return fmt.Errorf("dhclient iSCSI parser failed: %w", err)
    79  	}
    80  
    81  	slaunch.Debug("scanIscsiDrives: resolved target: '%s'", target)
    82  	slaunch.Debug("scanIscsiDrives: resolved volume: '%s'", volume)
    83  
    84  	devices, err := iscsinl.MountIscsi(
    85  		iscsinl.WithInitiator(initiator),
    86  		iscsinl.WithTarget(target.String(), volume),
    87  		iscsinl.WithCmdsMax(128),
    88  		iscsinl.WithQueueDepth(16),
    89  		iscsinl.WithScheduler("noop"),
    90  	)
    91  	if err != nil {
    92  		return fmt.Errorf("could not mount iSCSI drive: %w", err)
    93  	}
    94  
    95  	for i := range devices {
    96  		slaunch.Debug("scanIscsiDrives: iSCSI drive mounted at '%s'", devices[i])
    97  	}
    98  
    99  	return nil
   100  }
   101  
   102  // initialize sets up the environment.
   103  func initialize() error {
   104  	printStep("Initialization")
   105  
   106  	// Check if an iSCSI drive was specified and if so, mount it.
   107  	if iscsiSpecified() {
   108  		if err := scanIscsiDrives(); err != nil {
   109  			return fmt.Errorf("failed to mount iSCSI drive: %w", err)
   110  		}
   111  	}
   112  
   113  	if err := tpm.New(); err != nil {
   114  		return fmt.Errorf("failed to get TPM device: %w", err)
   115  	}
   116  
   117  	slaunch.Debug("Initialization successfully completed")
   118  
   119  	return nil
   120  }
   121  
   122  // parsePolicy parses and gets the policy file.
   123  func parsePolicy() (*policy.Policy, error) {
   124  	printStep("Locate and parse SL policy")
   125  
   126  	p, err := policy.Get()
   127  	if err != nil {
   128  		return nil, fmt.Errorf("failed to parse policy file: %w", err)
   129  	}
   130  
   131  	slaunch.Debug("Policy file successfully parsed")
   132  
   133  	return p, nil
   134  }
   135  
   136  // collectMeasurements runs any measurements specified in the policy file.
   137  func collectMeasurements(p *policy.Policy) error {
   138  	printStep("Collect evidence")
   139  
   140  	for _, collector := range p.Collectors {
   141  		slaunch.Debug("Input Collector: %v", collector)
   142  		if err := collector.Collect(); err != nil {
   143  			log.Printf("Collector %v failed: %v", collector, err)
   144  		}
   145  	}
   146  
   147  	slaunch.Debug("Collectors completed")
   148  
   149  	return nil
   150  }
   151  
   152  // measureFiles measures relevant files (e.g., policy, kernel, initrd).
   153  func measureFiles(p *policy.Policy) error {
   154  	printStep("Measure files")
   155  
   156  	if err := policy.Measure(); err != nil {
   157  		return fmt.Errorf("failed to measure policy file: %w", err)
   158  	}
   159  
   160  	if p.Launcher.Params["kernel"] != "" {
   161  		if err := p.Launcher.MeasureKernel(); err != nil {
   162  			return fmt.Errorf("failed to measure target kernel: %w", err)
   163  		}
   164  	}
   165  
   166  	if p.Launcher.Params["initrd"] != "" {
   167  		if err := p.Launcher.MeasureInitrd(); err != nil {
   168  			return fmt.Errorf("failed to measure target initrd: %w", err)
   169  		}
   170  	}
   171  
   172  	slaunch.Debug("Files successfully measured")
   173  
   174  	return nil
   175  }
   176  
   177  // parseEventLog parses the TPM event log.
   178  func parseEventLog(p *policy.Policy) error {
   179  	printStep("Parse event log")
   180  
   181  	if err := p.EventLog.Parse(); err != nil {
   182  		return fmt.Errorf("failed to parse event log: %w", err)
   183  	}
   184  
   185  	slaunch.Debug("Event log successfully parsed")
   186  
   187  	return nil
   188  }
   189  
   190  // dumpLogs writes out any pending logs to a file on disk.
   191  func dumpLogs() error {
   192  	printStep("Dump logs to disk")
   193  
   194  	if err := eventlog.ParseEventLog(); err != nil {
   195  		return fmt.Errorf("failed to parse event log: %w", err)
   196  	}
   197  
   198  	if err := slaunch.ClearPersistQueue(); err != nil {
   199  		return fmt.Errorf("failed to clear persist queue: %w", err)
   200  	}
   201  
   202  	slaunch.Debug("Logs successfully dumped to disk")
   203  
   204  	return nil
   205  }
   206  
   207  // unmountAll unmounts all mount points.
   208  func unmountAll() error {
   209  	printStep("Unmount all")
   210  
   211  	if err := slaunch.UnmountAll(); err != nil {
   212  		return fmt.Errorf("failed to unmount all devices: %w", err)
   213  	}
   214  
   215  	slaunch.Debug("Devices successfully unmounted")
   216  
   217  	return nil
   218  }
   219  
   220  // bootTarget boots the target kernel/initrd.
   221  func bootTarget(p *policy.Policy) error {
   222  	printStep("Boot target")
   223  
   224  	if err := p.Launcher.Boot(); err != nil {
   225  		return fmt.Errorf("failed to boot target: %w", err)
   226  	}
   227  
   228  	return nil
   229  }
   230  
   231  // exit loops forever trying to reboot the system.
   232  func exit(mainErr error) {
   233  	// Print the error.
   234  	fmt.Fprintf(os.Stderr, "ERROR: Failed to boot: %v\n", mainErr)
   235  
   236  	// Dump any logs, if possible. This can help figure out what went wrong.
   237  	if err := dumpLogs(); err != nil {
   238  		fmt.Fprintf(os.Stderr, "ERROR: Could not dump logs: %v\n", err)
   239  	}
   240  
   241  	// Umount anything that might be mounted.
   242  	slaunch.UnmountAll()
   243  
   244  	// Close the connection to the TPM if it was opened.
   245  	tpm.Close()
   246  
   247  	// Loop trying to reboot the system.
   248  	for {
   249  		// Wait 5 seconds.
   250  		time.Sleep(5 * time.Second)
   251  
   252  		// Try to reboot the system.
   253  		if err := syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART); err != nil {
   254  			fmt.Fprintf(os.Stderr, "ERROR: Failed to reboot: %v\n", err)
   255  		}
   256  	}
   257  }
   258  
   259  // main parses platform policy file, and based on the inputs performs
   260  // measurements and then launches a target kernel.
   261  //
   262  // Steps followed by uinit:
   263  // 1. if debug flag is set, enable logging.
   264  // 2. gets the TPM handle
   265  // 3. Gets secure launch policy file entered by user.
   266  // 4. calls collectors to collect measurements(hashes) a.k.a evidence.
   267  func main() {
   268  	// Ignore ctrl+c
   269  	signal.Ignore(syscall.SIGINT)
   270  
   271  	checkDebugFlag()
   272  
   273  	if err := initialize(); err != nil {
   274  		exit(err)
   275  	}
   276  
   277  	p, err := parsePolicy()
   278  	if err != nil {
   279  		exit(err)
   280  	}
   281  
   282  	if err := parseEventLog(p); err != nil {
   283  		exit(err)
   284  	}
   285  
   286  	if err := collectMeasurements(p); err != nil {
   287  		exit(err)
   288  	}
   289  
   290  	if err := measureFiles(p); err != nil {
   291  		exit(err)
   292  	}
   293  
   294  	if err := dumpLogs(); err != nil {
   295  		exit(err)
   296  	}
   297  
   298  	if err := unmountAll(); err != nil {
   299  		exit(err)
   300  	}
   301  
   302  	if err := bootTarget(p); err != nil {
   303  		exit(err)
   304  	}
   305  }