github.com/tonistiigi/docker@v0.10.1-0.20240229224939-974013b0dc6a/distribution/xfer/download_test.go (about)

     1  package xfer // import "github.com/docker/docker/distribution/xfer"
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"runtime"
    10  	"sync/atomic"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/docker/distribution"
    15  	"github.com/docker/docker/image"
    16  	"github.com/docker/docker/layer"
    17  	"github.com/docker/docker/pkg/progress"
    18  	"github.com/opencontainers/go-digest"
    19  	"gotest.tools/v3/assert"
    20  )
    21  
    22  const maxDownloadConcurrency = 3
    23  
    24  type mockLayer struct {
    25  	layerData bytes.Buffer
    26  	diffID    layer.DiffID
    27  	chainID   layer.ChainID
    28  	parent    layer.Layer
    29  }
    30  
    31  func (ml *mockLayer) TarStream() (io.ReadCloser, error) {
    32  	return io.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil
    33  }
    34  
    35  func (ml *mockLayer) TarStreamFrom(layer.ChainID) (io.ReadCloser, error) {
    36  	return nil, fmt.Errorf("not implemented")
    37  }
    38  
    39  func (ml *mockLayer) ChainID() layer.ChainID {
    40  	return ml.chainID
    41  }
    42  
    43  func (ml *mockLayer) DiffID() layer.DiffID {
    44  	return ml.diffID
    45  }
    46  
    47  func (ml *mockLayer) Parent() layer.Layer {
    48  	return ml.parent
    49  }
    50  
    51  func (ml *mockLayer) Size() int64 {
    52  	return 0
    53  }
    54  
    55  func (ml *mockLayer) DiffSize() int64 {
    56  	return 0
    57  }
    58  
    59  func (ml *mockLayer) Metadata() (map[string]string, error) {
    60  	return make(map[string]string), nil
    61  }
    62  
    63  type mockLayerStore struct {
    64  	layers map[layer.ChainID]*mockLayer
    65  }
    66  
    67  func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID {
    68  	if len(dgsts) == 0 {
    69  		return parent
    70  	}
    71  	if parent == "" {
    72  		return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...)
    73  	}
    74  	// H = "H(n-1) SHA256(n)"
    75  	dgst := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0])))
    76  	return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...)
    77  }
    78  
    79  func (ls *mockLayerStore) Map() map[layer.ChainID]layer.Layer {
    80  	layers := map[layer.ChainID]layer.Layer{}
    81  
    82  	for k, v := range ls.layers {
    83  		layers[k] = v
    84  	}
    85  
    86  	return layers
    87  }
    88  
    89  func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) {
    90  	return ls.RegisterWithDescriptor(reader, parentID, distribution.Descriptor{})
    91  }
    92  
    93  func (ls *mockLayerStore) RegisterWithDescriptor(reader io.Reader, parentID layer.ChainID, _ distribution.Descriptor) (layer.Layer, error) {
    94  	var (
    95  		parent layer.Layer
    96  		err    error
    97  	)
    98  
    99  	if parentID != "" {
   100  		parent, err = ls.Get(parentID)
   101  		if err != nil {
   102  			return nil, err
   103  		}
   104  	}
   105  
   106  	l := &mockLayer{parent: parent}
   107  	_, err = l.layerData.ReadFrom(reader)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	l.diffID = layer.DiffID(digest.FromBytes(l.layerData.Bytes()))
   112  	l.chainID = createChainIDFromParent(parentID, l.diffID)
   113  
   114  	ls.layers[l.chainID] = l
   115  	return l, nil
   116  }
   117  
   118  func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) {
   119  	l, ok := ls.layers[chainID]
   120  	if !ok {
   121  		return nil, layer.ErrLayerDoesNotExist
   122  	}
   123  	return l, nil
   124  }
   125  
   126  func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) {
   127  	return []layer.Metadata{}, nil
   128  }
   129  
   130  func (ls *mockLayerStore) CreateRWLayer(string, layer.ChainID, *layer.CreateRWLayerOpts) (layer.RWLayer, error) {
   131  	return nil, errors.New("not implemented")
   132  }
   133  
   134  func (ls *mockLayerStore) GetRWLayer(string) (layer.RWLayer, error) {
   135  	return nil, errors.New("not implemented")
   136  }
   137  
   138  func (ls *mockLayerStore) ReleaseRWLayer(layer.RWLayer) ([]layer.Metadata, error) {
   139  	return nil, errors.New("not implemented")
   140  }
   141  
   142  func (ls *mockLayerStore) GetMountID(string) (string, error) {
   143  	return "", errors.New("not implemented")
   144  }
   145  
   146  func (ls *mockLayerStore) Cleanup() error {
   147  	return nil
   148  }
   149  
   150  func (ls *mockLayerStore) DriverStatus() [][2]string {
   151  	return [][2]string{}
   152  }
   153  
   154  func (ls *mockLayerStore) DriverName() string {
   155  	return "mock"
   156  }
   157  
   158  type mockDownloadDescriptor struct {
   159  	currentDownloads *int32
   160  	id               string
   161  	diffID           layer.DiffID
   162  	registeredDiffID layer.DiffID
   163  	expectedDiffID   layer.DiffID
   164  	simulateRetries  int
   165  	retries          int
   166  }
   167  
   168  // Key returns the key used to deduplicate downloads.
   169  func (d *mockDownloadDescriptor) Key() string {
   170  	return d.id
   171  }
   172  
   173  // ID returns the ID for display purposes.
   174  func (d *mockDownloadDescriptor) ID() string {
   175  	return d.id
   176  }
   177  
   178  // DiffID should return the DiffID for this layer, or an error
   179  // if it is unknown (for example, if it has not been downloaded
   180  // before).
   181  func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) {
   182  	if d.diffID != "" {
   183  		return d.diffID, nil
   184  	}
   185  	return "", errors.New("no diffID available")
   186  }
   187  
   188  func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) {
   189  	d.registeredDiffID = diffID
   190  }
   191  
   192  func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser {
   193  	// The mock implementation returns the ID repeated 5 times as a tar
   194  	// stream instead of actual tar data. The data is ignored except for
   195  	// computing IDs.
   196  	return io.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id)))
   197  }
   198  
   199  // Download is called to perform the download.
   200  func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
   201  	if d.currentDownloads != nil {
   202  		defer atomic.AddInt32(d.currentDownloads, -1)
   203  
   204  		if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency {
   205  			return nil, 0, errors.New("concurrency limit exceeded")
   206  		}
   207  	}
   208  
   209  	// Sleep a bit to simulate a time-consuming download.
   210  	for i := int64(0); i <= 10; i++ {
   211  		select {
   212  		case <-ctx.Done():
   213  			return nil, 0, ctx.Err()
   214  		case <-time.After(10 * time.Millisecond):
   215  			progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10})
   216  		}
   217  	}
   218  
   219  	if d.retries < d.simulateRetries {
   220  		d.retries++
   221  		return nil, 0, fmt.Errorf("simulating download attempt failure %d/%d", d.retries, d.simulateRetries)
   222  	}
   223  
   224  	return d.mockTarStream(), 0, nil
   225  }
   226  
   227  func (d *mockDownloadDescriptor) Close() {
   228  }
   229  
   230  func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
   231  	return []DownloadDescriptor{
   232  		&mockDownloadDescriptor{
   233  			currentDownloads: currentDownloads,
   234  			id:               "id1",
   235  			expectedDiffID:   layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"),
   236  		},
   237  		&mockDownloadDescriptor{
   238  			currentDownloads: currentDownloads,
   239  			id:               "id2",
   240  			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
   241  		},
   242  		&mockDownloadDescriptor{
   243  			currentDownloads: currentDownloads,
   244  			id:               "id3",
   245  			expectedDiffID:   layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"),
   246  		},
   247  		&mockDownloadDescriptor{
   248  			currentDownloads: currentDownloads,
   249  			id:               "id2",
   250  			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
   251  		},
   252  		&mockDownloadDescriptor{
   253  			currentDownloads: currentDownloads,
   254  			id:               "id4",
   255  			expectedDiffID:   layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"),
   256  			simulateRetries:  1,
   257  		},
   258  		&mockDownloadDescriptor{
   259  			currentDownloads: currentDownloads,
   260  			id:               "id5",
   261  			expectedDiffID:   layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"),
   262  		},
   263  	}
   264  }
   265  
   266  func TestSuccessfulDownload(t *testing.T) {
   267  	// TODO Windows: Fix this unit text
   268  	if runtime.GOOS == "windows" {
   269  		t.Skip("Needs fixing on Windows")
   270  	}
   271  
   272  	layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
   273  	ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency, func(m *LayerDownloadManager) { m.waitDuration = time.Millisecond })
   274  
   275  	progressChan := make(chan progress.Progress)
   276  	progressDone := make(chan struct{})
   277  	receivedProgress := make(map[string]progress.Progress)
   278  
   279  	go func() {
   280  		for p := range progressChan {
   281  			receivedProgress[p.ID] = p
   282  		}
   283  		close(progressDone)
   284  	}()
   285  
   286  	var currentDownloads int32
   287  	descriptors := downloadDescriptors(&currentDownloads)
   288  
   289  	firstDescriptor := descriptors[0].(*mockDownloadDescriptor)
   290  
   291  	// Pre-register the first layer to simulate an already-existing layer
   292  	l, err := layerStore.Register(firstDescriptor.mockTarStream(), "")
   293  	if err != nil {
   294  		t.Fatal(err)
   295  	}
   296  	firstDescriptor.diffID = l.DiffID()
   297  
   298  	rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
   299  	if err != nil {
   300  		t.Fatalf("download error: %v", err)
   301  	}
   302  
   303  	releaseFunc()
   304  
   305  	close(progressChan)
   306  	<-progressDone
   307  
   308  	if len(rootFS.DiffIDs) != len(descriptors) {
   309  		t.Fatal("got wrong number of diffIDs in rootfs")
   310  	}
   311  
   312  	for i, d := range descriptors {
   313  		descriptor := d.(*mockDownloadDescriptor)
   314  
   315  		if descriptor.diffID != "" {
   316  			if receivedProgress[d.ID()].Action != "Already exists" {
   317  				t.Fatalf("did not get 'Already exists' message for %v", d.ID())
   318  			}
   319  		} else if receivedProgress[d.ID()].Action != "Pull complete" {
   320  			t.Fatalf("did not get 'Pull complete' message for %v", d.ID())
   321  		}
   322  
   323  		if rootFS.DiffIDs[i] != descriptor.expectedDiffID {
   324  			t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i])
   325  		}
   326  
   327  		if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] {
   328  			t.Fatal("diffID mismatch between rootFS and Registered callback")
   329  		}
   330  	}
   331  }
   332  
   333  func TestCancelledDownload(t *testing.T) {
   334  	layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
   335  	ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency, func(m *LayerDownloadManager) { m.waitDuration = time.Millisecond })
   336  	progressChan := make(chan progress.Progress)
   337  	progressDone := make(chan struct{})
   338  
   339  	go func() {
   340  		for range progressChan {
   341  		}
   342  		close(progressDone)
   343  	}()
   344  
   345  	ctx, cancel := context.WithCancel(context.Background())
   346  
   347  	go func() {
   348  		<-time.After(time.Millisecond)
   349  		cancel()
   350  	}()
   351  
   352  	descriptors := downloadDescriptors(nil)
   353  	_, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
   354  	if err != context.Canceled {
   355  		close(progressChan)
   356  		t.Fatal("expected download to be cancelled")
   357  	}
   358  
   359  	close(progressChan)
   360  	<-progressDone
   361  }
   362  
   363  func TestMaxDownloadAttempts(t *testing.T) {
   364  	tests := []struct {
   365  		name                string
   366  		simulateRetries     int
   367  		maxDownloadAttempts int
   368  		expectedErr         string
   369  	}{
   370  		{
   371  			name:                "max-attempts=5, succeed at 2nd attempt",
   372  			simulateRetries:     1,
   373  			maxDownloadAttempts: 5,
   374  		},
   375  		{
   376  			name:                "max-attempts=5, succeed at 5th attempt",
   377  			simulateRetries:     4,
   378  			maxDownloadAttempts: 5,
   379  		},
   380  		{
   381  			name:                "max-attempts=5, fail at 5th attempt",
   382  			simulateRetries:     5,
   383  			maxDownloadAttempts: 5,
   384  			expectedErr:         "simulating download attempt failure 5/5",
   385  		},
   386  		{
   387  			name:                "max-attempts=1, fail after 1 attempt",
   388  			simulateRetries:     1,
   389  			maxDownloadAttempts: 1,
   390  			expectedErr:         "simulating download attempt failure 1/1",
   391  		},
   392  	}
   393  	for _, tc := range tests {
   394  		tc := tc
   395  		t.Run(tc.name, func(t *testing.T) {
   396  			t.Parallel()
   397  			layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
   398  			ldm := NewLayerDownloadManager(
   399  				layerStore,
   400  				maxDownloadConcurrency,
   401  				func(m *LayerDownloadManager) {
   402  					m.waitDuration = time.Millisecond
   403  					m.maxDownloadAttempts = tc.maxDownloadAttempts
   404  				})
   405  
   406  			progressChan := make(chan progress.Progress)
   407  			progressDone := make(chan struct{})
   408  
   409  			go func() {
   410  				for range progressChan {
   411  				}
   412  				close(progressDone)
   413  			}()
   414  
   415  			var currentDownloads int32
   416  			descriptors := downloadDescriptors(&currentDownloads)
   417  			descriptors[4].(*mockDownloadDescriptor).simulateRetries = tc.simulateRetries
   418  
   419  			_, _, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
   420  			if tc.expectedErr == "" {
   421  				assert.NilError(t, err)
   422  			} else {
   423  				assert.Error(t, err, tc.expectedErr)
   424  			}
   425  		})
   426  	}
   427  }