github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/boot/fbnetboot/main.go (about)

     1  // Copyright 2017-2019 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"errors"
    11  	"flag"
    12  	"fmt"
    13  	"io"
    14  	"log"
    15  	"net"
    16  	"net/http"
    17  	"net/url"
    18  	"os"
    19  	"os/exec"
    20  	"path/filepath"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/insomniacslk/dhcp/dhcpv4"
    25  	"github.com/insomniacslk/dhcp/dhcpv6"
    26  	"github.com/insomniacslk/dhcp/iana"
    27  	"github.com/insomniacslk/dhcp/interfaces"
    28  	"github.com/insomniacslk/dhcp/netboot"
    29  	"github.com/mvdan/u-root-coreutils/pkg/boot/kexec"
    30  	"github.com/mvdan/u-root-coreutils/pkg/crypto"
    31  	"github.com/mvdan/u-root-coreutils/pkg/ntpdate"
    32  )
    33  
    34  var (
    35  	useV4              = flag.Bool("4", false, "Get a DHCPv4 lease")
    36  	useV6              = flag.Bool("6", true, "Get a DHCPv6 lease")
    37  	ifname             = flag.String("i", "", "Interface to send packets through")
    38  	dryRun             = flag.Bool("dryrun", false, "Do everything except assigning IP addresses, changing DNS, and kexec")
    39  	doDebug            = flag.Bool("d", false, "Print debug output")
    40  	skipDHCP           = flag.Bool("skip-dhcp", false, "Skip DHCP and rely on SLAAC for network configuration. This requires -netboot-url")
    41  	overrideNetbootURL = flag.String("netboot-url", "", "Override the netboot URL normally obtained via DHCP")
    42  	overrideCmdline    = flag.String("cmdline", "", "Override the extra kernel command line normally obtained via DHCP")
    43  	readTimeout        = flag.Int("timeout", 3, "Read timeout in seconds")
    44  	dhcpRetries        = flag.Int("retries", 3, "Number of times a DHCP request is retried")
    45  	userClass          = flag.String("userclass", "", "Override DHCP User Class option")
    46  	caCertFile         = flag.String("cacerts", "/etc/cacerts.pem", "CA cert file")
    47  	ntpEnable          = flag.Bool("ntp", true, "Set system time via NTP")
    48  	ntpConfig          = flag.String("ntp-config", ntpdate.DefaultNTPConfig, "NTP config to use when NTP is enabled")
    49  	ntpServers         = flag.String("ntp-servers", ntpServerDHCP, fmt.Sprintf("Comma-separated list of NTP servers to query for time. %q expands to list of NTP servers received in the DHCP lease, if any.", ntpServerDHCP))
    50  	skipCertVerify     = flag.Bool("skip-cert-verify", false, "Don't authenticate https certs")
    51  	doFix              = flag.Bool("fix", false, "Try to run fixmynetboot if netboot fails")
    52  )
    53  
    54  const (
    55  	ntpServerDHCP      = "DHCP"
    56  	interfaceUpTimeout = 10 * time.Second
    57  	maxHTTPAttempts    = 3
    58  	retryInterval      = time.Second
    59  )
    60  
    61  var banner = `
    62  
    63   _________________________________
    64  < Net booting is so hot right now >
    65   ---------------------------------
    66          \   ^__^
    67           \  (oo)\_______
    68              (__)\       )\/\
    69                  ||----w |
    70                  ||     ||
    71  
    72  `
    73  var debug = func(string, ...interface{}) {}
    74  
    75  func main() {
    76  	flag.Parse()
    77  	if *skipDHCP && *overrideNetbootURL == "" {
    78  		log.Fatal("-skip-dhcp requires -netboot-url")
    79  	}
    80  	if *doDebug {
    81  		debug = log.Printf
    82  	}
    83  	log.Print(banner)
    84  
    85  	if !*useV6 && !*useV4 {
    86  		log.Fatal("At least one of DHCPv6 and DHCPv4 is required")
    87  	}
    88  
    89  	iflist := []net.Interface{}
    90  	if *ifname != "" {
    91  		var iface *net.Interface
    92  		var err error
    93  		if iface, err = net.InterfaceByName(*ifname); err != nil {
    94  			log.Fatalf("Could not find interface %s: %v", *ifname, err)
    95  		}
    96  		iflist = append(iflist, *iface)
    97  	} else {
    98  		var err error
    99  		if iflist, err = interfaces.GetNonLoopbackInterfaces(); err != nil {
   100  			log.Fatalf("Could not obtain the list of network interfaces: %v", err)
   101  		}
   102  	}
   103  
   104  	for _, iface := range iflist {
   105  		log.Printf("Waiting for network interface %s to come up", iface.Name)
   106  		start := time.Now()
   107  		_, err := netboot.IfUp(iface.Name, interfaceUpTimeout)
   108  		if err != nil {
   109  			log.Printf("IfUp failed: %v", err)
   110  			continue
   111  		}
   112  		debug("Interface %s is up after %v", iface.Name, time.Since(start))
   113  
   114  		var dhcp []dhcpFunc
   115  		if *useV6 {
   116  			dhcp = append(dhcp, dhcp6)
   117  		}
   118  		if *useV4 {
   119  			dhcp = append(dhcp, dhcp4)
   120  		}
   121  		for _, d := range dhcp {
   122  			if err := boot(iface.Name, d); err != nil {
   123  				if *doFix {
   124  					cmd := exec.Command("fixmynetboot", iface.Name)
   125  					log.Printf("Running %s", strings.Join(cmd.Args, " "))
   126  					cmd.Stdin, cmd.Stdout, cmd.Stderr = os.Stdin, os.Stdout, os.Stderr
   127  					if err := cmd.Run(); err != nil {
   128  						log.Printf("Error calling fixmynetboot: %v", err)
   129  						log.Print("fixmynetboot failed. Check the above output to manually debug the issue.")
   130  						os.Exit(1)
   131  					}
   132  				}
   133  				log.Printf("Could not boot from %s: %v", iface.Name, err)
   134  			}
   135  		}
   136  	}
   137  
   138  	log.Fatalln("Could not boot from any interfaces")
   139  }
   140  
   141  func retryableNetError(err error) bool {
   142  	if err == nil {
   143  		return false
   144  	}
   145  	var netError net.Error
   146  	if errors.As(err, &netError) && netError.Timeout() {
   147  		return true
   148  	}
   149  	return false
   150  }
   151  
   152  func retryableHTTPError(resp *http.Response) bool {
   153  	if resp == nil {
   154  		return false
   155  	}
   156  	if resp.StatusCode == 500 || resp.StatusCode == 502 {
   157  		return true
   158  	}
   159  	return false
   160  }
   161  
   162  func boot(ifname string, dhcp dhcpFunc) error {
   163  	var (
   164  		bootconf *netboot.BootConf
   165  		err      error
   166  	)
   167  	if *skipDHCP {
   168  		log.Print("Skipping DHCP")
   169  		bootconf = &netboot.BootConf{}
   170  	} else {
   171  		// send a netboot request via DHCP
   172  		bootconf, err = dhcp(ifname)
   173  		if err != nil {
   174  			return fmt.Errorf("DHCP: netboot request for interface %s failed: %v", ifname, err)
   175  		}
   176  		debug("DHCP: network configuration: %+v", bootconf.NetConf)
   177  		if !*dryRun {
   178  			log.Printf("DHCP: configuring network interface %s with %v", ifname, bootconf.NetConf)
   179  			if err = netboot.ConfigureInterface(ifname, &bootconf.NetConf); err != nil {
   180  				return fmt.Errorf("DHCP: cannot configure interface %s: %v", ifname, err)
   181  			}
   182  		}
   183  		if *overrideNetbootURL != "" {
   184  			bootconf.BootfileURL = *overrideNetbootURL
   185  		}
   186  		log.Printf("DHCP: boot file for interface %s is %s", ifname, bootconf.BootfileURL)
   187  	}
   188  	if *overrideNetbootURL != "" {
   189  		bootconf.BootfileURL = *overrideNetbootURL
   190  	}
   191  	if *overrideCmdline != "" {
   192  		bootconf.BootfileParam = []string{*overrideCmdline}
   193  	}
   194  	debug("DHCP: boot file URL is %s", bootconf.BootfileURL)
   195  	// check for supported schemes
   196  	scheme, err := getScheme(bootconf.BootfileURL)
   197  	if err != nil {
   198  		return fmt.Errorf("DHCP: cannot get scheme from URL: %v", err)
   199  	}
   200  	if scheme == "" {
   201  		return errors.New("DHCP: no valid scheme found in URL")
   202  	}
   203  
   204  	if *ntpEnable {
   205  		var servers []string
   206  		for _, s := range strings.Split(*ntpServers, ",") {
   207  			if len(s) == 0 {
   208  				continue
   209  			}
   210  			if s == ntpServerDHCP {
   211  				for _, ip := range bootconf.NTPServers {
   212  					servers = append(servers, ip.String())
   213  				}
   214  			} else {
   215  				servers = append(servers, s)
   216  			}
   217  		}
   218  		log.Printf("NTP: Servers: %v, config: %s", servers, *ntpConfig)
   219  		if server, offset, err := ntpdate.SetTime(servers, *ntpConfig, "" /* fallback */, false /* setRTC */); err == nil {
   220  			plus := ""
   221  			if offset > 0 {
   222  				plus = "+"
   223  			}
   224  			log.Printf("NTP: adjust time server %s offset %s%f sec", server, plus, offset)
   225  		} else {
   226  			log.Printf("NTP: error setting time: %v", err)
   227  		}
   228  	}
   229  
   230  	client, err := getClientForBootfile(bootconf.BootfileURL)
   231  	if err != nil {
   232  		return fmt.Errorf("DHCP: cannot get client for %s: %v", bootconf.BootfileURL, err)
   233  	}
   234  	log.Printf("DHCP: fetching boot file URL: %s", bootconf.BootfileURL)
   235  
   236  	fetch := func(url string) (*http.Response, error) {
   237  		for attempt := 0; attempt < maxHTTPAttempts; attempt++ {
   238  			if attempt > 1 {
   239  				time.Sleep(retryInterval)
   240  			}
   241  			log.Printf("netboot: attempt %d for http.Get", attempt+1)
   242  			req, err := http.NewRequest(http.MethodGet, url, nil)
   243  			if err != nil {
   244  				return nil, fmt.Errorf("could not build request for %q: %v", url, err)
   245  			}
   246  			resp, err := client.Do(req)
   247  			if err == nil {
   248  				return resp, nil
   249  			}
   250  			log.Printf("attempt failed: %v", err)
   251  			if !retryableNetError(err) && !retryableHTTPError(resp) {
   252  				break
   253  			}
   254  		}
   255  		return nil, fmt.Errorf("fetch of %q failed", url)
   256  	}
   257  	resp, err := fetch(bootconf.BootfileURL)
   258  	if err != nil {
   259  		return fmt.Errorf("failed to fetch %q: %v", bootconf.BootfileURL, err)
   260  	}
   261  	defer resp.Body.Close()
   262  	if resp.StatusCode != 200 {
   263  		return fmt.Errorf("status code is not 200 OK: %d", resp.StatusCode)
   264  	}
   265  	body, err := io.ReadAll(resp.Body)
   266  	if err != nil {
   267  		return fmt.Errorf("DHCP: cannot read boot file from the network: %v", err)
   268  	}
   269  	crypto.TryMeasureData(crypto.BootConfigPCR, body, bootconf.BootfileURL)
   270  	u, err := url.Parse(bootconf.BootfileURL)
   271  	if err != nil {
   272  		return fmt.Errorf("DHCP: cannot parse URL %s: %v", bootconf.BootfileURL, err)
   273  	}
   274  	// extract file name component
   275  	if strings.HasSuffix(u.Path, "/") {
   276  		return fmt.Errorf("invalid file path, cannot end with '/': %s", u.Path)
   277  	}
   278  	filename := filepath.Base(u.Path)
   279  	if filename == "." || filename == "" {
   280  		return fmt.Errorf("invalid empty file name extracted from file path %s", u.Path)
   281  	}
   282  	if err = os.WriteFile(filename, body, 0o400); err != nil {
   283  		return fmt.Errorf("DHCP: cannot write to file %s: %v", filename, err)
   284  	}
   285  	debug("DHCP: saved boot file to %s", filename)
   286  
   287  	cmdline := strings.Join(bootconf.BootfileParam, " ")
   288  	if !*dryRun {
   289  		log.Printf("DHCP: kexec'ing into %s (with arguments: \"%s\")", filename, cmdline)
   290  		kernel, err := os.OpenFile(filename, os.O_RDONLY, 0)
   291  		if err != nil {
   292  			return fmt.Errorf("DHCP: cannot open file %s: %v", filename, err)
   293  		}
   294  		if err = kexec.FileLoad(kernel, nil /* ramfs */, cmdline); err != nil {
   295  			return fmt.Errorf("DHCP: kexec.FileLoad failed: %v", err)
   296  		}
   297  		if err = kexec.Reboot(); err != nil {
   298  			return fmt.Errorf("DHCP: kexec.Reboot failed: %v", err)
   299  		}
   300  	} else {
   301  		log.Printf("DHCP: I would've kexec into %s (with arguments: \"%s\") now unless the dry mode", filename, cmdline)
   302  	}
   303  	return nil
   304  }
   305  
   306  func getScheme(urlstring string) (string, error) {
   307  	u, err := url.Parse(urlstring)
   308  	if err != nil {
   309  		return "", err
   310  	}
   311  	scheme := strings.ToLower(u.Scheme)
   312  	if scheme != "http" && scheme != "https" {
   313  		return "", fmt.Errorf("URL scheme '%s' must be http or https", scheme)
   314  	}
   315  	return scheme, nil
   316  }
   317  
   318  func loadCaCerts() (*x509.CertPool, error) {
   319  	rootCAs, err := x509.SystemCertPool()
   320  	if err != nil {
   321  		return nil, err
   322  	}
   323  	if rootCAs == nil {
   324  		debug("certs: rootCAs == nil")
   325  		rootCAs = x509.NewCertPool()
   326  	}
   327  	caCerts, err := os.ReadFile(*caCertFile)
   328  	if err != nil {
   329  		return nil, fmt.Errorf("could not find cert file '%v' - %v", *caCertFile, err)
   330  	}
   331  	// TODO: Decide if this should also support compressed certs
   332  	// Might be better to have a generic compressed config API
   333  	if ok := rootCAs.AppendCertsFromPEM(caCerts); !ok {
   334  		debug("Failed to append CA Certs from %s, using system certs only", *caCertFile)
   335  	} else {
   336  		debug("CA certs appended from PEM")
   337  	}
   338  	return rootCAs, nil
   339  }
   340  
   341  func getClientForBootfile(bootfile string) (*http.Client, error) {
   342  	var client *http.Client
   343  	scheme, err := getScheme(bootfile)
   344  	if err != nil {
   345  		return nil, err
   346  	}
   347  
   348  	switch scheme {
   349  	case "https":
   350  		var config *tls.Config
   351  		if *skipCertVerify {
   352  			config = &tls.Config{
   353  				InsecureSkipVerify: true,
   354  			}
   355  		} else if *caCertFile != "" {
   356  			rootCAs, err := loadCaCerts()
   357  			if err != nil {
   358  				return nil, err
   359  			}
   360  			config = &tls.Config{
   361  				RootCAs: rootCAs,
   362  			}
   363  		}
   364  		tr := &http.Transport{TLSClientConfig: config}
   365  		client = &http.Client{Transport: tr}
   366  		debug("https client setup (use certs from VPD: %t, skipCertVerify %t)",
   367  			*caCertFile != "", *skipCertVerify)
   368  	case "http":
   369  		client = &http.Client{}
   370  		debug("http client setup")
   371  	default:
   372  		return nil, fmt.Errorf("Scheme %s is unsupported", scheme)
   373  	}
   374  	return client, nil
   375  }
   376  
   377  type dhcpFunc func(string) (bootconf *netboot.BootConf, err error)
   378  
   379  func dhcp6(ifname string) (*netboot.BootConf, error) {
   380  	log.Printf("Trying to obtain a DHCPv6 lease on %s", ifname)
   381  	modifiers := []dhcpv6.Modifier{
   382  		dhcpv6.WithArchType(iana.EFI_X86_64),
   383  	}
   384  	if *userClass != "" {
   385  		modifiers = append(modifiers, dhcpv6.WithUserClass([]byte(*userClass)))
   386  	}
   387  	if *ntpEnable && strings.Contains(*ntpServers, ntpServerDHCP) {
   388  		modifiers = append(modifiers, dhcpv6.WithRequestedOptions(dhcpv6.OptionNTPServer))
   389  	}
   390  	conversation, err := netboot.RequestNetbootv6(ifname, time.Duration(*readTimeout)*time.Second, *dhcpRetries, modifiers...)
   391  	if err != nil {
   392  		return nil, fmt.Errorf("DHCPv6: netboot request for interface %s failed: %v", ifname, err)
   393  	}
   394  	for _, m := range conversation {
   395  		debug(m.Summary())
   396  	}
   397  	return netboot.ConversationToNetconf(conversation)
   398  }
   399  
   400  func dhcp4(ifname string) (*netboot.BootConf, error) {
   401  	log.Printf("Trying to obtain a DHCPv4 lease on %s", ifname)
   402  	var modifiers []dhcpv4.Modifier
   403  	if *userClass != "" {
   404  		modifiers = append(modifiers, dhcpv4.WithUserClass(*userClass, false))
   405  	}
   406  	if *ntpEnable && strings.Contains(*ntpServers, ntpServerDHCP) {
   407  		modifiers = append(modifiers, dhcpv4.WithRequestedOptions(dhcpv4.OptionNTPServers))
   408  	}
   409  	conversation, err := netboot.RequestNetbootv4(ifname, time.Duration(*readTimeout)*time.Second, *dhcpRetries, modifiers...)
   410  	if err != nil {
   411  		return nil, fmt.Errorf("DHCPv4: netboot request for interface %s failed: %v", ifname, err)
   412  	}
   413  	for _, m := range conversation {
   414  		debug(m.Summary())
   415  	}
   416  	return netboot.ConversationToNetconfv4(conversation)
   417  }