github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/updater/util/http_test.go (about) 1 // Copyright 2015 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package util 5 6 import ( 7 "bytes" 8 "fmt" 9 "net/http" 10 "net/http/httptest" 11 "os" 12 "path/filepath" 13 "runtime" 14 "testing" 15 "time" 16 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 ) 20 21 func TestDiscardAndCloseBodyNil(t *testing.T) { 22 err := DiscardAndCloseBody(nil) 23 if err == nil { 24 t.Fatal("Should have errored") 25 } 26 } 27 28 func testServer(t *testing.T, data string, delay time.Duration) *httptest.Server { 29 return testServerWithETag(t, data, delay, "") 30 } 31 32 func testServerWithETag(t *testing.T, data string, delay time.Duration, etag string) *httptest.Server { 33 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 34 if delay > 0 { 35 time.Sleep(delay) 36 } 37 38 etagMatch := r.Header.Get("If-None-Match") 39 if etagMatch != "" { 40 t.Logf("Checking etag match: %s == %s", etag, etagMatch) 41 if etag == etagMatch { 42 w.WriteHeader(http.StatusNotModified) 43 return 44 } 45 } 46 47 w.Header().Set("Content-Type", "application/json") 48 fmt.Fprintln(w, data) 49 })) 50 } 51 52 func testServerForError(err error) *httptest.Server { 53 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 54 http.Error(w, err.Error(), 500) 55 })) 56 } 57 58 func TestSaveHTTPResponse(t *testing.T) { 59 data := `{"test": true}` 60 server := testServer(t, data, 0) 61 defer server.Close() 62 resp, err := http.Get(server.URL) 63 assert.NoError(t, err) 64 65 savePath := TempPath("", "TestSaveHTTPResponse.") 66 defer RemoveFileAtPath(savePath) 67 68 err = SaveHTTPResponse(resp, savePath, 0600, testLog) 69 assert.NoError(t, err) 70 71 saved, err := os.ReadFile(savePath) 72 assert.NoError(t, err) 73 74 assert.Equal(t, string(saved), data+"\n") 75 } 76 77 func TestSaveHTTPResponseInvalidPath(t *testing.T) { 78 data := `{"test": true}` 79 server := testServer(t, data, 0) 80 defer server.Close() 81 resp, err := http.Get(server.URL) 82 assert.NoError(t, err) 83 84 savePath := TempPath("", "TestSaveHTTPResponse.") 85 defer RemoveFileAtPath(savePath) 86 87 badPath := "/badpath" 88 if runtime.GOOS == "windows" { 89 badPath = `x:\` // Shouldn't be writable 90 } 91 92 err = SaveHTTPResponse(resp, badPath, 0600, testLog) 93 assert.Error(t, err) 94 err = SaveHTTPResponse(nil, savePath, 0600, testLog) 95 assert.Error(t, err) 96 } 97 98 func TestURLExistsValid(t *testing.T) { 99 server := testServer(t, "ok", 0) 100 defer server.Close() 101 exists, err := URLExists(server.URL, time.Second, testLog) 102 assert.True(t, exists) 103 assert.NoError(t, err) 104 } 105 106 func TestURLExistsInvalid(t *testing.T) { 107 exists, err := URLExists("", time.Second, testLog) 108 assert.Error(t, err) 109 assert.False(t, exists) 110 111 exists, err = URLExists("badurl", time.Second, testLog) 112 assert.Error(t, err) 113 assert.False(t, exists) 114 115 exists, err = URLExists("http://n", time.Second, testLog) 116 assert.Error(t, err) 117 assert.False(t, exists) 118 } 119 120 func TestURLExistsTimeout(t *testing.T) { 121 server := testServer(t, "timeout", time.Second) 122 defer server.Close() 123 exists, err := URLExists(server.URL, time.Millisecond, testLog) 124 t.Logf("Timeout error: %s", err) 125 assert.Error(t, err) 126 assert.False(t, exists) 127 } 128 129 func TestURLExistsFile(t *testing.T) { 130 path, err := WriteTempFile("TestURLExistsFile", []byte(""), 0600) 131 assert.NoError(t, err) 132 exists, err := URLExists(URLStringForPath(path), 0, testLog) 133 assert.NoError(t, err) 134 assert.True(t, exists) 135 136 exists, err = URLExists(URLStringForPath("/invalid"), 0, testLog) 137 assert.NoError(t, err) 138 assert.False(t, exists) 139 } 140 141 func TestDownloadURLValid(t *testing.T) { 142 server := testServer(t, "ok", 0) 143 defer server.Close() 144 destinationPath := TempPath("", "TestDownloadURLValid.") 145 digest, err := Digest(bytes.NewReader([]byte("ok\n"))) 146 assert.NoError(t, err) 147 err = DownloadURL(server.URL, destinationPath, DownloadURLOptions{Digest: digest, RequireDigest: true, Log: testLog}) 148 if assert.NoError(t, err) { 149 // Check file saved and correct data 150 fileExists, fileErr := FileExists(destinationPath) 151 assert.NoError(t, fileErr) 152 assert.True(t, fileExists) 153 data, readErr := os.ReadFile(destinationPath) 154 assert.NoError(t, readErr) 155 assert.Equal(t, []byte("ok\n"), data) 156 } 157 158 // Repeat test, download again, overwriting destination 159 server2 := testServer(t, "ok2", 0) 160 defer server2.Close() 161 digest2, err := Digest(bytes.NewReader([]byte("ok2\n"))) 162 assert.NoError(t, err) 163 err = DownloadURL(server2.URL, destinationPath, DownloadURLOptions{Digest: digest2, RequireDigest: true, Log: testLog}) 164 if assert.NoError(t, err) { 165 fileExists2, err := FileExists(destinationPath) 166 assert.NoError(t, err) 167 assert.True(t, fileExists2) 168 data2, err := os.ReadFile(destinationPath) 169 assert.NoError(t, err) 170 assert.Equal(t, []byte("ok2\n"), data2) 171 } 172 } 173 174 func TestDownloadURLInvalid(t *testing.T) { 175 destinationPath := TempPath("", "TestDownloadURLInvalid.") 176 177 err := DownloadURL("", destinationPath, DownloadURLOptions{Log: testLog}) 178 assert.Error(t, err) 179 180 err = DownloadURL("badurl", destinationPath, DownloadURLOptions{Log: testLog}) 181 assert.Error(t, err) 182 183 err = DownloadURL("http://", destinationPath, DownloadURLOptions{Log: testLog}) 184 assert.Error(t, err) 185 } 186 187 func TestDownloadURLTimeout(t *testing.T) { 188 server := testServer(t, "timeout", time.Second) 189 defer server.Close() 190 destinationPath := TempPath("", "TestDownloadURLInvalid.") 191 err := DownloadURL(server.URL, destinationPath, DownloadURLOptions{Timeout: time.Millisecond, Log: testLog}) 192 t.Logf("Timeout error: %s", err) 193 assert.Error(t, err) 194 } 195 196 func TestDownloadURLParseError(t *testing.T) { 197 err := DownloadURL("invalid", "", DownloadURLOptions{Log: testLog}) 198 assert.Error(t, err) 199 } 200 201 func TestDownloadURLError(t *testing.T) { 202 server := testServerForError(fmt.Errorf("Test error")) 203 defer server.Close() 204 205 err := DownloadURL(server.URL, "", DownloadURLOptions{Log: testLog}) 206 assert.EqualError(t, err, "Responded with 500 Internal Server Error") 207 } 208 209 func TestDownloadURLLocal(t *testing.T) { 210 _, filename, _, _ := runtime.Caller(0) 211 testZipPath := filepath.Join(filepath.Dir(filename), "../test/test.zip") 212 destinationPath := TempPath("", "TestDownloadURLLocal.") 213 defer RemoveFileAtPath(destinationPath) 214 err := DownloadURL(URLStringForPath(testZipPath), destinationPath, DownloadURLOptions{Log: testLog}) 215 assert.NoError(t, err) 216 217 exists, err := FileExists(destinationPath) 218 assert.NoError(t, err) 219 assert.True(t, exists) 220 } 221 222 func TestDownloadURLETag(t *testing.T) { 223 data := []byte("ok\n") 224 etag := "eff5bc1ef8ec9d03e640fc4370f5eacd" 225 server := testServerWithETag(t, "ok", 0, etag) 226 defer server.Close() 227 destinationPath := TempPath("", "TestDownloadURLETag.") 228 err := os.WriteFile(destinationPath, data, 0600) 229 require.NoError(t, err) 230 digest, err := Digest(bytes.NewReader(data)) 231 assert.NoError(t, err) 232 cached, err := downloadURL(server.URL, destinationPath, DownloadURLOptions{Digest: digest, RequireDigest: true, UseETag: true, Log: testLog}) 233 require.NoError(t, err) 234 assert.True(t, cached) 235 } 236 237 func TestURLExistsParseError(t *testing.T) { 238 exists, err := URLExists("invalid", time.Millisecond, testLog) 239 assert.False(t, exists) 240 assert.Error(t, err) 241 } 242 243 func TestURLExistsError(t *testing.T) { 244 server := testServerForError(fmt.Errorf("Test error")) 245 defer server.Close() 246 247 exists, err := URLExists(server.URL, time.Second, testLog) 248 assert.False(t, exists) 249 assert.EqualError(t, err, "Invalid status code (500)") 250 } 251 252 func TestURLValueForBool(t *testing.T) { 253 assert.Equal(t, "0", URLValueForBool(false)) 254 assert.Equal(t, "1", URLValueForBool(true)) 255 }