github.com/mirantis/virtlet@v1.5.2-0.20191204181327-1659b8a48e9b/pkg/image/download_test.go (about)

     1  /*
     2  Copyright 2018 Mirantis
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package image
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/tls"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"strings"
    26  	"testing"
    27  	"time"
    28  
    29  	testutils "github.com/Mirantis/virtlet/pkg/utils/testing"
    30  )
    31  
    32  // Note that more of downloader tests are in pkg/imagetranslation/transport_test.go
    33  // They examine other aspects like redirects and proxies in conjunction with
    34  // image translation handling.
    35  
    36  func downloadHandler(content string) http.HandlerFunc {
    37  	return func(w http.ResponseWriter, r *http.Request) {
    38  		switch r.URL.String() {
    39  		case "/base.qcow2":
    40  			w.Header().Set("Content-Type", "application/octet-stream")
    41  			w.Write([]byte(content))
    42  		case "/redir.qcow2":
    43  			http.Redirect(w, r, "/base.qcow2", http.StatusFound)
    44  		default:
    45  			http.NotFound(w, r)
    46  		}
    47  	}
    48  }
    49  
    50  func verifyDownload(t *testing.T, protocol string, content string, ep Endpoint) {
    51  	downloader := NewDownloader(protocol)
    52  	var buf bytes.Buffer
    53  	if err := downloader.DownloadFile(context.Background(), ep, &buf); err != nil {
    54  		t.Fatalf("DownloadFile(): %v", err)
    55  	}
    56  	if buf.String() != content {
    57  		t.Errorf("bad content: %q instead of %q", buf.String(), content)
    58  	}
    59  }
    60  
    61  func TestDownload(t *testing.T) {
    62  	ts := httptest.NewServer(downloadHandler("foobar"))
    63  	defer ts.Close()
    64  	verifyDownload(t, "http", "foobar", Endpoint{
    65  		URL: ts.Listener.Addr().String() + "/base.qcow2",
    66  	})
    67  }
    68  
    69  func TestDownloadRedirect(t *testing.T) {
    70  	ts := httptest.NewServer(downloadHandler("foobar"))
    71  	defer ts.Close()
    72  	verifyDownload(t, "http", "foobar", Endpoint{
    73  		URL:          ts.Listener.Addr().String() + "/redir.qcow2",
    74  		MaxRedirects: -1,
    75  	})
    76  }
    77  
    78  func TestTLSDownload(t *testing.T) {
    79  	ca, caKey := testutils.GenerateCert(t, true, "CA", nil, nil)
    80  	cert, key := testutils.GenerateCert(t, false, "127.0.0.1", ca, caKey)
    81  	ts := httptest.NewUnstartedServer(downloadHandler("foobar"))
    82  	ts.TLS = &tls.Config{
    83  		Certificates: []tls.Certificate{
    84  			{
    85  				Certificate: [][]byte{cert.Raw},
    86  				PrivateKey:  key,
    87  			},
    88  		},
    89  	}
    90  	ts.StartTLS()
    91  	defer ts.Close()
    92  	verifyDownload(t, "https", "foobar", Endpoint{
    93  		URL: ts.Listener.Addr().String() + "/base.qcow2",
    94  		TLS: &TLSConfig{
    95  			Certificates: []TLSCertificate{
    96  				{Certificate: ca},
    97  			},
    98  		},
    99  	})
   100  }
   101  
   102  func TestCancelDownload(t *testing.T) {
   103  	startedWriting := make(chan struct{})
   104  	done := make(chan struct{})
   105  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   106  		if r.URL.String() != "/base.qcow2" {
   107  			http.NotFound(w, r)
   108  		}
   109  		w.Header().Set("Content-Type", "application/octet-stream")
   110  		w.Write([]byte("foo"))
   111  		close(startedWriting)
   112  		select {
   113  		case <-time.After(40 * time.Second):
   114  			t.Errorf("request not cancelled within 40s")
   115  		case <-r.Context().Done():
   116  		}
   117  		close(done)
   118  	}))
   119  	defer ts.Close()
   120  	downloader := NewDownloader("http")
   121  	var buf bytes.Buffer
   122  	ctx, cancel := context.WithCancel(context.Background())
   123  	go func() {
   124  		<-startedWriting
   125  		cancel()
   126  	}()
   127  
   128  	err := downloader.DownloadFile(ctx, Endpoint{
   129  		URL: ts.Listener.Addr().String() + "/base.qcow2",
   130  	}, &buf)
   131  	switch {
   132  	case err == nil:
   133  		t.Errorf("DownloadFile() didn't return error after being cancelled")
   134  	case !strings.Contains(err.Error(), "context canceled"):
   135  		t.Errorf("DownloadFile() is expected to return Cancelled error but returned %q", err)
   136  	}
   137  	<-done
   138  }
   139  
   140  func TestNotFound(t *testing.T) {
   141  	ts := httptest.NewServer(downloadHandler("foobar"))
   142  	defer ts.Close()
   143  	downloader := NewDownloader("http")
   144  	var buf bytes.Buffer
   145  	ep := Endpoint{
   146  		URL: ts.Listener.Addr().String() + "/nosuchimage.qcow2",
   147  	}
   148  	switch err := downloader.DownloadFile(context.Background(), ep, &buf); {
   149  	case err == nil:
   150  		t.Errorf("no error returned for a nonexistent image")
   151  	case !strings.Contains(err.Error(), "Not Found"):
   152  		t.Errorf("bad error message for nonexistent image")
   153  	}
   154  }