github.com/coreos/rocket@v1.30.1-0.20200224141603-171c416fac02/rkt/fetch_test.go (about)

     1  // Copyright 2014 The rkt Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"archive/tar"
    19  	"bytes"
    20  	"encoding/base64"
    21  	"fmt"
    22  	"io"
    23  	"io/ioutil"
    24  	"net/http"
    25  	"net/http/httptest"
    26  	"net/url"
    27  	"os"
    28  	"path/filepath"
    29  	"strings"
    30  	"testing"
    31  
    32  	"github.com/rkt/rkt/pkg/aci"
    33  	"github.com/rkt/rkt/pkg/aci/acitest"
    34  	dist "github.com/rkt/rkt/pkg/distribution"
    35  	"github.com/rkt/rkt/pkg/keystore"
    36  	"github.com/rkt/rkt/pkg/keystore/keystoretest"
    37  	"github.com/rkt/rkt/rkt/config"
    38  	rktflag "github.com/rkt/rkt/rkt/flag"
    39  	"github.com/rkt/rkt/rkt/image"
    40  	"github.com/rkt/rkt/store/imagestore"
    41  )
    42  
    43  type httpError struct {
    44  	code    int
    45  	message string
    46  }
    47  
    48  func (e *httpError) Error() string {
    49  	return fmt.Sprintf("%d: %s", e.code, e.message)
    50  }
    51  
    52  type serverHandler struct {
    53  	body []byte
    54  	t    *testing.T
    55  	auth string
    56  }
    57  
    58  func getSecFlags(defOpts string) *rktflag.SecFlags {
    59  	sf, err := rktflag.NewSecFlags(defOpts)
    60  	if err != nil {
    61  		panic(fmt.Sprintf("fetch-test: problem initializing flags: %v", err))
    62  	}
    63  
    64  	return sf
    65  }
    66  
    67  var (
    68  	insecureFlags = getSecFlags("image,tls")
    69  	secureFlags   = getSecFlags("none")
    70  )
    71  
    72  func (h *serverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    73  	switch h.auth {
    74  	case "deny":
    75  		if _, ok := r.Header[http.CanonicalHeaderKey("Authorization")]; ok {
    76  			w.WriteHeader(http.StatusBadRequest)
    77  			return
    78  		}
    79  	case "none":
    80  		// no auth to do.
    81  	case "basic":
    82  		payload, httpErr := getAuthPayload(r, "Basic")
    83  		if httpErr != nil {
    84  			w.WriteHeader(httpErr.code)
    85  			return
    86  		}
    87  		creds, err := base64.StdEncoding.DecodeString(string(payload))
    88  		if err != nil {
    89  			w.WriteHeader(http.StatusBadRequest)
    90  			return
    91  		}
    92  		parts := strings.Split(string(creds), ":")
    93  		if len(parts) != 2 {
    94  			w.WriteHeader(http.StatusBadRequest)
    95  			return
    96  		}
    97  		user := parts[0]
    98  		password := parts[1]
    99  		if user != "bar" || password != "baz" {
   100  			w.WriteHeader(http.StatusUnauthorized)
   101  			return
   102  		}
   103  	case "bearer":
   104  		payload, httpErr := getAuthPayload(r, "Bearer")
   105  		if httpErr != nil {
   106  			w.WriteHeader(httpErr.code)
   107  			return
   108  		}
   109  		if payload != "sometoken" {
   110  			w.WriteHeader(http.StatusUnauthorized)
   111  			return
   112  		}
   113  	default:
   114  		panic("bug in test")
   115  	}
   116  	w.Write(h.body)
   117  }
   118  
   119  func getAuthPayload(r *http.Request, authType string) (string, *httpError) {
   120  	auth := r.Header.Get("Authorization")
   121  	if auth == "" {
   122  		err := &httpError{
   123  			code:    http.StatusUnauthorized,
   124  			message: "No auth",
   125  		}
   126  		return "", err
   127  	}
   128  	parts := strings.Split(auth, " ")
   129  	if len(parts) != 2 {
   130  		err := &httpError{
   131  			code:    http.StatusBadRequest,
   132  			message: "Malformed auth",
   133  		}
   134  		return "", err
   135  	}
   136  	if parts[0] != authType {
   137  		err := &httpError{
   138  			code:    http.StatusUnauthorized,
   139  			message: "Wrong auth",
   140  		}
   141  		return "", err
   142  	}
   143  	return parts[1], nil
   144  }
   145  
   146  type testHeaderer struct {
   147  	h http.Header
   148  }
   149  
   150  func (h *testHeaderer) GetHeader() http.Header {
   151  	return h.h
   152  }
   153  
   154  func (h *testHeaderer) SignRequest(r *http.Request) *http.Request {
   155  	r.Header.Set("Authorization", h.GetHeader().Get("Authorization"))
   156  	return r
   157  }
   158  
   159  func TestDownloading(t *testing.T) {
   160  	dir, err := ioutil.TempDir("", "download-image")
   161  	if err != nil {
   162  		t.Fatalf("error creating tempdir: %v", err)
   163  	}
   164  	defer os.RemoveAll(dir)
   165  
   166  	imj, err := acitest.ImageManifestString(nil)
   167  	if err != nil {
   168  		t.Fatalf("unexpected error: %v", err)
   169  	}
   170  
   171  	entries := []*aci.ACIEntry{
   172  		// An empty file
   173  		{
   174  			Contents: "hello",
   175  			Header: &tar.Header{
   176  				Name: "rootfs/file01.txt",
   177  				Size: 5,
   178  			},
   179  		},
   180  	}
   181  
   182  	aci, err := aci.NewACI(dir, imj, entries)
   183  	if err != nil {
   184  		t.Fatalf("error creating test tar: %v", err)
   185  	}
   186  
   187  	// Rewind the ACI
   188  	if _, err := aci.Seek(0, 0); err != nil {
   189  		t.Fatalf("unexpected error: %v", err)
   190  	}
   191  	body, err := ioutil.ReadAll(aci)
   192  	if err != nil {
   193  		t.Fatalf("unexpected error: %v", err)
   194  	}
   195  	noauthServer := &serverHandler{
   196  		body: body,
   197  		t:    t,
   198  		auth: "none",
   199  	}
   200  	basicServer := &serverHandler{
   201  		body: body,
   202  		t:    t,
   203  		auth: "basic",
   204  	}
   205  	oauthServer := &serverHandler{
   206  		body: body,
   207  		t:    t,
   208  		auth: "bearer",
   209  	}
   210  	denyServer := &serverHandler{
   211  		body: body,
   212  		t:    t,
   213  		auth: "deny",
   214  	}
   215  	noAuthTS := httptest.NewTLSServer(noauthServer)
   216  	defer noAuthTS.Close()
   217  	basicTS := httptest.NewTLSServer(basicServer)
   218  	defer basicTS.Close()
   219  	oauthTS := httptest.NewTLSServer(oauthServer)
   220  	defer oauthTS.Close()
   221  	denyAuthTS := httptest.NewServer(denyServer)
   222  	noAuth := http.Header{}
   223  	// YmFyOmJheg== is base64(bar:baz)
   224  	basicAuth := http.Header{"Authorization": {"Basic YmFyOmJheg=="}}
   225  	bearerAuth := http.Header{"Authorization": {"Bearer sometoken"}}
   226  	urlToName := map[string]string{
   227  		noAuthTS.URL:   "no auth",
   228  		basicTS.URL:    "basic",
   229  		oauthTS.URL:    "oauth",
   230  		denyAuthTS.URL: "deny auth",
   231  	}
   232  	tests := []struct {
   233  		aciURL       string
   234  		remoteExists bool
   235  		options      http.Header
   236  		authFail     bool
   237  	}{
   238  		{noAuthTS.URL, false, noAuth, false},
   239  		{noAuthTS.URL, true, noAuth, false},
   240  		{noAuthTS.URL, true, bearerAuth, false},
   241  		{noAuthTS.URL, true, basicAuth, false},
   242  
   243  		{basicTS.URL, false, noAuth, true},
   244  		{basicTS.URL, false, bearerAuth, true},
   245  		{basicTS.URL, false, basicAuth, false},
   246  
   247  		{oauthTS.URL, false, noAuth, true},
   248  		{oauthTS.URL, false, basicAuth, true},
   249  		{oauthTS.URL, false, bearerAuth, false},
   250  
   251  		{denyAuthTS.URL, false, basicAuth, false},
   252  		{denyAuthTS.URL, true, bearerAuth, false},
   253  		{denyAuthTS.URL, true, noAuth, false},
   254  	}
   255  
   256  	s, err := imagestore.NewStore(dir)
   257  	if err != nil {
   258  		t.Fatalf("unexpected error %v", err)
   259  	}
   260  
   261  	for _, tt := range tests {
   262  		_, err := s.GetRemote(tt.aciURL)
   263  		if err != nil {
   264  			if err != imagestore.ErrRemoteNotFound {
   265  				t.Fatalf("unexpected err: %v", err)
   266  			}
   267  
   268  			if tt.remoteExists {
   269  				t.Fatalf("should've found the remote, got %v", err)
   270  			}
   271  		} else if !tt.remoteExists {
   272  			t.Fatalf("should've gotten a remote not found error")
   273  		}
   274  
   275  		parsed, err := url.Parse(tt.aciURL)
   276  		if err != nil {
   277  			panic(fmt.Sprintf("Invalid url from test server: %s", tt.aciURL))
   278  		}
   279  		headers := map[string]config.Headerer{
   280  			parsed.Host: &testHeaderer{tt.options},
   281  		}
   282  		ft := &image.Fetcher{
   283  			S:             s,
   284  			Headers:       headers,
   285  			InsecureFlags: insecureFlags,
   286  		}
   287  		u, err := url.Parse(tt.aciURL)
   288  		if err != nil {
   289  			t.Fatalf("unexpected error %v", err)
   290  		}
   291  		d, err := dist.NewACIArchiveFromTransportURL(u)
   292  		if err != nil {
   293  			t.Fatalf("unexpected error %v", err)
   294  		}
   295  		_, err = ft.FetchImage(d, tt.aciURL, "")
   296  		if err != nil && !tt.authFail {
   297  			t.Fatalf("expected download to succeed, it failed: %v (server: %q, headers: `%v`)", err, urlToName[tt.aciURL], tt.options)
   298  		}
   299  		if err == nil && tt.authFail {
   300  			t.Fatalf("expected download to fail, it succeeded (server: %q, headers: `%v`)", urlToName[tt.aciURL], tt.options)
   301  		}
   302  	}
   303  
   304  	s.Dump(false)
   305  }
   306  
   307  func TestFetchImage(t *testing.T) {
   308  	dir, err := ioutil.TempDir("", "fetch-image")
   309  	if err != nil {
   310  		t.Fatalf("error creating tempdir: %v", err)
   311  	}
   312  	defer os.RemoveAll(dir)
   313  	s, err := imagestore.NewStore(dir)
   314  	if err != nil {
   315  		t.Fatalf("unexpected error %v", err)
   316  	}
   317  	defer s.Dump(false)
   318  
   319  	ks, ksPath, err := keystore.NewTestKeystore()
   320  	if err != nil {
   321  		t.Errorf("unexpected error %v", err)
   322  	}
   323  	defer os.RemoveAll(ksPath)
   324  
   325  	key := keystoretest.KeyMap["example.com/app"]
   326  	if _, err := ks.StoreTrustedKeyPrefix("example.com/app", bytes.NewBufferString(key.ArmoredPublicKey)); err != nil {
   327  		t.Fatalf("unexpected error %v", err)
   328  	}
   329  	a, err := aci.NewBasicACI(dir, "example.com/app")
   330  	defer a.Close()
   331  	if err != nil {
   332  		t.Fatalf("unexpected error %v", err)
   333  	}
   334  
   335  	// Rewind the ACI
   336  	if _, err := a.Seek(0, 0); err != nil {
   337  		t.Fatalf("unexpected error %v", err)
   338  	}
   339  
   340  	asc, err := aci.NewDetachedSignature(key.ArmoredPrivateKey, a)
   341  	if err != nil {
   342  		t.Fatalf("unexpected error %v", err)
   343  	}
   344  
   345  	// Rewind the ACI.
   346  	if _, err := a.Seek(0, 0); err != nil {
   347  		t.Fatalf("unexpected error %v", err)
   348  	}
   349  
   350  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   351  		switch filepath.Ext(r.URL.Path) {
   352  		case ".aci":
   353  			io.Copy(w, a)
   354  			return
   355  		case ".asc":
   356  			io.Copy(w, asc)
   357  			return
   358  		default:
   359  			t.Fatalf("unknown extension %v", r.URL.Path)
   360  		}
   361  	}))
   362  	defer ts.Close()
   363  	ft := &image.Fetcher{
   364  		S:             s,
   365  		Ks:            ks,
   366  		InsecureFlags: secureFlags,
   367  	}
   368  
   369  	u, err := url.Parse(fmt.Sprintf("%s/app.aci", ts.URL))
   370  	if err != nil {
   371  		t.Fatalf("unexpected error %v", err)
   372  	}
   373  	d, err := dist.NewACIArchiveFromTransportURL(u)
   374  	if err != nil {
   375  		t.Fatalf("unexpected error %v", err)
   376  	}
   377  	_, err = ft.FetchImage(d, u.String(), "")
   378  	if err != nil {
   379  		t.Fatalf("unexpected error: %v", err)
   380  	}
   381  }
   382  
   383  func TestGetStoreKeyFromApp(t *testing.T) {
   384  	dir, err := ioutil.TempDir("", "fetch-image")
   385  	if err != nil {
   386  		t.Fatalf("error creating tempdir: %v", err)
   387  	}
   388  	defer os.RemoveAll(dir)
   389  	s, err := imagestore.NewStore(dir)
   390  	if err != nil {
   391  		t.Fatalf("unexpected error %v", err)
   392  	}
   393  	defer s.Dump(false)
   394  
   395  	// Test an aci without os and arch labels
   396  	a, err := aci.NewBasicACI(dir, "example.com/app")
   397  	defer a.Close()
   398  	if err != nil {
   399  		t.Fatalf("unexpected error %v", err)
   400  	}
   401  	// Rewind the ACI
   402  	if _, err := a.Seek(0, 0); err != nil {
   403  		t.Fatalf("unexpected error %v", err)
   404  	}
   405  	_, err = s.WriteACI(a, imagestore.ACIFetchInfo{Latest: false})
   406  	if err != nil {
   407  		t.Fatalf("unexpected error %v", err)
   408  	}
   409  
   410  	_, err = getStoreKeyFromApp(s, "example.com/app")
   411  	if err != nil {
   412  		t.Fatalf("unexpected error: %v", err)
   413  	}
   414  }
   415  
   416  type redirectingServerHandler struct {
   417  	destServer string
   418  }
   419  
   420  func (h *redirectingServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   421  	w.Header().Set("Location", fmt.Sprintf("%s/%s", h.destServer, r.URL.Path))
   422  	w.WriteHeader(http.StatusTemporaryRedirect)
   423  }
   424  
   425  type cachingServerHandler struct {
   426  	aciBody []byte
   427  	ascBody []byte
   428  	etag    string
   429  	maxAge  int
   430  	t       *testing.T
   431  }
   432  
   433  func (h *cachingServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   434  	switch filepath.Ext(r.URL.Path) {
   435  	case ".aci":
   436  		if h.maxAge > 0 {
   437  			w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", h.maxAge))
   438  		}
   439  		if h.etag != "" {
   440  			w.Header().Set("ETag", h.etag)
   441  			if cc := r.Header.Get("If-None-Match"); cc == h.etag {
   442  				w.WriteHeader(http.StatusNotModified)
   443  				return
   444  			}
   445  		}
   446  		w.Write(h.aciBody)
   447  		return
   448  	case ".asc":
   449  		w.Write(h.ascBody)
   450  		return
   451  	}
   452  }
   453  
   454  func TestFetchImageCache(t *testing.T) {
   455  	dir, err := ioutil.TempDir("", "fetch-image-cache")
   456  	if err != nil {
   457  		t.Fatalf("error creating tempdir: %v", err)
   458  	}
   459  	defer os.RemoveAll(dir)
   460  	s, err := imagestore.NewStore(dir)
   461  	if err != nil {
   462  		t.Fatalf("unexpected error %v", err)
   463  	}
   464  	defer s.Dump(false)
   465  
   466  	ks, ksPath, err := keystore.NewTestKeystore()
   467  	if err != nil {
   468  		t.Errorf("unexpected error %v", err)
   469  	}
   470  	defer os.RemoveAll(ksPath)
   471  
   472  	key := keystoretest.KeyMap["example.com/app"]
   473  	if _, err := ks.StoreTrustedKeyPrefix("example.com/app", bytes.NewBufferString(key.ArmoredPublicKey)); err != nil {
   474  		t.Fatalf("unexpected error %v", err)
   475  	}
   476  	a, err := aci.NewBasicACI(dir, "example.com/app")
   477  	defer a.Close()
   478  	if err != nil {
   479  		t.Fatalf("unexpected error %v", err)
   480  	}
   481  	// Rewind the ACI
   482  	if _, err := a.Seek(0, 0); err != nil {
   483  		t.Fatalf("unexpected error %v", err)
   484  	}
   485  	asc, err := aci.NewDetachedSignature(key.ArmoredPrivateKey, a)
   486  	if err != nil {
   487  		t.Fatalf("unexpected error %v", err)
   488  	}
   489  	// Rewind the ACI
   490  	if _, err := a.Seek(0, 0); err != nil {
   491  		t.Fatalf("unexpected error %v", err)
   492  	}
   493  	aciBody, err := ioutil.ReadAll(a)
   494  	if err != nil {
   495  		t.Fatalf("unexpected error: %v", err)
   496  	}
   497  	ascBody, err := ioutil.ReadAll(asc)
   498  	if err != nil {
   499  		t.Fatalf("unexpected error: %v", err)
   500  	}
   501  
   502  	nocacheServer := &cachingServerHandler{
   503  		aciBody: aciBody,
   504  		ascBody: ascBody,
   505  		etag:    "",
   506  		maxAge:  0,
   507  		t:       t,
   508  	}
   509  	etagServer := &cachingServerHandler{
   510  		aciBody: aciBody,
   511  		ascBody: ascBody,
   512  		etag:    "123456789",
   513  		maxAge:  0,
   514  		t:       t,
   515  	}
   516  	maxAgeServer := &cachingServerHandler{
   517  		aciBody: aciBody,
   518  		ascBody: ascBody,
   519  		etag:    "",
   520  		maxAge:  10,
   521  		t:       t,
   522  	}
   523  	etagMaxAgeServer := &cachingServerHandler{
   524  		aciBody: aciBody,
   525  		ascBody: ascBody,
   526  		etag:    "123456789",
   527  		maxAge:  10,
   528  		t:       t,
   529  	}
   530  
   531  	nocacheTS := httptest.NewServer(nocacheServer)
   532  	defer nocacheTS.Close()
   533  	etagTS := httptest.NewServer(etagServer)
   534  	defer etagTS.Close()
   535  	maxAgeTS := httptest.NewServer(maxAgeServer)
   536  	defer maxAgeTS.Close()
   537  	etagMaxAgeTS := httptest.NewServer(etagMaxAgeServer)
   538  	defer etagMaxAgeTS.Close()
   539  
   540  	type testData struct {
   541  		URL             string
   542  		etag            string
   543  		cacheMaxAge     int
   544  		shouldUseCached bool
   545  	}
   546  	tests := []testData{
   547  		{nocacheTS.URL, "", 0, false},
   548  		{etagTS.URL, "123456789", 0, true},
   549  		{maxAgeTS.URL, "", 10, true},
   550  		{etagMaxAgeTS.URL, "123456789", 10, true},
   551  	}
   552  	testFn := func(tt testData, useRedirect bool) {
   553  		aciURL := fmt.Sprintf("%s/app.aci", tt.URL)
   554  		if useRedirect {
   555  			redirectingTS := httptest.NewServer(&redirectingServerHandler{destServer: tt.URL})
   556  			defer redirectingTS.Close()
   557  			aciURL = fmt.Sprintf("%s/app.aci", redirectingTS.URL)
   558  		}
   559  		ft := &image.Fetcher{
   560  			S:             s,
   561  			Ks:            ks,
   562  			InsecureFlags: secureFlags,
   563  			// Skip local store
   564  			PullPolicy: image.PullPolicyUpdate,
   565  		}
   566  		u, err := url.Parse(aciURL)
   567  		if err != nil {
   568  			t.Fatalf("unexpected error %v", err)
   569  		}
   570  		d, err := dist.NewACIArchiveFromTransportURL(u)
   571  		if err != nil {
   572  			t.Fatalf("unexpected error %v", err)
   573  		}
   574  		_, err = ft.FetchImage(d, u.String(), "")
   575  		if err != nil {
   576  			t.Fatalf("unexpected error: %v", err)
   577  		}
   578  		rem, err := s.GetRemote(aciURL)
   579  		if err != nil {
   580  			t.Fatalf("Error getting remote info: %v\n", err)
   581  		}
   582  		if rem.ETag != tt.etag {
   583  			t.Errorf("expected remote to have a ETag header argument")
   584  		}
   585  		if rem.CacheMaxAge != tt.cacheMaxAge {
   586  			t.Errorf("expected max-age header argument to be %q", tt.cacheMaxAge)
   587  		}
   588  
   589  		downloadTime := rem.DownloadTime
   590  		_, err = ft.FetchImage(d, u.String(), "")
   591  		if err != nil {
   592  			t.Fatalf("unexpected error: %v", err)
   593  		}
   594  		rem, err = s.GetRemote(aciURL)
   595  		if err != nil {
   596  			t.Fatalf("Error getting remote info: %v\n", err)
   597  		}
   598  		if rem.ETag != tt.etag {
   599  			t.Errorf("expected remote to have a ETag header argument")
   600  		}
   601  		if rem.CacheMaxAge != tt.cacheMaxAge {
   602  			t.Errorf("expected max-age header argument to be %q", tt.cacheMaxAge)
   603  		}
   604  		if tt.shouldUseCached {
   605  			if downloadTime != rem.DownloadTime {
   606  				t.Errorf("expected current download time to be the same as the previous one (no download) but they differ")
   607  			}
   608  		} else {
   609  			if downloadTime == rem.DownloadTime {
   610  				t.Errorf("expected current download time to be different from the previous one (new image download) but they are the same")
   611  			}
   612  		}
   613  
   614  		if err := s.RemoveACI(rem.BlobKey); err != nil {
   615  			t.Fatalf("unexpected error: %v", err)
   616  		}
   617  	}
   618  
   619  	// repeat the tests with and without a redirecting server
   620  	for i := 0; i <= 1; i++ {
   621  		useRedirect := false
   622  		if i == 1 {
   623  			useRedirect = true
   624  		}
   625  		for _, tt := range tests {
   626  			testFn(tt, useRedirect)
   627  		}
   628  	}
   629  }