github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/exp/netbootxyz/network.go (about)

     1  // Copyright 2021 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"context"
     9  	"crypto/tls"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"net"
    15  	"net/http"
    16  	"os"
    17  	"strconv"
    18  	"strings"
    19  	"time"
    20  
    21  	"github.com/mvdan/u-root-coreutils/pkg/dhclient"
    22  	"github.com/vishvananda/netlink"
    23  )
    24  
    25  // Counts number of bytes written
    26  // Conforms to io.Writer interface
    27  type ProgressCounter struct {
    28  	Downloaded    uint64
    29  	Total         uint64
    30  	PreviousRatio int
    31  	Writer        io.Writer
    32  }
    33  
    34  func (counter *ProgressCounter) Write(p []byte) (int, error) {
    35  	n := len(p)
    36  	counter.Downloaded += uint64(n)
    37  	counter.PrintProgress()
    38  	return n, nil
    39  }
    40  
    41  func (counter *ProgressCounter) PrintProgress() {
    42  	ratio := int(float64(counter.Downloaded) / float64(counter.Total) * 100)
    43  
    44  	// Only print every 5% to avoid spamming the serial port and making it look weird
    45  	if ratio%5 == 0 && ratio != counter.PreviousRatio {
    46  		// Clear the line by using a character return to go back to the start and
    47  		// remove the remaining characters by filling it with spaces
    48  		fmt.Fprintf(counter.Writer, "\r%s", strings.Repeat(" ", 50))
    49  
    50  		fmt.Fprintf(counter.Writer, "\rDownloading... %s out of %s (%d%%)", bytesToHuman(counter.Downloaded), bytesToHuman(counter.Total), ratio)
    51  		counter.PreviousRatio = ratio
    52  	}
    53  
    54  	if counter.Downloaded == counter.Total {
    55  		fmt.Fprintf(counter.Writer, "\n")
    56  	}
    57  }
    58  
    59  func bytesToHuman(bytes uint64) string {
    60  	const unit = 1000 // Instead of 1024 so we'll get MB instead of MiB
    61  	if bytes < unit {
    62  		return fmt.Sprintf("%d B", bytes)
    63  	}
    64  
    65  	div := int64(unit)
    66  	exponent := 0
    67  	for n := bytes / unit; n >= unit; n /= unit {
    68  		div *= unit
    69  		exponent++
    70  	}
    71  	return fmt.Sprintf("%4.1f %cB", float64(bytes)/float64(div), "kMGTPE"[exponent])
    72  }
    73  
    74  func downloadFile(filepath string, url string) error {
    75  	http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: false}
    76  
    77  	log.Printf("Downloading file %s from %s\n", filepath, url)
    78  
    79  	headResp, err := http.Head(url)
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	defer headResp.Body.Close()
    85  
    86  	// Get size for progress indicator
    87  	size, err := strconv.ParseUint(headResp.Header.Get("Content-Length"), 10, 64)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	// Get the data
    93  	resp, err := http.Get(url)
    94  	if err != nil {
    95  		return err
    96  	}
    97  	defer closeIO(resp.Body, &err)
    98  
    99  	// Create the file
   100  	out, err := os.Create(filepath)
   101  	if err != nil {
   102  		return err
   103  	}
   104  	defer closeFile(out, &err)
   105  
   106  	// Write the body to file
   107  	counter := &ProgressCounter{Total: size, PreviousRatio: 0, Writer: os.Stdout}
   108  	_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
   109  
   110  	return err
   111  }
   112  
   113  // Allows error handling on deferred file.Close()
   114  func closeFile(f *os.File, err *error) {
   115  	e := f.Close()
   116  	switch *err {
   117  	case nil:
   118  		*err = e
   119  	default:
   120  		if e != nil {
   121  			log.Println("Error:", e)
   122  		}
   123  	}
   124  }
   125  
   126  // Allows error handling on deferred io.Close()
   127  func closeIO(c io.Closer, err *error) {
   128  	e := c.Close()
   129  	switch *err {
   130  	case nil:
   131  		*err = e
   132  	default:
   133  		if e != nil {
   134  			log.Println("Error:", e)
   135  		}
   136  	}
   137  }
   138  
   139  func configureDHCPNetwork() error {
   140  	if *verbose {
   141  		log.Printf("Trying to configure network configuration dynamically...")
   142  	}
   143  
   144  	link, err := findNetworkInterface(*ifName)
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	var links []netlink.Link
   150  	links = append(links, link)
   151  
   152  	var level dhclient.LogLevel
   153  
   154  	config := dhclient.Config{
   155  		Timeout:  dhcpTimeout,
   156  		Retries:  dhcpTries,
   157  		LogLevel: level,
   158  	}
   159  
   160  	r := dhclient.SendRequests(context.TODO(), links, true, false, config, 20*time.Second)
   161  	for result := range r {
   162  		if result.Err == nil {
   163  			return result.Lease.Configure()
   164  		}
   165  		log.Printf("dhcp response error: %v", result.Err)
   166  	}
   167  	return errors.New("no valid DHCP configuration recieved")
   168  }
   169  
   170  func findNetworkInterface(ifName string) (netlink.Link, error) {
   171  	if ifName != "" {
   172  		if *verbose {
   173  			log.Printf("Try using %s", ifName)
   174  		}
   175  		link, err := netlink.LinkByName(ifName)
   176  		if err == nil {
   177  			return link, nil
   178  		}
   179  		log.Print(err)
   180  	}
   181  
   182  	ifaces, err := net.Interfaces()
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  
   187  	if len(ifaces) == 0 {
   188  		return nil, errors.New("no network interface found")
   189  	}
   190  
   191  	var ifnames []string
   192  	for _, iface := range ifaces {
   193  		ifnames = append(ifnames, iface.Name)
   194  		// skip loopback
   195  		if iface.Flags&net.FlagLoopback != 0 || iface.HardwareAddr.String() == "" {
   196  			continue
   197  		}
   198  		if *verbose {
   199  			log.Printf("Try using %s", iface.Name)
   200  		}
   201  		link, err := netlink.LinkByName(iface.Name)
   202  		if err == nil {
   203  			return link, nil
   204  		}
   205  		log.Print(err)
   206  	}
   207  
   208  	return nil, fmt.Errorf("could not find a non-loopback network interface with hardware address in any of %v", ifnames)
   209  }