github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/exp/netbootxyz/network.go (about) 1 // Copyright 2021 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package main 6 7 import ( 8 "context" 9 "crypto/tls" 10 "errors" 11 "fmt" 12 "io" 13 "log" 14 "net" 15 "net/http" 16 "os" 17 "strconv" 18 "strings" 19 "time" 20 21 "github.com/mvdan/u-root-coreutils/pkg/dhclient" 22 "github.com/vishvananda/netlink" 23 ) 24 25 // Counts number of bytes written 26 // Conforms to io.Writer interface 27 type ProgressCounter struct { 28 Downloaded uint64 29 Total uint64 30 PreviousRatio int 31 Writer io.Writer 32 } 33 34 func (counter *ProgressCounter) Write(p []byte) (int, error) { 35 n := len(p) 36 counter.Downloaded += uint64(n) 37 counter.PrintProgress() 38 return n, nil 39 } 40 41 func (counter *ProgressCounter) PrintProgress() { 42 ratio := int(float64(counter.Downloaded) / float64(counter.Total) * 100) 43 44 // Only print every 5% to avoid spamming the serial port and making it look weird 45 if ratio%5 == 0 && ratio != counter.PreviousRatio { 46 // Clear the line by using a character return to go back to the start and 47 // remove the remaining characters by filling it with spaces 48 fmt.Fprintf(counter.Writer, "\r%s", strings.Repeat(" ", 50)) 49 50 fmt.Fprintf(counter.Writer, "\rDownloading... %s out of %s (%d%%)", bytesToHuman(counter.Downloaded), bytesToHuman(counter.Total), ratio) 51 counter.PreviousRatio = ratio 52 } 53 54 if counter.Downloaded == counter.Total { 55 fmt.Fprintf(counter.Writer, "\n") 56 } 57 } 58 59 func bytesToHuman(bytes uint64) string { 60 const unit = 1000 // Instead of 1024 so we'll get MB instead of MiB 61 if bytes < unit { 62 return fmt.Sprintf("%d B", bytes) 63 } 64 65 div := int64(unit) 66 exponent := 0 67 for n := bytes / unit; n >= unit; n /= unit { 68 div *= unit 69 exponent++ 70 } 71 return fmt.Sprintf("%4.1f %cB", float64(bytes)/float64(div), "kMGTPE"[exponent]) 72 } 73 74 func downloadFile(filepath string, url string) error { 75 http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: false} 76 77 log.Printf("Downloading file %s from %s\n", filepath, url) 78 79 headResp, err := http.Head(url) 80 if err != nil { 81 return err 82 } 83 84 defer headResp.Body.Close() 85 86 // Get size for progress indicator 87 size, err := strconv.ParseUint(headResp.Header.Get("Content-Length"), 10, 64) 88 if err != nil { 89 return err 90 } 91 92 // Get the data 93 resp, err := http.Get(url) 94 if err != nil { 95 return err 96 } 97 defer closeIO(resp.Body, &err) 98 99 // Create the file 100 out, err := os.Create(filepath) 101 if err != nil { 102 return err 103 } 104 defer closeFile(out, &err) 105 106 // Write the body to file 107 counter := &ProgressCounter{Total: size, PreviousRatio: 0, Writer: os.Stdout} 108 _, err = io.Copy(out, io.TeeReader(resp.Body, counter)) 109 110 return err 111 } 112 113 // Allows error handling on deferred file.Close() 114 func closeFile(f *os.File, err *error) { 115 e := f.Close() 116 switch *err { 117 case nil: 118 *err = e 119 default: 120 if e != nil { 121 log.Println("Error:", e) 122 } 123 } 124 } 125 126 // Allows error handling on deferred io.Close() 127 func closeIO(c io.Closer, err *error) { 128 e := c.Close() 129 switch *err { 130 case nil: 131 *err = e 132 default: 133 if e != nil { 134 log.Println("Error:", e) 135 } 136 } 137 } 138 139 func configureDHCPNetwork() error { 140 if *verbose { 141 log.Printf("Trying to configure network configuration dynamically...") 142 } 143 144 link, err := findNetworkInterface(*ifName) 145 if err != nil { 146 return err 147 } 148 149 var links []netlink.Link 150 links = append(links, link) 151 152 var level dhclient.LogLevel 153 154 config := dhclient.Config{ 155 Timeout: dhcpTimeout, 156 Retries: dhcpTries, 157 LogLevel: level, 158 } 159 160 r := dhclient.SendRequests(context.TODO(), links, true, false, config, 20*time.Second) 161 for result := range r { 162 if result.Err == nil { 163 return result.Lease.Configure() 164 } 165 log.Printf("dhcp response error: %v", result.Err) 166 } 167 return errors.New("no valid DHCP configuration recieved") 168 } 169 170 func findNetworkInterface(ifName string) (netlink.Link, error) { 171 if ifName != "" { 172 if *verbose { 173 log.Printf("Try using %s", ifName) 174 } 175 link, err := netlink.LinkByName(ifName) 176 if err == nil { 177 return link, nil 178 } 179 log.Print(err) 180 } 181 182 ifaces, err := net.Interfaces() 183 if err != nil { 184 return nil, err 185 } 186 187 if len(ifaces) == 0 { 188 return nil, errors.New("no network interface found") 189 } 190 191 var ifnames []string 192 for _, iface := range ifaces { 193 ifnames = append(ifnames, iface.Name) 194 // skip loopback 195 if iface.Flags&net.FlagLoopback != 0 || iface.HardwareAddr.String() == "" { 196 continue 197 } 198 if *verbose { 199 log.Printf("Try using %s", iface.Name) 200 } 201 link, err := netlink.LinkByName(iface.Name) 202 if err == nil { 203 return link, nil 204 } 205 log.Print(err) 206 } 207 208 return nil, fmt.Errorf("could not find a non-loopback network interface with hardware address in any of %v", ifnames) 209 }