github.com/rkt/rkt@v1.30.1-0.20200224141603-171c416fac02/tests/testutils/aci-server/server.go (about)

     1  // Copyright 2015 The rkt Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aci
    16  
    17  import (
    18  	"crypto/sha512"
    19  	"crypto/tls"
    20  	"encoding/base64"
    21  	"fmt"
    22  	"io/ioutil"
    23  	"net"
    24  	"net/http"
    25  	"net/http/httptest"
    26  	"path/filepath"
    27  	"strings"
    28  	"time"
    29  )
    30  
    31  type PortType int
    32  
    33  const (
    34  	PortFixed PortType = iota
    35  	PortRandom
    36  )
    37  
    38  type ProtocolType int
    39  
    40  const (
    41  	ProtocolHttps ProtocolType = iota
    42  	ProtocolHttp
    43  )
    44  
    45  type AuthType int
    46  
    47  const (
    48  	AuthNone AuthType = iota
    49  	AuthBasic
    50  	AuthOauth
    51  )
    52  
    53  type ServerType int
    54  
    55  const (
    56  	ServerOrdinary ServerType = iota
    57  	ServerQuay
    58  )
    59  
    60  type httpError struct {
    61  	code    int
    62  	message string
    63  }
    64  
    65  func (e *httpError) Error() string {
    66  	return fmt.Sprintf("%d: %s", e.code, e.message)
    67  }
    68  
    69  type servedFile struct {
    70  	path string
    71  	etag string
    72  }
    73  
    74  func newServedFile(path string) (*servedFile, error) {
    75  	contents, err := ioutil.ReadFile(path)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	checksum := sha512.Sum512(contents)
    80  	sf := &servedFile{
    81  		path: path,
    82  		etag: fmt.Sprintf("%x", checksum),
    83  	}
    84  	return sf, nil
    85  }
    86  
    87  type serverHandler struct {
    88  	server       ServerType
    89  	auth         AuthType
    90  	protocol     ProtocolType
    91  	msg          chan<- string
    92  	fileSet      map[string]*servedFile
    93  	servedImages map[string]struct{}
    94  	serverURL    string
    95  }
    96  
    97  func (h *serverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    98  	if r.Method != "GET" {
    99  		w.WriteHeader(http.StatusMethodNotAllowed)
   100  		return
   101  	}
   102  	if authOk := h.handleAuth(w, r); !authOk {
   103  		return
   104  	}
   105  	h.sendMsg(fmt.Sprintf("Trying to serve %q", r.URL.String()))
   106  	h.handleRequest(w, r)
   107  }
   108  
   109  func (h *serverHandler) handleAuth(w http.ResponseWriter, r *http.Request) bool {
   110  	switch h.auth {
   111  	case AuthNone:
   112  		// no auth to do.
   113  		return true
   114  	case AuthBasic:
   115  		return h.handleBasicAuth(w, r)
   116  	case AuthOauth:
   117  		return h.handleOauthAuth(w, r)
   118  	default:
   119  		panic("Woe is me!")
   120  	}
   121  }
   122  
   123  func (h *serverHandler) handleBasicAuth(w http.ResponseWriter, r *http.Request) bool {
   124  	payload, httpErr := getAuthPayload(r, "Basic")
   125  	if httpErr != nil {
   126  		w.WriteHeader(httpErr.code)
   127  		h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message))
   128  		return false
   129  	}
   130  	creds, err := base64.StdEncoding.DecodeString(string(payload))
   131  	if err != nil {
   132  		w.WriteHeader(http.StatusBadRequest)
   133  		h.sendMsg(`Badly formed "Authorization" header`)
   134  		return false
   135  	}
   136  	parts := strings.Split(string(creds), ":")
   137  	if len(parts) != 2 {
   138  		w.WriteHeader(http.StatusBadRequest)
   139  		h.sendMsg(`Badly formed "Authorization" header (2)`)
   140  		return false
   141  	}
   142  	user := parts[0]
   143  	password := parts[1]
   144  	if user != "bar" || password != "baz" {
   145  		w.WriteHeader(http.StatusUnauthorized)
   146  		h.sendMsg(fmt.Sprintf("Bad credentials: %q", string(creds)))
   147  		return false
   148  	}
   149  	return true
   150  }
   151  
   152  func (h *serverHandler) handleOauthAuth(w http.ResponseWriter, r *http.Request) bool {
   153  	payload, httpErr := getAuthPayload(r, "Bearer")
   154  	if httpErr != nil {
   155  		w.WriteHeader(httpErr.code)
   156  		h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message))
   157  		return false
   158  	}
   159  	if payload != "sometoken" {
   160  		w.WriteHeader(http.StatusUnauthorized)
   161  		h.sendMsg(fmt.Sprintf(`Bad token: %q`, payload))
   162  		return false
   163  	}
   164  	return true
   165  }
   166  
   167  func getAuthPayload(r *http.Request, authType string) (string, *httpError) {
   168  	auth := r.Header.Get("Authorization")
   169  	if auth == "" {
   170  		err := &httpError{
   171  			code:    http.StatusUnauthorized,
   172  			message: "No auth",
   173  		}
   174  		return "", err
   175  	}
   176  	parts := strings.Split(auth, " ")
   177  	if len(parts) != 2 {
   178  		err := &httpError{
   179  			code:    http.StatusBadRequest,
   180  			message: "Malformed auth",
   181  		}
   182  		return "", err
   183  	}
   184  	if parts[0] != authType {
   185  		err := &httpError{
   186  			code:    http.StatusUnauthorized,
   187  			message: "Wrong auth",
   188  		}
   189  		return "", err
   190  	}
   191  	return parts[1], nil
   192  }
   193  
   194  func (h *serverHandler) handleRequest(w http.ResponseWriter, r *http.Request) {
   195  	path := filepath.Base(r.URL.Path)
   196  	switch path {
   197  	case "/":
   198  		h.sendAcDiscovery(w)
   199  	default:
   200  		h.handleFile(w, path, r.Header)
   201  	}
   202  }
   203  
   204  func (h *serverHandler) sendAcDiscovery(w http.ResponseWriter) {
   205  	// TODO(krnowak): When appc spec gets the discovery over
   206  	// custom port feature, possibly take it into account here
   207  	indexHTML := fmt.Sprintf(`<meta name="ac-discovery" content="localhost %s/{name}.{ext}">`, h.serverURL)
   208  	w.Write([]byte(indexHTML))
   209  	h.sendMsg("  done.")
   210  }
   211  
   212  func (h *serverHandler) handleFile(w http.ResponseWriter, reqPath string, headers http.Header) {
   213  	sf, ok := h.fileSet[reqPath]
   214  	if !ok {
   215  		w.WriteHeader(http.StatusNotFound)
   216  		h.sendMsg("  not found.")
   217  		return
   218  	}
   219  	if !h.canServe(reqPath, w) {
   220  		return
   221  	}
   222  	if headers.Get("If-None-Match") == sf.etag {
   223  		addCacheHeaders(w, sf)
   224  		w.WriteHeader(http.StatusNotModified)
   225  		h.sendMsg("  not modified, done.")
   226  		return
   227  	}
   228  	contents, err := ioutil.ReadFile(sf.path)
   229  	if err != nil {
   230  		w.WriteHeader(http.StatusInternalServerError)
   231  		h.sendMsg("  not found, but specified in fileset; bug?")
   232  		return
   233  	}
   234  	addCacheHeaders(w, sf)
   235  	w.Write(contents)
   236  	reqImagePath, isAsc := isPathAnImageKey(reqPath)
   237  	if isAsc {
   238  		delete(h.servedImages, reqImagePath)
   239  	} else {
   240  		h.servedImages[reqPath] = struct{}{}
   241  	}
   242  	h.sendMsg("  done.")
   243  }
   244  
   245  func (h *serverHandler) canServe(reqPath string, w http.ResponseWriter) bool {
   246  	if h.server != ServerQuay {
   247  		return true
   248  	}
   249  	reqImagePath, isAsc := isPathAnImageKey(reqPath)
   250  	if !isAsc {
   251  		return true
   252  	}
   253  	if _, imageAlreadyServed := h.servedImages[reqImagePath]; imageAlreadyServed {
   254  		return true
   255  	}
   256  	w.WriteHeader(http.StatusAccepted)
   257  	h.sendMsg("  asking to defer the download")
   258  	return false
   259  }
   260  
   261  func addCacheHeaders(w http.ResponseWriter, sf *servedFile) {
   262  	w.Header().Set("ETag", sf.etag)
   263  	w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", 60*60*24)) // one day
   264  }
   265  
   266  func (h *serverHandler) sendMsg(msg string) {
   267  	select {
   268  	case h.msg <- msg:
   269  	default:
   270  	}
   271  }
   272  
   273  func isPathAnImageKey(path string) (string, bool) {
   274  	if strings.HasSuffix(path, ".asc") {
   275  		imagePath := strings.TrimSuffix(path, ".asc")
   276  		return imagePath, true
   277  	}
   278  	return "", false
   279  }
   280  
   281  type Server struct {
   282  	Msg     <-chan string
   283  	Conf    string
   284  	URL     string
   285  	handler *serverHandler
   286  	http    *httptest.Server
   287  }
   288  
   289  type ServerSetup struct {
   290  	Port        PortType
   291  	Protocol    ProtocolType
   292  	Server      ServerType
   293  	Auth        AuthType
   294  	MsgCapacity int
   295  }
   296  
   297  func GetDefaultServerSetup() *ServerSetup {
   298  	return &ServerSetup{
   299  		Port:        PortFixed,
   300  		Protocol:    ProtocolHttps,
   301  		Server:      ServerOrdinary,
   302  		Auth:        AuthNone,
   303  		MsgCapacity: 20,
   304  	}
   305  }
   306  
   307  func (s *Server) Close() {
   308  	s.http.Close()
   309  	close(s.handler.msg)
   310  }
   311  
   312  func (s *Server) UpdateFileSet(fileSet map[string]string) error {
   313  	s.handler.fileSet = make(map[string]*servedFile, len(fileSet))
   314  	for base, path := range fileSet {
   315  		sf, err := newServedFile(path)
   316  		if err != nil {
   317  			return err
   318  		}
   319  		s.handler.fileSet[base] = sf
   320  	}
   321  	return nil
   322  }
   323  
   324  func NewServer(setup *ServerSetup) *Server {
   325  	msg := make(chan string, setup.MsgCapacity)
   326  	server := &Server{
   327  		Msg: msg,
   328  		handler: &serverHandler{
   329  			auth:         setup.Auth,
   330  			msg:          msg,
   331  			server:       setup.Server,
   332  			protocol:     setup.Protocol,
   333  			fileSet:      make(map[string]*servedFile),
   334  			servedImages: make(map[string]struct{}),
   335  		},
   336  	}
   337  	server.http = httptest.NewUnstartedServer(server.handler)
   338  	// We use our own listener, so we can override a port number
   339  	// without using a "httptest.serve" flag. Using the
   340  	// "httptest.serve" flag together with an HTTP protocol
   341  	// results in blocking for debugging purposes as described in
   342  	// https://golang.org/src/net/http/httptest/server.go#L74.
   343  	// Here, we lose the ability, but we don't need it.
   344  	server.http.Listener = newLocalListener(setup.Port, setup.Protocol)
   345  	switch setup.Protocol {
   346  	case ProtocolHttp:
   347  		server.http.Start()
   348  	case ProtocolHttps:
   349  		server.http.TLS = &tls.Config{InsecureSkipVerify: true}
   350  		server.http.StartTLS()
   351  	default:
   352  		panic("Woe is me!")
   353  	}
   354  	server.URL = server.http.URL
   355  	server.handler.serverURL = server.http.URL
   356  	host := server.http.Listener.Addr().String()
   357  	switch setup.Auth {
   358  	case AuthNone:
   359  		// nothing to do
   360  	case AuthBasic:
   361  		creds := `"user": "bar",
   362  		"password": "baz"`
   363  		server.Conf = sprintCreds(host, "basic", creds)
   364  	case AuthOauth:
   365  		creds := `"token": "sometoken"`
   366  		server.Conf = sprintCreds(host, "oauth", creds)
   367  	default:
   368  		panic("Woe is me!")
   369  	}
   370  	return server
   371  }
   372  
   373  func newLocalListener(port PortType, protocol ProtocolType) net.Listener {
   374  	portNumber := 0
   375  	if port == PortFixed {
   376  		switch protocol {
   377  		case ProtocolHttp:
   378  			portNumber = 80
   379  		case ProtocolHttps:
   380  			portNumber = 443
   381  		}
   382  	}
   383  	addrs, err := net.LookupHost("localhost")
   384  	if err != nil {
   385  		panic(`aci test server: failed to look up "localhost", really`)
   386  	}
   387  	var lookupErrs []string
   388  	for try := 0; try < 2; try++ {
   389  		for _, addr := range addrs {
   390  			addrport := fmt.Sprintf("%s:%d", addr, portNumber)
   391  			l, err := net.Listen("tcp", addrport)
   392  			if err == nil {
   393  				return l
   394  			}
   395  			lookupErrs = append(lookupErrs, fmt.Sprintf("(listen on %s, attempt #%d: %v)", addrport, try+1, err))
   396  		}
   397  		// TODO: When we have discovery on a custom port then
   398  		// we could drop listening on fixed ports, so we
   399  		// probably won't get any races between old server
   400  		// stopping to listen and new server starting to
   401  		// listen.
   402  		// https://github.com/appc/spec/pull/110
   403  		// Might be possible with ABD:
   404  		// https://github.com/appc/abd
   405  		time.Sleep(time.Second)
   406  	}
   407  	panic(fmt.Sprintf("aci test server: failed to listen on localhost:%d: %v", portNumber, lookupErrs))
   408  }
   409  
   410  func sprintCreds(host, auth, creds string) string {
   411  	return fmt.Sprintf(`
   412  {
   413  	"rktKind": "auth",
   414  	"rktVersion": "v1",
   415  	"domains": ["%s"],
   416  	"type": "%s",
   417  	"credentials":
   418  	{
   419  		%s
   420  	}
   421  }
   422  
   423  `, host, auth, creds)
   424  }