github.com/system-transparency/u-root@v6.0.1-0.20190919065413-ed07a650de4c+incompatible/cmds/boot/fbnetboot/main.go (about)

     1  // Copyright 2017-2019 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"errors"
    11  	"flag"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"log"
    15  	"net"
    16  	"net/http"
    17  	"net/url"
    18  	"os"
    19  	"os/exec"
    20  	"path/filepath"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/insomniacslk/dhcp/dhcpv4"
    25  	"github.com/insomniacslk/dhcp/dhcpv6"
    26  	"github.com/insomniacslk/dhcp/iana"
    27  	"github.com/insomniacslk/dhcp/interfaces"
    28  	"github.com/insomniacslk/dhcp/netboot"
    29  	"github.com/u-root/u-root/pkg/crypto"
    30  	"github.com/u-root/u-root/pkg/kexec"
    31  )
    32  
    33  var (
    34  	useV4              = flag.Bool("4", false, "Get a DHCPv4 lease")
    35  	useV6              = flag.Bool("6", true, "Get a DHCPv6 lease")
    36  	ifname             = flag.String("i", "", "Interface to send packets through")
    37  	dryRun             = flag.Bool("dryrun", false, "Do everything except assigning IP addresses, changing DNS, and kexec")
    38  	doDebug            = flag.Bool("d", false, "Print debug output")
    39  	skipDHCP           = flag.Bool("skip-dhcp", false, "Skip DHCP and rely on SLAAC for network configuration. This requires -netboot-url")
    40  	overrideNetbootURL = flag.String("netboot-url", "", "Override the netboot URL normally obtained via DHCP")
    41  	readTimeout        = flag.Int("timeout", 3, "Read timeout in seconds")
    42  	dhcpRetries        = flag.Int("retries", 3, "Number of times a DHCP request is retried")
    43  	userClass          = flag.String("userclass", "", "Override DHCP User Class option")
    44  	caCertFile         = flag.String("cacerts", "/etc/cacerts.pem", "CA cert file")
    45  	skipCertVerify     = flag.Bool("skip-cert-verify", false, "Don't authenticate https certs")
    46  	doFix              = flag.Bool("fix", false, "Try to run fixmynetboot if netboot fails")
    47  )
    48  
    49  const (
    50  	interfaceUpTimeout = 10 * time.Second
    51  	maxHTTPAttempts    = 3
    52  	retryInterval      = time.Second
    53  )
    54  
    55  var banner = `
    56  
    57   _________________________________
    58  < Net booting is so hot right now >
    59   ---------------------------------
    60          \   ^__^
    61           \  (oo)\_______
    62              (__)\       )\/\
    63                  ||----w |
    64                  ||     ||
    65  
    66  `
    67  var debug = func(string, ...interface{}) {}
    68  
    69  func main() {
    70  	flag.Parse()
    71  	if *skipDHCP && *overrideNetbootURL == "" {
    72  		log.Fatal("-skip-dhcp requires -netboot-url")
    73  	}
    74  	if *doDebug {
    75  		debug = log.Printf
    76  	}
    77  	log.Print(banner)
    78  
    79  	if !*useV6 && !*useV4 {
    80  		log.Fatal("At least one of DHCPv6 and DHCPv4 is required")
    81  	}
    82  
    83  	iflist := []net.Interface{}
    84  	if *ifname != "" {
    85  		var iface *net.Interface
    86  		var err error
    87  		if iface, err = net.InterfaceByName(*ifname); err != nil {
    88  			log.Fatalf("Could not find interface %s: %v", *ifname, err)
    89  		}
    90  		iflist = append(iflist, *iface)
    91  	} else {
    92  		var err error
    93  		if iflist, err = interfaces.GetNonLoopbackInterfaces(); err != nil {
    94  			log.Fatalf("Could not obtain the list of network interfaces: %v", err)
    95  		}
    96  	}
    97  
    98  	for _, iface := range iflist {
    99  		log.Printf("Waiting for network interface %s to come up", iface.Name)
   100  		start := time.Now()
   101  		_, err := netboot.IfUp(iface.Name, interfaceUpTimeout)
   102  		if err != nil {
   103  			log.Printf("IfUp failed: %v", err)
   104  			continue
   105  		}
   106  		debug("Interface %s is up after %v", iface.Name, time.Since(start))
   107  
   108  		var dhcp []dhcpFunc
   109  		if *useV6 {
   110  			dhcp = append(dhcp, dhcp6)
   111  		}
   112  		if *useV4 {
   113  			dhcp = append(dhcp, dhcp4)
   114  		}
   115  		for _, d := range dhcp {
   116  			if err := boot(iface.Name, d); err != nil {
   117  				if *doFix {
   118  					cmd := exec.Command("fixmynetboot", iface.Name)
   119  					log.Printf("Running %s", strings.Join(cmd.Args, " "))
   120  					cmd.Stdin, cmd.Stdout, cmd.Stderr = os.Stdin, os.Stdout, os.Stderr
   121  					if err := cmd.Run(); err != nil {
   122  						log.Printf("Error calling fixmynetboot: %v", err)
   123  						log.Print("fixmynetboot failed. Check the above output to manually debug the issue.")
   124  						os.Exit(1)
   125  					}
   126  				}
   127  				log.Printf("Could not boot from %s: %v", iface.Name, err)
   128  			}
   129  		}
   130  	}
   131  
   132  	log.Fatalln("Could not boot from any interfaces")
   133  }
   134  
   135  func retryableNetError(err error) bool {
   136  	if err == nil {
   137  		return false
   138  	}
   139  	switch err := err.(type) {
   140  	case net.Error:
   141  		if err.Timeout() {
   142  			return true
   143  		}
   144  	}
   145  	return false
   146  }
   147  
   148  func retryableHTTPError(resp *http.Response) bool {
   149  	if resp == nil {
   150  		return false
   151  	}
   152  	if resp.StatusCode == 500 || resp.StatusCode == 502 {
   153  		return true
   154  	}
   155  	return false
   156  }
   157  
   158  func boot(ifname string, dhcp dhcpFunc) error {
   159  	var (
   160  		netconf  *netboot.NetConf
   161  		bootfile string
   162  		err      error
   163  	)
   164  	if *skipDHCP {
   165  		log.Print("Skipping DHCP")
   166  	} else {
   167  		// send a netboot request via DHCP
   168  		netconf, bootfile, err = dhcp(ifname)
   169  		if err != nil {
   170  			return fmt.Errorf("DHCPv6: netboot request for interface %s failed: %v", ifname, err)
   171  		}
   172  		debug("DHCP: network configuration: %+v", netconf)
   173  		if !*dryRun {
   174  			log.Printf("DHCP: configuring network interface %s with %v", ifname, netconf)
   175  			if err = netboot.ConfigureInterface(ifname, netconf); err != nil {
   176  				return fmt.Errorf("DHCP: cannot configure interface %s: %v", ifname, err)
   177  			}
   178  		}
   179  		if *overrideNetbootURL != "" {
   180  			bootfile = *overrideNetbootURL
   181  		}
   182  		log.Printf("DHCP: boot file for interface %s is %s", ifname, bootfile)
   183  	}
   184  	if *overrideNetbootURL != "" {
   185  		bootfile = *overrideNetbootURL
   186  	}
   187  	debug("DHCP: boot file URL is %s", bootfile)
   188  	// check for supported schemes
   189  	scheme, err := getScheme(bootfile)
   190  	if err != nil {
   191  		return fmt.Errorf("DHCP: cannot get scheme from URL: %v", err)
   192  	}
   193  	if scheme == "" {
   194  		return errors.New("DHCP: no valid scheme found in URL")
   195  	}
   196  
   197  	client, err := getClientForBootfile(bootfile)
   198  	if err != nil {
   199  		return fmt.Errorf("DHCP: cannot get client for %s: %v", bootfile, err)
   200  	}
   201  	log.Printf("DHCP: fetching boot file URL: %s", bootfile)
   202  
   203  	var resp *http.Response
   204  	for attempt := 0; attempt < maxHTTPAttempts; attempt++ {
   205  		log.Printf("netboot: attempt %d for http.Get", attempt+1)
   206  		req, err := http.NewRequest(http.MethodGet, bootfile, nil)
   207  		if err != nil {
   208  			return fmt.Errorf("could not build request for %s: %v", bootfile, err)
   209  		}
   210  		resp, err = client.Do(req)
   211  		if err != nil && retryableNetError(err) || retryableHTTPError(resp) {
   212  			time.Sleep(retryInterval)
   213  			continue
   214  		}
   215  		if err == nil {
   216  			break
   217  		}
   218  		return fmt.Errorf("DHCP: http.Get of %s failed: %v", bootfile, err)
   219  	}
   220  	// FIXME this will not be called if something fails after this point
   221  	defer resp.Body.Close()
   222  	if resp.StatusCode != 200 {
   223  		return fmt.Errorf("status code is not 200 OK: %d", resp.StatusCode)
   224  	}
   225  	body, err := ioutil.ReadAll(resp.Body)
   226  	if err != nil {
   227  		return fmt.Errorf("DHCP: cannot read boot file from the network: %v", err)
   228  	}
   229  	crypto.TryMeasureData(crypto.BootConfigPCR, body, bootfile)
   230  	u, err := url.Parse(bootfile)
   231  	if err != nil {
   232  		return fmt.Errorf("DHCP: cannot parse URL %s: %v", bootfile, err)
   233  	}
   234  	// extract file name component
   235  	if strings.HasSuffix(u.Path, "/") {
   236  		return fmt.Errorf("invalid file path, cannot end with '/': %s", u.Path)
   237  	}
   238  	filename := filepath.Base(u.Path)
   239  	if filename == "." || filename == "" {
   240  		return fmt.Errorf("invalid empty file name extracted from file path %s", u.Path)
   241  	}
   242  	if err = ioutil.WriteFile(filename, body, 0400); err != nil {
   243  		return fmt.Errorf("DHCP: cannot write to file %s: %v", filename, err)
   244  	}
   245  	debug("DHCP: saved boot file to %s", filename)
   246  	if !*dryRun {
   247  		log.Printf("DHCP: kexec'ing into %s", filename)
   248  		kernel, err := os.OpenFile(filename, os.O_RDONLY, 0)
   249  		if err != nil {
   250  			return fmt.Errorf("DHCP: cannot open file %s: %v", filename, err)
   251  		}
   252  		if err = kexec.FileLoad(kernel, nil /* ramfs */, "" /* cmdline */); err != nil {
   253  			return fmt.Errorf("DHCP: kexec.FileLoad failed: %v", err)
   254  		}
   255  		if err = kexec.Reboot(); err != nil {
   256  			return fmt.Errorf("DHCP: kexec.Reboot failed: %v", err)
   257  		}
   258  	}
   259  	return nil
   260  }
   261  
   262  func getScheme(urlstring string) (string, error) {
   263  	u, err := url.Parse(urlstring)
   264  	if err != nil {
   265  		return "", err
   266  	}
   267  	scheme := strings.ToLower(u.Scheme)
   268  	if scheme != "http" && scheme != "https" {
   269  		return "", fmt.Errorf("URL scheme '%s' must be http or https", scheme)
   270  	}
   271  	return scheme, nil
   272  }
   273  
   274  func loadCaCerts() (*x509.CertPool, error) {
   275  	rootCAs, err := x509.SystemCertPool()
   276  	if err != nil {
   277  		return nil, err
   278  	}
   279  	if rootCAs == nil {
   280  		debug("certs: rootCAs == nil")
   281  		rootCAs = x509.NewCertPool()
   282  	}
   283  	caCerts, err := ioutil.ReadFile(*caCertFile)
   284  	if err != nil {
   285  		return nil, fmt.Errorf("could not find cert file '%v' - %v", *caCertFile, err)
   286  	}
   287  	// TODO: Decide if this should also support compressed certs
   288  	// Might be better to have a generic compressed config API
   289  	if ok := rootCAs.AppendCertsFromPEM(caCerts); !ok {
   290  		debug("Failed to append CA Certs from %s, using system certs only", *caCertFile)
   291  	} else {
   292  		debug("CA certs appended from PEM")
   293  	}
   294  	return rootCAs, nil
   295  
   296  }
   297  
   298  func getClientForBootfile(bootfile string) (*http.Client, error) {
   299  	var client *http.Client
   300  	scheme, err := getScheme(bootfile)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  
   305  	switch scheme {
   306  	case "https":
   307  		var config *tls.Config
   308  		if *skipCertVerify {
   309  			config = &tls.Config{
   310  				InsecureSkipVerify: true,
   311  			}
   312  		} else if *caCertFile != "" {
   313  			rootCAs, err := loadCaCerts()
   314  			if err != nil {
   315  				return nil, err
   316  			}
   317  			config = &tls.Config{
   318  				RootCAs: rootCAs,
   319  			}
   320  		}
   321  		tr := &http.Transport{TLSClientConfig: config}
   322  		client = &http.Client{Transport: tr}
   323  		debug("https client setup (use certs from VPD: %t, skipCertVerify %t)",
   324  			*skipCertVerify, *caCertFile != "")
   325  	case "http":
   326  		client = &http.Client{}
   327  		debug("http client setup")
   328  	default:
   329  		return nil, fmt.Errorf("Scheme %s is unsupported", scheme)
   330  	}
   331  	return client, nil
   332  }
   333  
   334  type dhcpFunc func(string) (*netboot.NetConf, string, error)
   335  
   336  func dhcp6(ifname string) (*netboot.NetConf, string, error) {
   337  	log.Printf("Trying to obtain a DHCPv6 lease on %s", ifname)
   338  	modifiers := []dhcpv6.Modifier{
   339  		dhcpv6.WithArchType(iana.EFI_X86_64),
   340  	}
   341  	if *userClass != "" {
   342  		modifiers = append(modifiers, dhcpv6.WithUserClass([]byte(*userClass)))
   343  	}
   344  	conversation, err := netboot.RequestNetbootv6(ifname, time.Duration(*readTimeout)*time.Second, *dhcpRetries, modifiers...)
   345  	for _, m := range conversation {
   346  		debug(m.Summary())
   347  	}
   348  	if err != nil {
   349  		return nil, "", fmt.Errorf("DHCPv6: netboot request for interface %s failed: %v", ifname, err)
   350  	}
   351  	return netboot.ConversationToNetconf(conversation)
   352  }
   353  
   354  func dhcp4(ifname string) (*netboot.NetConf, string, error) {
   355  	log.Printf("Trying to obtain a DHCPv4 lease on %s", ifname)
   356  	var modifiers []dhcpv4.Modifier
   357  	if *userClass != "" {
   358  		modifiers = append(modifiers, dhcpv4.WithUserClass(*userClass, false))
   359  	}
   360  	conversation, err := netboot.RequestNetbootv4(ifname, time.Duration(*readTimeout)*time.Second, *dhcpRetries, modifiers...)
   361  	for _, m := range conversation {
   362  		debug(m.Summary())
   363  	}
   364  	if err != nil {
   365  		return nil, "", fmt.Errorf("DHCPv4: netboot request for interface %s failed: %v", ifname, err)
   366  	}
   367  	return netboot.ConversationToNetconfv4(conversation)
   368  }