github.com/anchore/syft@v1.38.2/internal/file/getter_test.go (about) 1 package file 2 3 import ( 4 "archive/tar" 5 "bytes" 6 "context" 7 "crypto/x509" 8 "fmt" 9 "net" 10 "net/http" 11 "net/http/httptest" 12 "net/url" 13 "path" 14 "testing" 15 16 "github.com/stretchr/testify/assert" 17 18 "github.com/anchore/clio" 19 ) 20 21 func TestGetter_GetFile(t *testing.T) { 22 testCases := []struct { 23 name string 24 prepareClient func(*http.Client) 25 assert assert.ErrorAssertionFunc 26 }{ 27 { 28 name: "client trusts server's CA", 29 assert: assert.NoError, 30 }, 31 { 32 name: "client doesn't trust server's CA", 33 prepareClient: removeTrustedCAs, 34 assert: assertUnknownAuthorityError, 35 }, 36 } 37 38 for _, tc := range testCases { 39 t.Run(tc.name, func(t *testing.T) { 40 requestPath := "/foo" 41 42 server := newTestServer(t, withResponseForPath(t, requestPath, testFileContent)) 43 t.Cleanup(server.Close) 44 45 httpClient := getClient(t, server) 46 if tc.prepareClient != nil { 47 tc.prepareClient(httpClient) 48 } 49 50 getter := NewGetter(testID, httpClient) 51 requestURL := createRequestURL(t, server, requestPath) 52 53 tempDir := t.TempDir() 54 tempFile := path.Join(tempDir, "some-destination-file") 55 56 err := getter.GetFile(tempFile, requestURL) 57 tc.assert(t, err) 58 }) 59 } 60 } 61 62 func TestGetter_GetToDir_FilterNonArchivesWired(t *testing.T) { 63 testCases := []struct { 64 name string 65 source string 66 assert assert.ErrorAssertionFunc 67 }{ 68 { 69 name: "error out on non-archive sources", 70 source: "http://localhost/something.txt", 71 assert: assertErrNonArchiveSource, 72 }, 73 } 74 75 for _, test := range testCases { 76 t.Run(test.name, func(t *testing.T) { 77 test.assert(t, NewGetter(testID, nil).GetToDir(t.TempDir(), test.source)) 78 }) 79 } 80 } 81 82 func TestGetter_validateHttpSource(t *testing.T) { 83 testCases := []struct { 84 name string 85 source string 86 assert assert.ErrorAssertionFunc 87 }{ 88 { 89 name: "error out on non-archive sources", 90 source: "http://localhost/something.txt", 91 assert: assertErrNonArchiveSource, 92 }, 93 { 94 name: "filter out non-archive sources with get param", 95 source: "https://localhost/vulnerability-db_v3_2021-11-21T08:15:44Z.txt?checksum=sha256%3Ac402d01fa909a3fa85a5c6733ef27a3a51a9105b6c62b9152adbd24c08358911", 96 assert: assertErrNonArchiveSource, 97 }, 98 { 99 name: "ignore non http-https input", 100 source: "s3://bucket/something.txt", 101 assert: assert.NoError, 102 }, 103 } 104 105 for _, test := range testCases { 106 t.Run(test.name, func(t *testing.T) { 107 test.assert(t, validateHTTPSource(test.source)) 108 }) 109 } 110 } 111 112 func TestGetter_GetToDir_CertConcerns(t *testing.T) { 113 testCases := []struct { 114 name string 115 prepareClient func(*http.Client) 116 assert assert.ErrorAssertionFunc 117 }{ 118 119 { 120 name: "client trusts server's CA", 121 assert: assert.NoError, 122 }, 123 { 124 name: "client doesn't trust server's CA", 125 prepareClient: removeTrustedCAs, 126 assert: assertUnknownAuthorityError, 127 }, 128 } 129 130 for _, tc := range testCases { 131 t.Run(tc.name, func(t *testing.T) { 132 requestPath := "/foo.tar" 133 tarball := createTarball("foo", testFileContent) 134 135 server := newTestServer(t, withResponseForPath(t, requestPath, tarball)) 136 t.Cleanup(server.Close) 137 138 httpClient := getClient(t, server) 139 if tc.prepareClient != nil { 140 tc.prepareClient(httpClient) 141 } 142 143 getter := NewGetter(testID, httpClient) 144 requestURL := createRequestURL(t, server, requestPath) 145 146 tempDir := t.TempDir() 147 148 err := getter.GetToDir(tempDir, requestURL) 149 tc.assert(t, err) 150 }) 151 } 152 } 153 154 func assertUnknownAuthorityError(t assert.TestingT, err error, _ ...interface{}) bool { 155 return assert.ErrorAs(t, err, &x509.UnknownAuthorityError{}) 156 } 157 158 func assertErrNonArchiveSource(t assert.TestingT, err error, _ ...interface{}) bool { 159 return assert.ErrorIs(t, err, ErrNonArchiveSource) 160 } 161 162 func removeTrustedCAs(client *http.Client) { 163 client.Transport.(*http.Transport).TLSClientConfig.RootCAs = x509.NewCertPool() 164 } 165 166 // createTarball makes a single-file tarball and returns it as a byte slice. 167 func createTarball(filename string, content []byte) []byte { 168 tarBuffer := new(bytes.Buffer) 169 tarWriter := tar.NewWriter(tarBuffer) 170 tarWriter.WriteHeader(&tar.Header{ 171 Name: filename, 172 Size: int64(len(content)), 173 Mode: 0600, 174 }) 175 tarWriter.Write(content) 176 tarWriter.Close() 177 178 return tarBuffer.Bytes() 179 } 180 181 type muxOption func(mux *http.ServeMux) 182 183 func withResponseForPath(t *testing.T, path string, response []byte) muxOption { 184 t.Helper() 185 186 return func(mux *http.ServeMux) { 187 mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) { 188 t.Logf("server handling request: %s %s", req.Method, req.URL) 189 190 _, err := w.Write(response) 191 if err != nil { 192 t.Fatal(err) 193 } 194 }) 195 } 196 } 197 198 var testID = clio.Identification{ 199 Name: "test-app", 200 Version: "v0.5.3", 201 } 202 203 func newTestServer(t *testing.T, muxOptions ...muxOption) *httptest.Server { 204 t.Helper() 205 206 mux := http.NewServeMux() 207 for _, option := range muxOptions { 208 option(mux) 209 } 210 211 server := httptest.NewTLSServer(mux) 212 t.Logf("new TLS server listening at %s", getHost(t, server)) 213 214 return server 215 } 216 217 func createRequestURL(t *testing.T, server *httptest.Server, path string) string { 218 t.Helper() 219 220 // TODO: Figure out how to get this value from the server without hardcoding it here 221 const testServerCertificateName = "example.com" 222 223 serverURL, err := url.Parse(server.URL) 224 if err != nil { 225 t.Fatal(err) 226 } 227 228 // Set URL hostname to value from TLS certificate 229 serverURL.Host = fmt.Sprintf("%s:%s", testServerCertificateName, serverURL.Port()) 230 231 serverURL.Path = path 232 233 return serverURL.String() 234 } 235 236 // getClient returns an http.Client that can be used to contact the test TLS server. 237 func getClient(t *testing.T, server *httptest.Server) *http.Client { 238 t.Helper() 239 240 httpClient := server.Client() 241 transport := httpClient.Transport.(*http.Transport) 242 243 serverHost := getHost(t, server) 244 245 transport.DialContext = func(_ context.Context, _, addr string) (net.Conn, error) { 246 t.Logf("client dialing %q for host %q", serverHost, addr) 247 248 // Ensure the client dials our test server 249 return net.Dial("tcp", serverHost) 250 } 251 252 return httpClient 253 } 254 255 // getHost extracts the host value from a server URL string. 256 // e.g. given a server with URL "http://1.2.3.4:5000/foo", getHost returns "1.2.3.4:5000" 257 func getHost(t *testing.T, server *httptest.Server) string { 258 t.Helper() 259 260 u, err := url.Parse(server.URL) 261 if err != nil { 262 t.Fatal(err) 263 } 264 265 return u.Hostname() + ":" + u.Port() 266 } 267 268 var testFileContent = []byte("This is the content of a test file!\n")