github.com/anchore/syft@v1.38.2/internal/file/getter_test.go (about)

     1  package file
     2  
     3  import (
     4  	"archive/tar"
     5  	"bytes"
     6  	"context"
     7  	"crypto/x509"
     8  	"fmt"
     9  	"net"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"net/url"
    13  	"path"
    14  	"testing"
    15  
    16  	"github.com/stretchr/testify/assert"
    17  
    18  	"github.com/anchore/clio"
    19  )
    20  
    21  func TestGetter_GetFile(t *testing.T) {
    22  	testCases := []struct {
    23  		name          string
    24  		prepareClient func(*http.Client)
    25  		assert        assert.ErrorAssertionFunc
    26  	}{
    27  		{
    28  			name:   "client trusts server's CA",
    29  			assert: assert.NoError,
    30  		},
    31  		{
    32  			name:          "client doesn't trust server's CA",
    33  			prepareClient: removeTrustedCAs,
    34  			assert:        assertUnknownAuthorityError,
    35  		},
    36  	}
    37  
    38  	for _, tc := range testCases {
    39  		t.Run(tc.name, func(t *testing.T) {
    40  			requestPath := "/foo"
    41  
    42  			server := newTestServer(t, withResponseForPath(t, requestPath, testFileContent))
    43  			t.Cleanup(server.Close)
    44  
    45  			httpClient := getClient(t, server)
    46  			if tc.prepareClient != nil {
    47  				tc.prepareClient(httpClient)
    48  			}
    49  
    50  			getter := NewGetter(testID, httpClient)
    51  			requestURL := createRequestURL(t, server, requestPath)
    52  
    53  			tempDir := t.TempDir()
    54  			tempFile := path.Join(tempDir, "some-destination-file")
    55  
    56  			err := getter.GetFile(tempFile, requestURL)
    57  			tc.assert(t, err)
    58  		})
    59  	}
    60  }
    61  
    62  func TestGetter_GetToDir_FilterNonArchivesWired(t *testing.T) {
    63  	testCases := []struct {
    64  		name   string
    65  		source string
    66  		assert assert.ErrorAssertionFunc
    67  	}{
    68  		{
    69  			name:   "error out on non-archive sources",
    70  			source: "http://localhost/something.txt",
    71  			assert: assertErrNonArchiveSource,
    72  		},
    73  	}
    74  
    75  	for _, test := range testCases {
    76  		t.Run(test.name, func(t *testing.T) {
    77  			test.assert(t, NewGetter(testID, nil).GetToDir(t.TempDir(), test.source))
    78  		})
    79  	}
    80  }
    81  
    82  func TestGetter_validateHttpSource(t *testing.T) {
    83  	testCases := []struct {
    84  		name   string
    85  		source string
    86  		assert assert.ErrorAssertionFunc
    87  	}{
    88  		{
    89  			name:   "error out on non-archive sources",
    90  			source: "http://localhost/something.txt",
    91  			assert: assertErrNonArchiveSource,
    92  		},
    93  		{
    94  			name:   "filter out non-archive sources with get param",
    95  			source: "https://localhost/vulnerability-db_v3_2021-11-21T08:15:44Z.txt?checksum=sha256%3Ac402d01fa909a3fa85a5c6733ef27a3a51a9105b6c62b9152adbd24c08358911",
    96  			assert: assertErrNonArchiveSource,
    97  		},
    98  		{
    99  			name:   "ignore non http-https input",
   100  			source: "s3://bucket/something.txt",
   101  			assert: assert.NoError,
   102  		},
   103  	}
   104  
   105  	for _, test := range testCases {
   106  		t.Run(test.name, func(t *testing.T) {
   107  			test.assert(t, validateHTTPSource(test.source))
   108  		})
   109  	}
   110  }
   111  
   112  func TestGetter_GetToDir_CertConcerns(t *testing.T) {
   113  	testCases := []struct {
   114  		name          string
   115  		prepareClient func(*http.Client)
   116  		assert        assert.ErrorAssertionFunc
   117  	}{
   118  
   119  		{
   120  			name:   "client trusts server's CA",
   121  			assert: assert.NoError,
   122  		},
   123  		{
   124  			name:          "client doesn't trust server's CA",
   125  			prepareClient: removeTrustedCAs,
   126  			assert:        assertUnknownAuthorityError,
   127  		},
   128  	}
   129  
   130  	for _, tc := range testCases {
   131  		t.Run(tc.name, func(t *testing.T) {
   132  			requestPath := "/foo.tar"
   133  			tarball := createTarball("foo", testFileContent)
   134  
   135  			server := newTestServer(t, withResponseForPath(t, requestPath, tarball))
   136  			t.Cleanup(server.Close)
   137  
   138  			httpClient := getClient(t, server)
   139  			if tc.prepareClient != nil {
   140  				tc.prepareClient(httpClient)
   141  			}
   142  
   143  			getter := NewGetter(testID, httpClient)
   144  			requestURL := createRequestURL(t, server, requestPath)
   145  
   146  			tempDir := t.TempDir()
   147  
   148  			err := getter.GetToDir(tempDir, requestURL)
   149  			tc.assert(t, err)
   150  		})
   151  	}
   152  }
   153  
   154  func assertUnknownAuthorityError(t assert.TestingT, err error, _ ...interface{}) bool {
   155  	return assert.ErrorAs(t, err, &x509.UnknownAuthorityError{})
   156  }
   157  
   158  func assertErrNonArchiveSource(t assert.TestingT, err error, _ ...interface{}) bool {
   159  	return assert.ErrorIs(t, err, ErrNonArchiveSource)
   160  }
   161  
   162  func removeTrustedCAs(client *http.Client) {
   163  	client.Transport.(*http.Transport).TLSClientConfig.RootCAs = x509.NewCertPool()
   164  }
   165  
   166  // createTarball makes a single-file tarball and returns it as a byte slice.
   167  func createTarball(filename string, content []byte) []byte {
   168  	tarBuffer := new(bytes.Buffer)
   169  	tarWriter := tar.NewWriter(tarBuffer)
   170  	tarWriter.WriteHeader(&tar.Header{
   171  		Name: filename,
   172  		Size: int64(len(content)),
   173  		Mode: 0600,
   174  	})
   175  	tarWriter.Write(content)
   176  	tarWriter.Close()
   177  
   178  	return tarBuffer.Bytes()
   179  }
   180  
   181  type muxOption func(mux *http.ServeMux)
   182  
   183  func withResponseForPath(t *testing.T, path string, response []byte) muxOption {
   184  	t.Helper()
   185  
   186  	return func(mux *http.ServeMux) {
   187  		mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
   188  			t.Logf("server handling request: %s %s", req.Method, req.URL)
   189  
   190  			_, err := w.Write(response)
   191  			if err != nil {
   192  				t.Fatal(err)
   193  			}
   194  		})
   195  	}
   196  }
   197  
   198  var testID = clio.Identification{
   199  	Name:    "test-app",
   200  	Version: "v0.5.3",
   201  }
   202  
   203  func newTestServer(t *testing.T, muxOptions ...muxOption) *httptest.Server {
   204  	t.Helper()
   205  
   206  	mux := http.NewServeMux()
   207  	for _, option := range muxOptions {
   208  		option(mux)
   209  	}
   210  
   211  	server := httptest.NewTLSServer(mux)
   212  	t.Logf("new TLS server listening at %s", getHost(t, server))
   213  
   214  	return server
   215  }
   216  
   217  func createRequestURL(t *testing.T, server *httptest.Server, path string) string {
   218  	t.Helper()
   219  
   220  	// TODO: Figure out how to get this value from the server without hardcoding it here
   221  	const testServerCertificateName = "example.com"
   222  
   223  	serverURL, err := url.Parse(server.URL)
   224  	if err != nil {
   225  		t.Fatal(err)
   226  	}
   227  
   228  	// Set URL hostname to value from TLS certificate
   229  	serverURL.Host = fmt.Sprintf("%s:%s", testServerCertificateName, serverURL.Port())
   230  
   231  	serverURL.Path = path
   232  
   233  	return serverURL.String()
   234  }
   235  
   236  // getClient returns an http.Client that can be used to contact the test TLS server.
   237  func getClient(t *testing.T, server *httptest.Server) *http.Client {
   238  	t.Helper()
   239  
   240  	httpClient := server.Client()
   241  	transport := httpClient.Transport.(*http.Transport)
   242  
   243  	serverHost := getHost(t, server)
   244  
   245  	transport.DialContext = func(_ context.Context, _, addr string) (net.Conn, error) {
   246  		t.Logf("client dialing %q for host %q", serverHost, addr)
   247  
   248  		// Ensure the client dials our test server
   249  		return net.Dial("tcp", serverHost)
   250  	}
   251  
   252  	return httpClient
   253  }
   254  
   255  // getHost extracts the host value from a server URL string.
   256  // e.g. given a server with URL "http://1.2.3.4:5000/foo", getHost returns "1.2.3.4:5000"
   257  func getHost(t *testing.T, server *httptest.Server) string {
   258  	t.Helper()
   259  
   260  	u, err := url.Parse(server.URL)
   261  	if err != nil {
   262  		t.Fatal(err)
   263  	}
   264  
   265  	return u.Hostname() + ":" + u.Port()
   266  }
   267  
   268  var testFileContent = []byte("This is the content of a test file!\n")