github.com/tonistiigi/docker@v0.10.1-0.20240229224939-974013b0dc6a/distribution/xfer/download_test.go (about) 1 package xfer // import "github.com/docker/docker/distribution/xfer" 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "io" 9 "runtime" 10 "sync/atomic" 11 "testing" 12 "time" 13 14 "github.com/docker/distribution" 15 "github.com/docker/docker/image" 16 "github.com/docker/docker/layer" 17 "github.com/docker/docker/pkg/progress" 18 "github.com/opencontainers/go-digest" 19 "gotest.tools/v3/assert" 20 ) 21 22 const maxDownloadConcurrency = 3 23 24 type mockLayer struct { 25 layerData bytes.Buffer 26 diffID layer.DiffID 27 chainID layer.ChainID 28 parent layer.Layer 29 } 30 31 func (ml *mockLayer) TarStream() (io.ReadCloser, error) { 32 return io.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil 33 } 34 35 func (ml *mockLayer) TarStreamFrom(layer.ChainID) (io.ReadCloser, error) { 36 return nil, fmt.Errorf("not implemented") 37 } 38 39 func (ml *mockLayer) ChainID() layer.ChainID { 40 return ml.chainID 41 } 42 43 func (ml *mockLayer) DiffID() layer.DiffID { 44 return ml.diffID 45 } 46 47 func (ml *mockLayer) Parent() layer.Layer { 48 return ml.parent 49 } 50 51 func (ml *mockLayer) Size() int64 { 52 return 0 53 } 54 55 func (ml *mockLayer) DiffSize() int64 { 56 return 0 57 } 58 59 func (ml *mockLayer) Metadata() (map[string]string, error) { 60 return make(map[string]string), nil 61 } 62 63 type mockLayerStore struct { 64 layers map[layer.ChainID]*mockLayer 65 } 66 67 func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID { 68 if len(dgsts) == 0 { 69 return parent 70 } 71 if parent == "" { 72 return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...) 73 } 74 // H = "H(n-1) SHA256(n)" 75 dgst := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0]))) 76 return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...) 77 } 78 79 func (ls *mockLayerStore) Map() map[layer.ChainID]layer.Layer { 80 layers := map[layer.ChainID]layer.Layer{} 81 82 for k, v := range ls.layers { 83 layers[k] = v 84 } 85 86 return layers 87 } 88 89 func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) { 90 return ls.RegisterWithDescriptor(reader, parentID, distribution.Descriptor{}) 91 } 92 93 func (ls *mockLayerStore) RegisterWithDescriptor(reader io.Reader, parentID layer.ChainID, _ distribution.Descriptor) (layer.Layer, error) { 94 var ( 95 parent layer.Layer 96 err error 97 ) 98 99 if parentID != "" { 100 parent, err = ls.Get(parentID) 101 if err != nil { 102 return nil, err 103 } 104 } 105 106 l := &mockLayer{parent: parent} 107 _, err = l.layerData.ReadFrom(reader) 108 if err != nil { 109 return nil, err 110 } 111 l.diffID = layer.DiffID(digest.FromBytes(l.layerData.Bytes())) 112 l.chainID = createChainIDFromParent(parentID, l.diffID) 113 114 ls.layers[l.chainID] = l 115 return l, nil 116 } 117 118 func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) { 119 l, ok := ls.layers[chainID] 120 if !ok { 121 return nil, layer.ErrLayerDoesNotExist 122 } 123 return l, nil 124 } 125 126 func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) { 127 return []layer.Metadata{}, nil 128 } 129 130 func (ls *mockLayerStore) CreateRWLayer(string, layer.ChainID, *layer.CreateRWLayerOpts) (layer.RWLayer, error) { 131 return nil, errors.New("not implemented") 132 } 133 134 func (ls *mockLayerStore) GetRWLayer(string) (layer.RWLayer, error) { 135 return nil, errors.New("not implemented") 136 } 137 138 func (ls *mockLayerStore) ReleaseRWLayer(layer.RWLayer) ([]layer.Metadata, error) { 139 return nil, errors.New("not implemented") 140 } 141 142 func (ls *mockLayerStore) GetMountID(string) (string, error) { 143 return "", errors.New("not implemented") 144 } 145 146 func (ls *mockLayerStore) Cleanup() error { 147 return nil 148 } 149 150 func (ls *mockLayerStore) DriverStatus() [][2]string { 151 return [][2]string{} 152 } 153 154 func (ls *mockLayerStore) DriverName() string { 155 return "mock" 156 } 157 158 type mockDownloadDescriptor struct { 159 currentDownloads *int32 160 id string 161 diffID layer.DiffID 162 registeredDiffID layer.DiffID 163 expectedDiffID layer.DiffID 164 simulateRetries int 165 retries int 166 } 167 168 // Key returns the key used to deduplicate downloads. 169 func (d *mockDownloadDescriptor) Key() string { 170 return d.id 171 } 172 173 // ID returns the ID for display purposes. 174 func (d *mockDownloadDescriptor) ID() string { 175 return d.id 176 } 177 178 // DiffID should return the DiffID for this layer, or an error 179 // if it is unknown (for example, if it has not been downloaded 180 // before). 181 func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) { 182 if d.diffID != "" { 183 return d.diffID, nil 184 } 185 return "", errors.New("no diffID available") 186 } 187 188 func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) { 189 d.registeredDiffID = diffID 190 } 191 192 func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser { 193 // The mock implementation returns the ID repeated 5 times as a tar 194 // stream instead of actual tar data. The data is ignored except for 195 // computing IDs. 196 return io.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id))) 197 } 198 199 // Download is called to perform the download. 200 func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) { 201 if d.currentDownloads != nil { 202 defer atomic.AddInt32(d.currentDownloads, -1) 203 204 if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency { 205 return nil, 0, errors.New("concurrency limit exceeded") 206 } 207 } 208 209 // Sleep a bit to simulate a time-consuming download. 210 for i := int64(0); i <= 10; i++ { 211 select { 212 case <-ctx.Done(): 213 return nil, 0, ctx.Err() 214 case <-time.After(10 * time.Millisecond): 215 progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10}) 216 } 217 } 218 219 if d.retries < d.simulateRetries { 220 d.retries++ 221 return nil, 0, fmt.Errorf("simulating download attempt failure %d/%d", d.retries, d.simulateRetries) 222 } 223 224 return d.mockTarStream(), 0, nil 225 } 226 227 func (d *mockDownloadDescriptor) Close() { 228 } 229 230 func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor { 231 return []DownloadDescriptor{ 232 &mockDownloadDescriptor{ 233 currentDownloads: currentDownloads, 234 id: "id1", 235 expectedDiffID: layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"), 236 }, 237 &mockDownloadDescriptor{ 238 currentDownloads: currentDownloads, 239 id: "id2", 240 expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"), 241 }, 242 &mockDownloadDescriptor{ 243 currentDownloads: currentDownloads, 244 id: "id3", 245 expectedDiffID: layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"), 246 }, 247 &mockDownloadDescriptor{ 248 currentDownloads: currentDownloads, 249 id: "id2", 250 expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"), 251 }, 252 &mockDownloadDescriptor{ 253 currentDownloads: currentDownloads, 254 id: "id4", 255 expectedDiffID: layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"), 256 simulateRetries: 1, 257 }, 258 &mockDownloadDescriptor{ 259 currentDownloads: currentDownloads, 260 id: "id5", 261 expectedDiffID: layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"), 262 }, 263 } 264 } 265 266 func TestSuccessfulDownload(t *testing.T) { 267 // TODO Windows: Fix this unit text 268 if runtime.GOOS == "windows" { 269 t.Skip("Needs fixing on Windows") 270 } 271 272 layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)} 273 ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency, func(m *LayerDownloadManager) { m.waitDuration = time.Millisecond }) 274 275 progressChan := make(chan progress.Progress) 276 progressDone := make(chan struct{}) 277 receivedProgress := make(map[string]progress.Progress) 278 279 go func() { 280 for p := range progressChan { 281 receivedProgress[p.ID] = p 282 } 283 close(progressDone) 284 }() 285 286 var currentDownloads int32 287 descriptors := downloadDescriptors(¤tDownloads) 288 289 firstDescriptor := descriptors[0].(*mockDownloadDescriptor) 290 291 // Pre-register the first layer to simulate an already-existing layer 292 l, err := layerStore.Register(firstDescriptor.mockTarStream(), "") 293 if err != nil { 294 t.Fatal(err) 295 } 296 firstDescriptor.diffID = l.DiffID() 297 298 rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan)) 299 if err != nil { 300 t.Fatalf("download error: %v", err) 301 } 302 303 releaseFunc() 304 305 close(progressChan) 306 <-progressDone 307 308 if len(rootFS.DiffIDs) != len(descriptors) { 309 t.Fatal("got wrong number of diffIDs in rootfs") 310 } 311 312 for i, d := range descriptors { 313 descriptor := d.(*mockDownloadDescriptor) 314 315 if descriptor.diffID != "" { 316 if receivedProgress[d.ID()].Action != "Already exists" { 317 t.Fatalf("did not get 'Already exists' message for %v", d.ID()) 318 } 319 } else if receivedProgress[d.ID()].Action != "Pull complete" { 320 t.Fatalf("did not get 'Pull complete' message for %v", d.ID()) 321 } 322 323 if rootFS.DiffIDs[i] != descriptor.expectedDiffID { 324 t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i]) 325 } 326 327 if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] { 328 t.Fatal("diffID mismatch between rootFS and Registered callback") 329 } 330 } 331 } 332 333 func TestCancelledDownload(t *testing.T) { 334 layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)} 335 ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency, func(m *LayerDownloadManager) { m.waitDuration = time.Millisecond }) 336 progressChan := make(chan progress.Progress) 337 progressDone := make(chan struct{}) 338 339 go func() { 340 for range progressChan { 341 } 342 close(progressDone) 343 }() 344 345 ctx, cancel := context.WithCancel(context.Background()) 346 347 go func() { 348 <-time.After(time.Millisecond) 349 cancel() 350 }() 351 352 descriptors := downloadDescriptors(nil) 353 _, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan)) 354 if err != context.Canceled { 355 close(progressChan) 356 t.Fatal("expected download to be cancelled") 357 } 358 359 close(progressChan) 360 <-progressDone 361 } 362 363 func TestMaxDownloadAttempts(t *testing.T) { 364 tests := []struct { 365 name string 366 simulateRetries int 367 maxDownloadAttempts int 368 expectedErr string 369 }{ 370 { 371 name: "max-attempts=5, succeed at 2nd attempt", 372 simulateRetries: 1, 373 maxDownloadAttempts: 5, 374 }, 375 { 376 name: "max-attempts=5, succeed at 5th attempt", 377 simulateRetries: 4, 378 maxDownloadAttempts: 5, 379 }, 380 { 381 name: "max-attempts=5, fail at 5th attempt", 382 simulateRetries: 5, 383 maxDownloadAttempts: 5, 384 expectedErr: "simulating download attempt failure 5/5", 385 }, 386 { 387 name: "max-attempts=1, fail after 1 attempt", 388 simulateRetries: 1, 389 maxDownloadAttempts: 1, 390 expectedErr: "simulating download attempt failure 1/1", 391 }, 392 } 393 for _, tc := range tests { 394 tc := tc 395 t.Run(tc.name, func(t *testing.T) { 396 t.Parallel() 397 layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)} 398 ldm := NewLayerDownloadManager( 399 layerStore, 400 maxDownloadConcurrency, 401 func(m *LayerDownloadManager) { 402 m.waitDuration = time.Millisecond 403 m.maxDownloadAttempts = tc.maxDownloadAttempts 404 }) 405 406 progressChan := make(chan progress.Progress) 407 progressDone := make(chan struct{}) 408 409 go func() { 410 for range progressChan { 411 } 412 close(progressDone) 413 }() 414 415 var currentDownloads int32 416 descriptors := downloadDescriptors(¤tDownloads) 417 descriptors[4].(*mockDownloadDescriptor).simulateRetries = tc.simulateRetries 418 419 _, _, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan)) 420 if tc.expectedErr == "" { 421 assert.NilError(t, err) 422 } else { 423 assert.Error(t, err, tc.expectedErr) 424 } 425 }) 426 } 427 }