github.com/demonoid81/containerd@v1.3.4/remotes/docker/resolver_test.go (about)

     1  /*
     2     Copyright The containerd Authors.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package docker
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"encoding/json"
    24  	"fmt"
    25  	"io"
    26  	"io/ioutil"
    27  	"net/http"
    28  	"net/http/httptest"
    29  	"strconv"
    30  	"strings"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/containerd/containerd/remotes"
    35  	digest "github.com/opencontainers/go-digest"
    36  	specs "github.com/opencontainers/image-spec/specs-go"
    37  	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
    38  	"github.com/pkg/errors"
    39  )
    40  
    41  func TestHTTPResolver(t *testing.T) {
    42  	s := func(h http.Handler) (string, ResolverOptions, func()) {
    43  		s := httptest.NewServer(h)
    44  
    45  		options := ResolverOptions{}
    46  		base := s.URL[7:] // strip "http://"
    47  		return base, options, s.Close
    48  	}
    49  
    50  	runBasicTest(t, "testname", s)
    51  }
    52  
    53  func TestHTTPSResolver(t *testing.T) {
    54  	runBasicTest(t, "testname", tlsServer)
    55  }
    56  
    57  func TestBasicResolver(t *testing.T) {
    58  	basicAuth := func(h http.Handler) (string, ResolverOptions, func()) {
    59  		// Wrap with basic auth
    60  		wrapped := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
    61  			username, password, ok := r.BasicAuth()
    62  			if !ok || username != "user1" || password != "password1" {
    63  				rw.Header().Set("WWW-Authenticate", "Basic realm=localhost")
    64  				rw.WriteHeader(http.StatusUnauthorized)
    65  				return
    66  			}
    67  			h.ServeHTTP(rw, r)
    68  		})
    69  
    70  		base, options, close := tlsServer(wrapped)
    71  		options.Hosts = ConfigureDefaultRegistries(
    72  			WithClient(options.Client),
    73  			WithAuthorizer(NewAuthorizer(options.Client, func(string) (string, string, error) {
    74  				return "user1", "password1", nil
    75  			})),
    76  		)
    77  		return base, options, close
    78  	}
    79  	runBasicTest(t, "testname", basicAuth)
    80  }
    81  
    82  func TestAnonymousTokenResolver(t *testing.T) {
    83  	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
    84  		if r.Method != http.MethodGet {
    85  			rw.WriteHeader(http.StatusMethodNotAllowed)
    86  			return
    87  		}
    88  		rw.Header().Set("Content-Type", "application/json")
    89  		rw.WriteHeader(http.StatusOK)
    90  		rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`))
    91  	})
    92  
    93  	runBasicTest(t, "testname", withTokenServer(th, nil))
    94  }
    95  
    96  func TestBasicAuthTokenResolver(t *testing.T) {
    97  	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
    98  		if r.Method != http.MethodGet {
    99  			rw.WriteHeader(http.StatusMethodNotAllowed)
   100  			return
   101  		}
   102  		rw.Header().Set("Content-Type", "application/json")
   103  		rw.WriteHeader(http.StatusOK)
   104  		username, password, ok := r.BasicAuth()
   105  		if !ok || username != "user1" || password != "password1" {
   106  			rw.Write([]byte(`{"access_token":"insufficientscope"}`))
   107  		} else {
   108  			rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`))
   109  		}
   110  	})
   111  	creds := func(string) (string, string, error) {
   112  		return "user1", "password1", nil
   113  	}
   114  
   115  	runBasicTest(t, "testname", withTokenServer(th, creds))
   116  }
   117  
   118  func TestRefreshTokenResolver(t *testing.T) {
   119  	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   120  		if r.Method != http.MethodPost {
   121  			rw.WriteHeader(http.StatusMethodNotAllowed)
   122  			return
   123  		}
   124  		rw.Header().Set("Content-Type", "application/json")
   125  		rw.WriteHeader(http.StatusOK)
   126  
   127  		r.ParseForm()
   128  		if r.PostForm.Get("grant_type") != "refresh_token" || r.PostForm.Get("refresh_token") != "somerefreshtoken" {
   129  			rw.Write([]byte(`{"access_token":"insufficientscope"}`))
   130  		} else {
   131  			rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`))
   132  		}
   133  	})
   134  	creds := func(string) (string, string, error) {
   135  		return "", "somerefreshtoken", nil
   136  	}
   137  
   138  	runBasicTest(t, "testname", withTokenServer(th, creds))
   139  }
   140  
   141  func TestPostBasicAuthTokenResolver(t *testing.T) {
   142  	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   143  		if r.Method != http.MethodPost {
   144  			rw.WriteHeader(http.StatusMethodNotAllowed)
   145  			return
   146  		}
   147  		rw.Header().Set("Content-Type", "application/json")
   148  		rw.WriteHeader(http.StatusOK)
   149  
   150  		r.ParseForm()
   151  		if r.PostForm.Get("grant_type") != "password" || r.PostForm.Get("username") != "user1" || r.PostForm.Get("password") != "password1" {
   152  			rw.Write([]byte(`{"access_token":"insufficientscope"}`))
   153  		} else {
   154  			rw.Write([]byte(`{"access_token":"perfectlyvalidopaquetoken"}`))
   155  		}
   156  	})
   157  	creds := func(string) (string, string, error) {
   158  		return "user1", "password1", nil
   159  	}
   160  
   161  	runBasicTest(t, "testname", withTokenServer(th, creds))
   162  }
   163  
   164  func TestBadTokenResolver(t *testing.T) {
   165  	th := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   166  		if r.Method != http.MethodPost {
   167  			rw.WriteHeader(http.StatusMethodNotAllowed)
   168  			return
   169  		}
   170  		rw.Header().Set("Content-Type", "application/json")
   171  		rw.WriteHeader(http.StatusOK)
   172  		rw.Write([]byte(`{"access_token":"insufficientscope"}`))
   173  	})
   174  	creds := func(string) (string, string, error) {
   175  		return "", "somerefreshtoken", nil
   176  	}
   177  
   178  	ctx := context.Background()
   179  	h := newContent(ocispec.MediaTypeImageManifest, []byte("not anything parse-able"))
   180  
   181  	base, ro, close := withTokenServer(th, creds)(logHandler{t, h})
   182  	defer close()
   183  
   184  	resolver := NewResolver(ro)
   185  	image := fmt.Sprintf("%s/doesntmatter:sometatg", base)
   186  
   187  	_, _, err := resolver.Resolve(ctx, image)
   188  	if err == nil {
   189  		t.Fatal("Expected error getting token with inssufficient scope")
   190  	}
   191  	if errors.Cause(err) != ErrInvalidAuthorization {
   192  		t.Fatal(err)
   193  	}
   194  }
   195  
   196  func TestHostFailureFallbackResolver(t *testing.T) {
   197  	sf := func(h http.Handler) (string, ResolverOptions, func()) {
   198  		s := httptest.NewServer(h)
   199  		base := s.URL[7:] // strip "http://"
   200  
   201  		options := ResolverOptions{}
   202  		createHost := func(host string) RegistryHost {
   203  			return RegistryHost{
   204  				Client: &http.Client{
   205  					// Set the timeout so we timeout waiting for the non-responsive HTTP server
   206  					Timeout: 500 * time.Millisecond,
   207  				},
   208  				Host:         host,
   209  				Scheme:       "http",
   210  				Path:         "/v2",
   211  				Capabilities: HostCapabilityPull | HostCapabilityResolve | HostCapabilityPush,
   212  			}
   213  		}
   214  
   215  		// Create an unstarted HTTP server. We use this to generate a random port.
   216  		notRunning := httptest.NewUnstartedServer(nil)
   217  		notRunningBase := notRunning.Listener.Addr().String()
   218  
   219  		// Override hosts with two hosts
   220  		options.Hosts = func(host string) ([]RegistryHost, error) {
   221  			return []RegistryHost{
   222  				createHost(notRunningBase), // This host IS running, but with a non-responsive HTTP server
   223  				createHost(base),           // This host IS running
   224  			}, nil
   225  		}
   226  
   227  		return base, options, s.Close
   228  	}
   229  
   230  	runBasicTest(t, "testname", sf)
   231  }
   232  
   233  func TestHostTLSFailureFallbackResolver(t *testing.T) {
   234  	sf := func(h http.Handler) (string, ResolverOptions, func()) {
   235  		// Start up two servers
   236  		server := httptest.NewServer(h)
   237  		httpBase := server.URL[7:] // strip "http://"
   238  
   239  		tlsServer := httptest.NewUnstartedServer(h)
   240  		tlsServer.StartTLS()
   241  		httpsBase := tlsServer.URL[8:] // strip "https://"
   242  
   243  		capool := x509.NewCertPool()
   244  		cert, _ := x509.ParseCertificate(tlsServer.TLS.Certificates[0].Certificate[0])
   245  		capool.AddCert(cert)
   246  
   247  		client := &http.Client{
   248  			Transport: &http.Transport{
   249  				TLSClientConfig: &tls.Config{
   250  					RootCAs: capool,
   251  				},
   252  			},
   253  		}
   254  
   255  		options := ResolverOptions{}
   256  		createHost := func(host string) RegistryHost {
   257  			return RegistryHost{
   258  				Client:       client,
   259  				Host:         host,
   260  				Scheme:       "https",
   261  				Path:         "/v2",
   262  				Capabilities: HostCapabilityPull | HostCapabilityResolve | HostCapabilityPush,
   263  			}
   264  		}
   265  
   266  		// Override hosts with two hosts
   267  		options.Hosts = func(host string) ([]RegistryHost, error) {
   268  			return []RegistryHost{
   269  				createHost(httpBase),  // This host is serving plain HTTP
   270  				createHost(httpsBase), // This host is serving TLS
   271  			}, nil
   272  		}
   273  
   274  		return httpBase, options, func() {
   275  			server.Close()
   276  			tlsServer.Close()
   277  		}
   278  	}
   279  
   280  	runBasicTest(t, "testname", sf)
   281  }
   282  
   283  func withTokenServer(th http.Handler, creds func(string) (string, string, error)) func(h http.Handler) (string, ResolverOptions, func()) {
   284  	return func(h http.Handler) (string, ResolverOptions, func()) {
   285  		s := httptest.NewUnstartedServer(th)
   286  		s.StartTLS()
   287  
   288  		cert, _ := x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
   289  		tokenBase := s.URL + "/token"
   290  
   291  		// Wrap with token auth
   292  		wrapped := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   293  			auth := strings.ToLower(r.Header.Get("Authorization"))
   294  			if auth != "bearer perfectlyvalidopaquetoken" {
   295  				authHeader := fmt.Sprintf("Bearer realm=%q,service=registry,scope=\"repository:testname:pull,pull\"", tokenBase)
   296  				if strings.HasPrefix(auth, "bearer ") {
   297  					authHeader = authHeader + ",error=" + auth[7:]
   298  				}
   299  				rw.Header().Set("WWW-Authenticate", authHeader)
   300  				rw.WriteHeader(http.StatusUnauthorized)
   301  				return
   302  			}
   303  			h.ServeHTTP(rw, r)
   304  		})
   305  
   306  		base, options, close := tlsServer(wrapped)
   307  		options.Hosts = ConfigureDefaultRegistries(
   308  			WithClient(options.Client),
   309  			WithAuthorizer(NewDockerAuthorizer(
   310  				WithAuthClient(options.Client),
   311  				WithAuthCreds(creds),
   312  			)),
   313  		)
   314  		options.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs.AddCert(cert)
   315  		return base, options, func() {
   316  			s.Close()
   317  			close()
   318  		}
   319  	}
   320  }
   321  
   322  func tlsServer(h http.Handler) (string, ResolverOptions, func()) {
   323  	s := httptest.NewUnstartedServer(h)
   324  	s.StartTLS()
   325  
   326  	capool := x509.NewCertPool()
   327  	cert, _ := x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
   328  	capool.AddCert(cert)
   329  
   330  	client := &http.Client{
   331  		Transport: &http.Transport{
   332  			TLSClientConfig: &tls.Config{
   333  				RootCAs: capool,
   334  			},
   335  		},
   336  	}
   337  	options := ResolverOptions{
   338  		Hosts: ConfigureDefaultRegistries(WithClient(client)),
   339  		// Set deprecated field for tests to use for configuration
   340  		Client: client,
   341  	}
   342  	base := s.URL[8:] // strip "https://"
   343  	return base, options, s.Close
   344  }
   345  
   346  type logHandler struct {
   347  	t       *testing.T
   348  	handler http.Handler
   349  }
   350  
   351  func (h logHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
   352  	h.handler.ServeHTTP(rw, r)
   353  }
   354  
   355  func runBasicTest(t *testing.T, name string, sf func(h http.Handler) (string, ResolverOptions, func())) {
   356  	var (
   357  		ctx = context.Background()
   358  		tag = "latest"
   359  		r   = http.NewServeMux()
   360  	)
   361  
   362  	m := newManifest(
   363  		newContent(ocispec.MediaTypeImageConfig, []byte("1")),
   364  		newContent(ocispec.MediaTypeImageLayerGzip, []byte("2")),
   365  	)
   366  	mc := newContent(ocispec.MediaTypeImageManifest, m.OCIManifest())
   367  	m.RegisterHandler(r, name)
   368  	r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, tag), mc)
   369  	r.Handle(fmt.Sprintf("/v2/%s/manifests/%s", name, mc.Digest()), mc)
   370  
   371  	base, ro, close := sf(logHandler{t, r})
   372  	defer close()
   373  
   374  	resolver := NewResolver(ro)
   375  	image := fmt.Sprintf("%s/%s:%s", base, name, tag)
   376  
   377  	_, d, err := resolver.Resolve(ctx, image)
   378  	if err != nil {
   379  		t.Fatal(err)
   380  	}
   381  	f, err := resolver.Fetcher(ctx, image)
   382  	if err != nil {
   383  		t.Fatal(err)
   384  	}
   385  
   386  	refs, err := testocimanifest(ctx, f, d)
   387  	if err != nil {
   388  		t.Fatal(err)
   389  	}
   390  
   391  	if len(refs) != 2 {
   392  		t.Fatalf("Unexpected number of references: %d, expected 2", len(refs))
   393  	}
   394  
   395  	for _, ref := range refs {
   396  		if err := testFetch(ctx, f, ref); err != nil {
   397  			t.Fatal(err)
   398  		}
   399  	}
   400  }
   401  
   402  func testFetch(ctx context.Context, f remotes.Fetcher, desc ocispec.Descriptor) error {
   403  	r, err := f.Fetch(ctx, desc)
   404  	if err != nil {
   405  		return err
   406  	}
   407  	dgstr := desc.Digest.Algorithm().Digester()
   408  	io.Copy(dgstr.Hash(), r)
   409  	if dgstr.Digest() != desc.Digest {
   410  		return errors.Errorf("content mismatch: %s != %s", dgstr.Digest(), desc.Digest)
   411  	}
   412  
   413  	return nil
   414  }
   415  
   416  func testocimanifest(ctx context.Context, f remotes.Fetcher, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) {
   417  	r, err := f.Fetch(ctx, desc)
   418  	if err != nil {
   419  		return nil, errors.Wrapf(err, "failed to fetch %s", desc.Digest)
   420  	}
   421  	p, err := ioutil.ReadAll(r)
   422  	if err != nil {
   423  		return nil, err
   424  	}
   425  	if dgst := desc.Digest.Algorithm().FromBytes(p); dgst != desc.Digest {
   426  		return nil, errors.Errorf("digest mismatch: %s != %s", dgst, desc.Digest)
   427  	}
   428  
   429  	var manifest ocispec.Manifest
   430  	if err := json.Unmarshal(p, &manifest); err != nil {
   431  		return nil, err
   432  	}
   433  
   434  	var descs []ocispec.Descriptor
   435  
   436  	descs = append(descs, manifest.Config)
   437  	descs = append(descs, manifest.Layers...)
   438  
   439  	return descs, nil
   440  }
   441  
   442  type testContent struct {
   443  	mediaType string
   444  	content   []byte
   445  }
   446  
   447  func newContent(mediaType string, b []byte) testContent {
   448  	return testContent{
   449  		mediaType: mediaType,
   450  		content:   b,
   451  	}
   452  }
   453  
   454  func (tc testContent) Descriptor() ocispec.Descriptor {
   455  	return ocispec.Descriptor{
   456  		MediaType: tc.mediaType,
   457  		Digest:    digest.FromBytes(tc.content),
   458  		Size:      int64(len(tc.content)),
   459  	}
   460  }
   461  
   462  func (tc testContent) Digest() digest.Digest {
   463  	return digest.FromBytes(tc.content)
   464  }
   465  
   466  func (tc testContent) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   467  	w.Header().Add("Content-Type", tc.mediaType)
   468  	w.Header().Add("Content-Length", strconv.Itoa(len(tc.content)))
   469  	w.Header().Add("Docker-Content-Digest", tc.Digest().String())
   470  	w.WriteHeader(http.StatusOK)
   471  	w.Write(tc.content)
   472  }
   473  
   474  type testManifest struct {
   475  	config     testContent
   476  	references []testContent
   477  }
   478  
   479  func newManifest(config testContent, refs ...testContent) testManifest {
   480  	return testManifest{
   481  		config:     config,
   482  		references: refs,
   483  	}
   484  }
   485  
   486  func (m testManifest) OCIManifest() []byte {
   487  	manifest := ocispec.Manifest{
   488  		Versioned: specs.Versioned{
   489  			SchemaVersion: 1,
   490  		},
   491  		Config: m.config.Descriptor(),
   492  		Layers: make([]ocispec.Descriptor, len(m.references)),
   493  	}
   494  	for i, c := range append(m.references) {
   495  		manifest.Layers[i] = c.Descriptor()
   496  	}
   497  	b, _ := json.Marshal(manifest)
   498  	return b
   499  }
   500  
   501  func (m testManifest) RegisterHandler(r *http.ServeMux, name string) {
   502  	for _, c := range append(m.references, m.config) {
   503  		r.Handle(fmt.Sprintf("/v2/%s/blobs/%s", name, c.Digest()), c)
   504  	}
   505  }