github.com/Racer159/jackal@v0.32.7-0.20240401174413-0bd2339e4f2e/src/pkg/utils/network.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // SPDX-FileCopyrightText: 2021-Present The Jackal Authors
     3  
     4  // Package utils provides generic helper functions.
     5  package utils
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/url"
    13  	"os"
    14  	"path/filepath"
    15  	"strings"
    16  
    17  	"github.com/Racer159/jackal/src/config/lang"
    18  	"github.com/Racer159/jackal/src/pkg/message"
    19  	"github.com/defenseunicorns/pkg/helpers"
    20  )
    21  
    22  func parseChecksum(src string) (string, string, error) {
    23  	atSymbolCount := strings.Count(src, "@")
    24  	var checksum string
    25  	if atSymbolCount > 0 {
    26  		parsed, err := url.Parse(src)
    27  		if err != nil {
    28  			return src, checksum, fmt.Errorf("unable to parse the URL: %s", src)
    29  		}
    30  		if atSymbolCount == 1 && parsed.User != nil {
    31  			return src, checksum, nil
    32  		}
    33  
    34  		index := strings.LastIndex(src, "@")
    35  		checksum = src[index+1:]
    36  		src = src[:index]
    37  	}
    38  	return src, checksum, nil
    39  }
    40  
    41  // DownloadToFile downloads a given URL to the target filepath (including the cosign key if necessary).
    42  func DownloadToFile(src string, dst string, cosignKeyPath string) (err error) {
    43  	message.Debugf("Downloading %s to %s", src, dst)
    44  	// check if the parsed URL has a checksum
    45  	// if so, remove it and use the checksum to validate the file
    46  	src, checksum, err := parseChecksum(src)
    47  	if err != nil {
    48  		return err
    49  	}
    50  
    51  	err = helpers.CreateDirectory(filepath.Dir(dst), helpers.ReadWriteExecuteUser)
    52  	if err != nil {
    53  		return fmt.Errorf(lang.ErrCreatingDir, filepath.Dir(dst), err.Error())
    54  	}
    55  
    56  	// Create the file
    57  	file, err := os.Create(dst)
    58  	if err != nil {
    59  		return fmt.Errorf(lang.ErrWritingFile, dst, err.Error())
    60  	}
    61  	defer file.Close()
    62  
    63  	parsed, err := url.Parse(src)
    64  	if err != nil {
    65  		return fmt.Errorf("unable to parse the URL: %s", src)
    66  	}
    67  	// If the source url starts with the sget protocol use that, otherwise do a typical GET call
    68  	if parsed.Scheme == helpers.SGETURLScheme {
    69  		err = Sget(context.TODO(), src, cosignKeyPath, file)
    70  		if err != nil {
    71  			return fmt.Errorf("unable to download file with sget: %s: %w", src, err)
    72  		}
    73  		if err != nil {
    74  			return err
    75  		}
    76  	} else {
    77  		err = httpGetFile(src, file)
    78  		if err != nil {
    79  			return err
    80  		}
    81  	}
    82  
    83  	// If the file has a checksum, validate it
    84  	if len(checksum) > 0 {
    85  		received, err := helpers.GetSHA256OfFile(dst)
    86  		if err != nil {
    87  			return err
    88  		}
    89  		if received != checksum {
    90  			return fmt.Errorf("shasum mismatch for file %s: expected %s, got %s ", dst, checksum, received)
    91  		}
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func httpGetFile(url string, destinationFile *os.File) error {
    98  	// Get the data
    99  	resp, err := http.Get(url)
   100  	if err != nil {
   101  		return fmt.Errorf("unable to download the file %s", url)
   102  	}
   103  	defer resp.Body.Close()
   104  
   105  	// Check server response
   106  	if resp.StatusCode != http.StatusOK {
   107  		return fmt.Errorf("bad HTTP status: %s", resp.Status)
   108  	}
   109  
   110  	// Writer the body to file
   111  	title := fmt.Sprintf("Downloading %s", filepath.Base(url))
   112  	progressBar := message.NewProgressBar(resp.ContentLength, title)
   113  
   114  	if _, err = io.Copy(destinationFile, io.TeeReader(resp.Body, progressBar)); err != nil {
   115  		progressBar.Errorf(err, "Unable to save the file %s", destinationFile.Name())
   116  		return err
   117  	}
   118  
   119  	title = fmt.Sprintf("Downloaded %s", url)
   120  	progressBar.Successf("%s", title)
   121  	return nil
   122  }