github.com/mirantis/virtlet@v1.5.2-0.20191204181327-1659b8a48e9b/pkg/image/download_test.go (about) 1 /* 2 Copyright 2018 Mirantis 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package image 18 19 import ( 20 "bytes" 21 "context" 22 "crypto/tls" 23 "net/http" 24 "net/http/httptest" 25 "strings" 26 "testing" 27 "time" 28 29 testutils "github.com/Mirantis/virtlet/pkg/utils/testing" 30 ) 31 32 // Note that more of downloader tests are in pkg/imagetranslation/transport_test.go 33 // They examine other aspects like redirects and proxies in conjunction with 34 // image translation handling. 35 36 func downloadHandler(content string) http.HandlerFunc { 37 return func(w http.ResponseWriter, r *http.Request) { 38 switch r.URL.String() { 39 case "/base.qcow2": 40 w.Header().Set("Content-Type", "application/octet-stream") 41 w.Write([]byte(content)) 42 case "/redir.qcow2": 43 http.Redirect(w, r, "/base.qcow2", http.StatusFound) 44 default: 45 http.NotFound(w, r) 46 } 47 } 48 } 49 50 func verifyDownload(t *testing.T, protocol string, content string, ep Endpoint) { 51 downloader := NewDownloader(protocol) 52 var buf bytes.Buffer 53 if err := downloader.DownloadFile(context.Background(), ep, &buf); err != nil { 54 t.Fatalf("DownloadFile(): %v", err) 55 } 56 if buf.String() != content { 57 t.Errorf("bad content: %q instead of %q", buf.String(), content) 58 } 59 } 60 61 func TestDownload(t *testing.T) { 62 ts := httptest.NewServer(downloadHandler("foobar")) 63 defer ts.Close() 64 verifyDownload(t, "http", "foobar", Endpoint{ 65 URL: ts.Listener.Addr().String() + "/base.qcow2", 66 }) 67 } 68 69 func TestDownloadRedirect(t *testing.T) { 70 ts := httptest.NewServer(downloadHandler("foobar")) 71 defer ts.Close() 72 verifyDownload(t, "http", "foobar", Endpoint{ 73 URL: ts.Listener.Addr().String() + "/redir.qcow2", 74 MaxRedirects: -1, 75 }) 76 } 77 78 func TestTLSDownload(t *testing.T) { 79 ca, caKey := testutils.GenerateCert(t, true, "CA", nil, nil) 80 cert, key := testutils.GenerateCert(t, false, "127.0.0.1", ca, caKey) 81 ts := httptest.NewUnstartedServer(downloadHandler("foobar")) 82 ts.TLS = &tls.Config{ 83 Certificates: []tls.Certificate{ 84 { 85 Certificate: [][]byte{cert.Raw}, 86 PrivateKey: key, 87 }, 88 }, 89 } 90 ts.StartTLS() 91 defer ts.Close() 92 verifyDownload(t, "https", "foobar", Endpoint{ 93 URL: ts.Listener.Addr().String() + "/base.qcow2", 94 TLS: &TLSConfig{ 95 Certificates: []TLSCertificate{ 96 {Certificate: ca}, 97 }, 98 }, 99 }) 100 } 101 102 func TestCancelDownload(t *testing.T) { 103 startedWriting := make(chan struct{}) 104 done := make(chan struct{}) 105 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 106 if r.URL.String() != "/base.qcow2" { 107 http.NotFound(w, r) 108 } 109 w.Header().Set("Content-Type", "application/octet-stream") 110 w.Write([]byte("foo")) 111 close(startedWriting) 112 select { 113 case <-time.After(40 * time.Second): 114 t.Errorf("request not cancelled within 40s") 115 case <-r.Context().Done(): 116 } 117 close(done) 118 })) 119 defer ts.Close() 120 downloader := NewDownloader("http") 121 var buf bytes.Buffer 122 ctx, cancel := context.WithCancel(context.Background()) 123 go func() { 124 <-startedWriting 125 cancel() 126 }() 127 128 err := downloader.DownloadFile(ctx, Endpoint{ 129 URL: ts.Listener.Addr().String() + "/base.qcow2", 130 }, &buf) 131 switch { 132 case err == nil: 133 t.Errorf("DownloadFile() didn't return error after being cancelled") 134 case !strings.Contains(err.Error(), "context canceled"): 135 t.Errorf("DownloadFile() is expected to return Cancelled error but returned %q", err) 136 } 137 <-done 138 } 139 140 func TestNotFound(t *testing.T) { 141 ts := httptest.NewServer(downloadHandler("foobar")) 142 defer ts.Close() 143 downloader := NewDownloader("http") 144 var buf bytes.Buffer 145 ep := Endpoint{ 146 URL: ts.Listener.Addr().String() + "/nosuchimage.qcow2", 147 } 148 switch err := downloader.DownloadFile(context.Background(), ep, &buf); { 149 case err == nil: 150 t.Errorf("no error returned for a nonexistent image") 151 case !strings.Contains(err.Error(), "Not Found"): 152 t.Errorf("bad error message for nonexistent image") 153 } 154 }