github.com/uriddle/docker@v0.0.0-20210926094723-4072e6aeb013/distribution/xfer/download_test.go (about)

     1  package xfer
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"io/ioutil"
     8  	"sync/atomic"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/docker/distribution/digest"
    13  	"github.com/docker/docker/image"
    14  	"github.com/docker/docker/layer"
    15  	"github.com/docker/docker/pkg/progress"
    16  	"golang.org/x/net/context"
    17  )
    18  
    19  const maxDownloadConcurrency = 3
    20  
    21  type mockLayer struct {
    22  	layerData bytes.Buffer
    23  	diffID    layer.DiffID
    24  	chainID   layer.ChainID
    25  	parent    layer.Layer
    26  }
    27  
    28  func (ml *mockLayer) TarStream() (io.ReadCloser, error) {
    29  	return ioutil.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil
    30  }
    31  
    32  func (ml *mockLayer) ChainID() layer.ChainID {
    33  	return ml.chainID
    34  }
    35  
    36  func (ml *mockLayer) DiffID() layer.DiffID {
    37  	return ml.diffID
    38  }
    39  
    40  func (ml *mockLayer) Parent() layer.Layer {
    41  	return ml.parent
    42  }
    43  
    44  func (ml *mockLayer) Size() (size int64, err error) {
    45  	return 0, nil
    46  }
    47  
    48  func (ml *mockLayer) DiffSize() (size int64, err error) {
    49  	return 0, nil
    50  }
    51  
    52  func (ml *mockLayer) Metadata() (map[string]string, error) {
    53  	return make(map[string]string), nil
    54  }
    55  
    56  type mockLayerStore struct {
    57  	layers map[layer.ChainID]*mockLayer
    58  }
    59  
    60  func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID {
    61  	if len(dgsts) == 0 {
    62  		return parent
    63  	}
    64  	if parent == "" {
    65  		return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...)
    66  	}
    67  	// H = "H(n-1) SHA256(n)"
    68  	dgst := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0])))
    69  	return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...)
    70  }
    71  
    72  func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) {
    73  	var (
    74  		parent layer.Layer
    75  		err    error
    76  	)
    77  
    78  	if parentID != "" {
    79  		parent, err = ls.Get(parentID)
    80  		if err != nil {
    81  			return nil, err
    82  		}
    83  	}
    84  
    85  	l := &mockLayer{parent: parent}
    86  	_, err = l.layerData.ReadFrom(reader)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  	l.diffID = layer.DiffID(digest.FromBytes(l.layerData.Bytes()))
    91  	l.chainID = createChainIDFromParent(parentID, l.diffID)
    92  
    93  	ls.layers[l.chainID] = l
    94  	return l, nil
    95  }
    96  
    97  func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) {
    98  	l, ok := ls.layers[chainID]
    99  	if !ok {
   100  		return nil, layer.ErrLayerDoesNotExist
   101  	}
   102  	return l, nil
   103  }
   104  
   105  func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) {
   106  	return []layer.Metadata{}, nil
   107  }
   108  func (ls *mockLayerStore) CreateRWLayer(string, layer.ChainID, string, layer.MountInit) (layer.RWLayer, error) {
   109  	return nil, errors.New("not implemented")
   110  }
   111  
   112  func (ls *mockLayerStore) GetRWLayer(string) (layer.RWLayer, error) {
   113  	return nil, errors.New("not implemented")
   114  
   115  }
   116  
   117  func (ls *mockLayerStore) ReleaseRWLayer(layer.RWLayer) ([]layer.Metadata, error) {
   118  	return nil, errors.New("not implemented")
   119  
   120  }
   121  
   122  func (ls *mockLayerStore) Cleanup() error {
   123  	return nil
   124  }
   125  
   126  func (ls *mockLayerStore) DriverStatus() [][2]string {
   127  	return [][2]string{}
   128  }
   129  
   130  func (ls *mockLayerStore) DriverName() string {
   131  	return "mock"
   132  }
   133  
   134  type mockDownloadDescriptor struct {
   135  	currentDownloads *int32
   136  	id               string
   137  	diffID           layer.DiffID
   138  	registeredDiffID layer.DiffID
   139  	expectedDiffID   layer.DiffID
   140  	simulateRetries  int
   141  }
   142  
   143  // Key returns the key used to deduplicate downloads.
   144  func (d *mockDownloadDescriptor) Key() string {
   145  	return d.id
   146  }
   147  
   148  // ID returns the ID for display purposes.
   149  func (d *mockDownloadDescriptor) ID() string {
   150  	return d.id
   151  }
   152  
   153  // DiffID should return the DiffID for this layer, or an error
   154  // if it is unknown (for example, if it has not been downloaded
   155  // before).
   156  func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) {
   157  	if d.diffID != "" {
   158  		return d.diffID, nil
   159  	}
   160  	return "", errors.New("no diffID available")
   161  }
   162  
   163  func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) {
   164  	d.registeredDiffID = diffID
   165  }
   166  
   167  func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser {
   168  	// The mock implementation returns the ID repeated 5 times as a tar
   169  	// stream instead of actual tar data. The data is ignored except for
   170  	// computing IDs.
   171  	return ioutil.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id)))
   172  }
   173  
   174  // Download is called to perform the download.
   175  func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
   176  	if d.currentDownloads != nil {
   177  		defer atomic.AddInt32(d.currentDownloads, -1)
   178  
   179  		if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency {
   180  			return nil, 0, errors.New("concurrency limit exceeded")
   181  		}
   182  	}
   183  
   184  	// Sleep a bit to simulate a time-consuming download.
   185  	for i := int64(0); i <= 10; i++ {
   186  		select {
   187  		case <-ctx.Done():
   188  			return nil, 0, ctx.Err()
   189  		case <-time.After(10 * time.Millisecond):
   190  			progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10})
   191  		}
   192  	}
   193  
   194  	if d.simulateRetries != 0 {
   195  		d.simulateRetries--
   196  		return nil, 0, errors.New("simulating retry")
   197  	}
   198  
   199  	return d.mockTarStream(), 0, nil
   200  }
   201  
   202  func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
   203  	return []DownloadDescriptor{
   204  		&mockDownloadDescriptor{
   205  			currentDownloads: currentDownloads,
   206  			id:               "id1",
   207  			expectedDiffID:   layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"),
   208  		},
   209  		&mockDownloadDescriptor{
   210  			currentDownloads: currentDownloads,
   211  			id:               "id2",
   212  			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
   213  		},
   214  		&mockDownloadDescriptor{
   215  			currentDownloads: currentDownloads,
   216  			id:               "id3",
   217  			expectedDiffID:   layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"),
   218  		},
   219  		&mockDownloadDescriptor{
   220  			currentDownloads: currentDownloads,
   221  			id:               "id2",
   222  			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
   223  		},
   224  		&mockDownloadDescriptor{
   225  			currentDownloads: currentDownloads,
   226  			id:               "id4",
   227  			expectedDiffID:   layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"),
   228  			simulateRetries:  1,
   229  		},
   230  		&mockDownloadDescriptor{
   231  			currentDownloads: currentDownloads,
   232  			id:               "id5",
   233  			expectedDiffID:   layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"),
   234  		},
   235  	}
   236  }
   237  
   238  func TestSuccessfulDownload(t *testing.T) {
   239  	layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
   240  	ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency)
   241  
   242  	progressChan := make(chan progress.Progress)
   243  	progressDone := make(chan struct{})
   244  	receivedProgress := make(map[string]progress.Progress)
   245  
   246  	go func() {
   247  		for p := range progressChan {
   248  			receivedProgress[p.ID] = p
   249  		}
   250  		close(progressDone)
   251  	}()
   252  
   253  	var currentDownloads int32
   254  	descriptors := downloadDescriptors(&currentDownloads)
   255  
   256  	firstDescriptor := descriptors[0].(*mockDownloadDescriptor)
   257  
   258  	// Pre-register the first layer to simulate an already-existing layer
   259  	l, err := layerStore.Register(firstDescriptor.mockTarStream(), "")
   260  	if err != nil {
   261  		t.Fatal(err)
   262  	}
   263  	firstDescriptor.diffID = l.DiffID()
   264  
   265  	rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
   266  	if err != nil {
   267  		t.Fatalf("download error: %v", err)
   268  	}
   269  
   270  	releaseFunc()
   271  
   272  	close(progressChan)
   273  	<-progressDone
   274  
   275  	if len(rootFS.DiffIDs) != len(descriptors) {
   276  		t.Fatal("got wrong number of diffIDs in rootfs")
   277  	}
   278  
   279  	for i, d := range descriptors {
   280  		descriptor := d.(*mockDownloadDescriptor)
   281  
   282  		if descriptor.diffID != "" {
   283  			if receivedProgress[d.ID()].Action != "Already exists" {
   284  				t.Fatalf("did not get 'Already exists' message for %v", d.ID())
   285  			}
   286  		} else if receivedProgress[d.ID()].Action != "Pull complete" {
   287  			t.Fatalf("did not get 'Pull complete' message for %v", d.ID())
   288  		}
   289  
   290  		if rootFS.DiffIDs[i] != descriptor.expectedDiffID {
   291  			t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i])
   292  		}
   293  
   294  		if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] {
   295  			t.Fatal("diffID mismatch between rootFS and Registered callback")
   296  		}
   297  	}
   298  }
   299  
   300  func TestCancelledDownload(t *testing.T) {
   301  	ldm := NewLayerDownloadManager(&mockLayerStore{make(map[layer.ChainID]*mockLayer)}, maxDownloadConcurrency)
   302  
   303  	progressChan := make(chan progress.Progress)
   304  	progressDone := make(chan struct{})
   305  
   306  	go func() {
   307  		for range progressChan {
   308  		}
   309  		close(progressDone)
   310  	}()
   311  
   312  	ctx, cancel := context.WithCancel(context.Background())
   313  
   314  	go func() {
   315  		<-time.After(time.Millisecond)
   316  		cancel()
   317  	}()
   318  
   319  	descriptors := downloadDescriptors(nil)
   320  	_, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
   321  	if err != context.Canceled {
   322  		t.Fatal("expected download to be cancelled")
   323  	}
   324  
   325  	close(progressChan)
   326  	<-progressDone
   327  }