github.com/ActiveState/cli@v0.0.0-20240508170324-6801f60cd051/internal/httputil/get.go (about)

     1  package httputil
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"os"
     8  	"path/filepath"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/ActiveState/cli/internal/condition"
    13  	"github.com/ActiveState/cli/internal/constants"
    14  	"github.com/ActiveState/cli/internal/environment"
    15  	"github.com/ActiveState/cli/internal/errs"
    16  	"github.com/ActiveState/cli/internal/locale"
    17  	"github.com/ActiveState/cli/internal/logging"
    18  	"github.com/ActiveState/cli/internal/proxyreader"
    19  	"github.com/ActiveState/cli/internal/retryhttp"
    20  	"github.com/ActiveState/cli/pkg/platform/runtime/setup/events/progress"
    21  )
    22  
    23  // Get takes a URL and returns the contents as bytes
    24  var Get func(url string) ([]byte, error)
    25  
    26  var GetDirect = httpGet
    27  
    28  // GetWithProgress takes a URL and returns the contents as bytes, it takes an optional second arg which will spawn a progressbar
    29  var GetWithProgress func(url string, progress progress.Reporter) ([]byte, error)
    30  
    31  func init() {
    32  	SetMocking(condition.InUnitTest())
    33  }
    34  
    35  // SetMocking sets the correct Get methods for testing
    36  func SetMocking(useMocking bool) {
    37  	if useMocking {
    38  		Get = _testHTTPGet
    39  		GetWithProgress = _testHTTPGetWithProgress
    40  	} else {
    41  		Get = httpGet
    42  		GetWithProgress = httpGetWithProgress
    43  	}
    44  }
    45  
    46  func httpGet(url string) ([]byte, error) {
    47  	return httpGetWithProgress(url, nil)
    48  }
    49  
    50  func httpGetWithProgress(url string, progress progress.Reporter) ([]byte, error) {
    51  	return httpGetWithProgressRetry(url, progress, 1, 3)
    52  }
    53  
    54  func httpGetWithProgressRetry(url string, prg progress.Reporter, attempt int, retries int) ([]byte, error) {
    55  	client := retryhttp.NewClient(0 /* 0 = no timeout */, retries)
    56  	resp, err := client.Get(url)
    57  	if err != nil {
    58  		code := -1
    59  		if resp != nil {
    60  			code = resp.StatusCode
    61  		}
    62  		return nil, locale.WrapError(err, "err_network_get", "", "Status code: {{.V0}}", strconv.Itoa(code))
    63  	}
    64  	defer resp.Body.Close()
    65  
    66  	if resp.StatusCode != 200 {
    67  		return nil, locale.NewError("err_invalid_status_code", "", strconv.Itoa(resp.StatusCode))
    68  	}
    69  
    70  	var total int
    71  	length := resp.Header.Get("Content-Length")
    72  	if length == "" {
    73  		total = 1
    74  	} else {
    75  		total, err = strconv.Atoi(length)
    76  		if err != nil {
    77  			return nil, errs.Wrap(err, "Could not convert header length to int, value: %s", length)
    78  		}
    79  	}
    80  
    81  	var src io.Reader = resp.Body
    82  	defer resp.Body.Close()
    83  
    84  	if prg != nil {
    85  		if err := prg.ReportSize(total); err != nil {
    86  			return nil, errs.Wrap(err, "Could not report size")
    87  		}
    88  		src = proxyreader.NewProxyReader(prg, resp.Body)
    89  	}
    90  
    91  	var dst bytes.Buffer
    92  	_, err = io.Copy(&dst, src)
    93  	if err != nil && !errors.Is(err, io.EOF) {
    94  		logging.Debug("Reading body failed: %s", errs.JoinMessage(err))
    95  		if attempt <= retries {
    96  			return httpGetWithProgressRetry(url, prg, attempt+1, retries)
    97  		}
    98  		return nil, errs.Wrap(err, "Could not copy network stream")
    99  	}
   100  
   101  	return dst.Bytes(), nil
   102  }
   103  
   104  func _testHTTPGetWithProgress(url string, progress progress.Reporter) ([]byte, error) {
   105  	return _testHTTPGet(url)
   106  }
   107  
   108  // _testHTTPGet is used when in tests, this cannot be in the test itself as that would limit it to only that one test
   109  func _testHTTPGet(url string) ([]byte, error) {
   110  	path := strings.Replace(url, constants.APIArtifactURL, "", 1)
   111  	path = filepath.Join(environment.GetRootPathUnsafe(), "test", path)
   112  
   113  	body, err := os.ReadFile(path)
   114  	if err != nil {
   115  		return nil, errs.Wrap(err, "Could not read file contents: %s", path)
   116  	}
   117  
   118  	return body, nil
   119  }