github.com/hashicorp/go-getter/v2@v2.2.2/client_option_progress_test.go (about)

     1  package getter
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"os"
     9  	"path/filepath"
    10  	"sync"
    11  	"testing"
    12  
    13  	testing_helper "github.com/hashicorp/go-getter/v2/helper/testing"
    14  )
    15  
    16  type MockProgressTracking struct {
    17  	sync.Mutex
    18  	downloaded map[string]int
    19  }
    20  
    21  func (p *MockProgressTracking) TrackProgress(src string,
    22  	currentSize, totalSize int64, stream io.ReadCloser) (body io.ReadCloser) {
    23  	p.Lock()
    24  	defer p.Unlock()
    25  
    26  	if p.downloaded == nil {
    27  		p.downloaded = map[string]int{}
    28  	}
    29  
    30  	v, _ := p.downloaded[src]
    31  	p.downloaded[src] = v + 1
    32  	return stream
    33  }
    34  
    35  func TestGet_progress(t *testing.T) {
    36  	s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    37  		// all good
    38  		rw.Header().Add("X-Terraform-Get", "something")
    39  	}))
    40  	defer s.Close()
    41  	ctx := context.Background()
    42  
    43  	{ // dl without tracking
    44  		dst := testing_helper.TempTestFile(t)
    45  		defer os.RemoveAll(filepath.Dir(dst))
    46  		if _, err := GetFile(ctx, dst, s.URL+"/file?thig=this&that"); err != nil {
    47  			t.Fatalf("download failed: %v", err)
    48  		}
    49  	}
    50  
    51  	{ // tracking
    52  		p := &MockProgressTracking{}
    53  		dst := testing_helper.TempTestFile(t)
    54  		defer os.RemoveAll(filepath.Dir(dst))
    55  		req := &Request{
    56  			Dst:              dst,
    57  			Src:              s.URL + "/file?thig=this&that",
    58  			ProgressListener: p,
    59  		}
    60  		if _, err := DefaultClient.Get(ctx, req); err != nil {
    61  			t.Fatalf("download failed: %v", err)
    62  		}
    63  		req = &Request{
    64  			Dst:              dst,
    65  			Src:              s.URL + "/otherfile?thig=this&that",
    66  			ProgressListener: p,
    67  		}
    68  		if _, err := DefaultClient.Get(ctx, req); err != nil {
    69  			t.Fatalf("download failed: %v", err)
    70  		}
    71  
    72  		if p.downloaded["file"] != 1 {
    73  			t.Error("Expected a file download")
    74  		}
    75  		if p.downloaded["otherfile"] != 1 {
    76  			t.Error("Expected a otherfile download")
    77  		}
    78  	}
    79  }