github.com/blixtra/rkt@v0.8.1-0.20160204105720-ab0d1add1a43/tests/testutils/aci-server/server.go (about)

     1  // Copyright 2015 The rkt Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aci
    16  
    17  import (
    18  	"crypto/sha512"
    19  	"crypto/tls"
    20  	"encoding/base64"
    21  	"fmt"
    22  	"io/ioutil"
    23  	"net"
    24  	"net/http"
    25  	"net/http/httptest"
    26  	"path/filepath"
    27  	"strings"
    28  )
    29  
    30  type PortType int
    31  
    32  const (
    33  	PortFixed PortType = iota
    34  	PortRandom
    35  )
    36  
    37  type ProtocolType int
    38  
    39  const (
    40  	ProtocolHttps ProtocolType = iota
    41  	ProtocolHttp
    42  )
    43  
    44  type AuthType int
    45  
    46  const (
    47  	AuthNone AuthType = iota
    48  	AuthBasic
    49  	AuthOauth
    50  )
    51  
    52  type ServerType int
    53  
    54  const (
    55  	ServerOrdinary ServerType = iota
    56  	ServerQuay
    57  )
    58  
    59  type httpError struct {
    60  	code    int
    61  	message string
    62  }
    63  
    64  func (e *httpError) Error() string {
    65  	return fmt.Sprintf("%d: %s", e.code, e.message)
    66  }
    67  
    68  type servedFile struct {
    69  	path string
    70  	etag string
    71  }
    72  
    73  func newServedFile(path string) (*servedFile, error) {
    74  	contents, err := ioutil.ReadFile(path)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	checksum := sha512.Sum512(contents)
    79  	sf := &servedFile{
    80  		path: path,
    81  		etag: fmt.Sprintf("%x", checksum),
    82  	}
    83  	return sf, nil
    84  }
    85  
    86  type serverHandler struct {
    87  	server       ServerType
    88  	auth         AuthType
    89  	protocol     ProtocolType
    90  	msg          chan<- string
    91  	fileSet      map[string]*servedFile
    92  	servedImages map[string]struct{}
    93  	serverURL    string
    94  }
    95  
    96  func (h *serverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    97  	if r.Method != "GET" {
    98  		w.WriteHeader(http.StatusMethodNotAllowed)
    99  		return
   100  	}
   101  	if authOk := h.handleAuth(w, r); !authOk {
   102  		return
   103  	}
   104  	h.sendMsg(fmt.Sprintf("Trying to serve %q", r.URL.String()))
   105  	h.handleRequest(w, r)
   106  }
   107  
   108  func (h *serverHandler) handleAuth(w http.ResponseWriter, r *http.Request) bool {
   109  	switch h.auth {
   110  	case AuthNone:
   111  		// no auth to do.
   112  		return true
   113  	case AuthBasic:
   114  		return h.handleBasicAuth(w, r)
   115  	case AuthOauth:
   116  		return h.handleOauthAuth(w, r)
   117  	default:
   118  		panic("Woe is me!")
   119  	}
   120  }
   121  
   122  func (h *serverHandler) handleBasicAuth(w http.ResponseWriter, r *http.Request) bool {
   123  	payload, httpErr := getAuthPayload(r, "Basic")
   124  	if httpErr != nil {
   125  		w.WriteHeader(httpErr.code)
   126  		h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message))
   127  		return false
   128  	}
   129  	creds, err := base64.StdEncoding.DecodeString(string(payload))
   130  	if err != nil {
   131  		w.WriteHeader(http.StatusBadRequest)
   132  		h.sendMsg(`Badly formed "Authorization" header`)
   133  		return false
   134  	}
   135  	parts := strings.Split(string(creds), ":")
   136  	if len(parts) != 2 {
   137  		w.WriteHeader(http.StatusBadRequest)
   138  		h.sendMsg(`Badly formed "Authorization" header (2)`)
   139  		return false
   140  	}
   141  	user := parts[0]
   142  	password := parts[1]
   143  	if user != "bar" || password != "baz" {
   144  		w.WriteHeader(http.StatusUnauthorized)
   145  		h.sendMsg(fmt.Sprintf("Bad credentials: %q", string(creds)))
   146  		return false
   147  	}
   148  	return true
   149  }
   150  
   151  func (h *serverHandler) handleOauthAuth(w http.ResponseWriter, r *http.Request) bool {
   152  	payload, httpErr := getAuthPayload(r, "Bearer")
   153  	if httpErr != nil {
   154  		w.WriteHeader(httpErr.code)
   155  		h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message))
   156  		return false
   157  	}
   158  	if payload != "sometoken" {
   159  		w.WriteHeader(http.StatusUnauthorized)
   160  		h.sendMsg(fmt.Sprintf(`Bad token: %q`, payload))
   161  		return false
   162  	}
   163  	return true
   164  }
   165  
   166  func getAuthPayload(r *http.Request, authType string) (string, *httpError) {
   167  	auth := r.Header.Get("Authorization")
   168  	if auth == "" {
   169  		err := &httpError{
   170  			code:    http.StatusUnauthorized,
   171  			message: "No auth",
   172  		}
   173  		return "", err
   174  	}
   175  	parts := strings.Split(auth, " ")
   176  	if len(parts) != 2 {
   177  		err := &httpError{
   178  			code:    http.StatusBadRequest,
   179  			message: "Malformed auth",
   180  		}
   181  		return "", err
   182  	}
   183  	if parts[0] != authType {
   184  		err := &httpError{
   185  			code:    http.StatusUnauthorized,
   186  			message: "Wrong auth",
   187  		}
   188  		return "", err
   189  	}
   190  	return parts[1], nil
   191  }
   192  
   193  func (h *serverHandler) handleRequest(w http.ResponseWriter, r *http.Request) {
   194  	path := filepath.Base(r.URL.Path)
   195  	switch path {
   196  	case "/":
   197  		h.sendAcDiscovery(w)
   198  	default:
   199  		h.handleFile(w, path, r.Header)
   200  	}
   201  }
   202  
   203  func (h *serverHandler) sendAcDiscovery(w http.ResponseWriter) {
   204  	// TODO(krnowak): When appc spec gets the discovery over
   205  	// custom port feature, possibly take it into account here
   206  	indexHTML := fmt.Sprintf(`<meta name="ac-discovery" content="localhost %s/{name}.{ext}">`, h.serverURL)
   207  	w.Write([]byte(indexHTML))
   208  	h.sendMsg("  done.")
   209  }
   210  
   211  func (h *serverHandler) handleFile(w http.ResponseWriter, reqPath string, headers http.Header) {
   212  	sf, ok := h.fileSet[reqPath]
   213  	if !ok {
   214  		w.WriteHeader(http.StatusNotFound)
   215  		h.sendMsg("  not found.")
   216  		return
   217  	}
   218  	if !h.canServe(reqPath, w) {
   219  		return
   220  	}
   221  	if headers.Get("If-None-Match") == sf.etag {
   222  		addCacheHeaders(w, sf)
   223  		w.WriteHeader(http.StatusNotModified)
   224  		h.sendMsg("  not modified, done.")
   225  		return
   226  	}
   227  	contents, err := ioutil.ReadFile(sf.path)
   228  	if err != nil {
   229  		w.WriteHeader(http.StatusInternalServerError)
   230  		h.sendMsg("  not found, but specified in fileset; bug?")
   231  		return
   232  	}
   233  	addCacheHeaders(w, sf)
   234  	w.Write(contents)
   235  	reqImagePath, isAsc := isPathAnImageKey(reqPath)
   236  	if isAsc {
   237  		delete(h.servedImages, reqImagePath)
   238  	} else {
   239  		h.servedImages[reqPath] = struct{}{}
   240  	}
   241  	h.sendMsg("  done.")
   242  }
   243  
   244  func (h *serverHandler) canServe(reqPath string, w http.ResponseWriter) bool {
   245  	if h.server != ServerQuay {
   246  		return true
   247  	}
   248  	reqImagePath, isAsc := isPathAnImageKey(reqPath)
   249  	if !isAsc {
   250  		return true
   251  	}
   252  	if _, imageAlreadyServed := h.servedImages[reqImagePath]; imageAlreadyServed {
   253  		return true
   254  	}
   255  	w.WriteHeader(http.StatusAccepted)
   256  	h.sendMsg("  asking to defer the download")
   257  	return false
   258  }
   259  
   260  func addCacheHeaders(w http.ResponseWriter, sf *servedFile) {
   261  	w.Header().Set("ETag", sf.etag)
   262  	w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", 60*60*24)) // one day
   263  }
   264  
   265  func (h *serverHandler) sendMsg(msg string) {
   266  	select {
   267  	case h.msg <- msg:
   268  	default:
   269  	}
   270  }
   271  
   272  func isPathAnImageKey(path string) (string, bool) {
   273  	if strings.HasSuffix(path, ".asc") {
   274  		imagePath := strings.TrimSuffix(path, ".asc")
   275  		return imagePath, true
   276  	}
   277  	return "", false
   278  }
   279  
   280  type Server struct {
   281  	Msg     <-chan string
   282  	Conf    string
   283  	URL     string
   284  	handler *serverHandler
   285  	http    *httptest.Server
   286  }
   287  
   288  type ServerSetup struct {
   289  	Port        PortType
   290  	Protocol    ProtocolType
   291  	Server      ServerType
   292  	Auth        AuthType
   293  	MsgCapacity int
   294  }
   295  
   296  func GetDefaultServerSetup() *ServerSetup {
   297  	return &ServerSetup{
   298  		Port:        PortFixed,
   299  		Protocol:    ProtocolHttps,
   300  		Server:      ServerOrdinary,
   301  		Auth:        AuthNone,
   302  		MsgCapacity: 20,
   303  	}
   304  }
   305  
   306  func (s *Server) Close() {
   307  	s.http.Close()
   308  	close(s.handler.msg)
   309  }
   310  
   311  func (s *Server) UpdateFileSet(fileSet map[string]string) error {
   312  	s.handler.fileSet = make(map[string]*servedFile, len(fileSet))
   313  	for base, path := range fileSet {
   314  		sf, err := newServedFile(path)
   315  		if err != nil {
   316  			return err
   317  		}
   318  		s.handler.fileSet[base] = sf
   319  	}
   320  	return nil
   321  }
   322  
   323  func NewServer(setup *ServerSetup) *Server {
   324  	msg := make(chan string, setup.MsgCapacity)
   325  	server := &Server{
   326  		Msg: msg,
   327  		handler: &serverHandler{
   328  			auth:         setup.Auth,
   329  			msg:          msg,
   330  			server:       setup.Server,
   331  			protocol:     setup.Protocol,
   332  			fileSet:      make(map[string]*servedFile),
   333  			servedImages: make(map[string]struct{}),
   334  		},
   335  	}
   336  	server.http = httptest.NewUnstartedServer(server.handler)
   337  	// We use our own listener, so we can override a port number
   338  	// without using a "httptest.serve" flag. Using the
   339  	// "httptest.serve" flag together with an HTTP protocol
   340  	// results in blocking for debugging purposes as described in
   341  	// https://golang.org/src/net/http/httptest/server.go#L74.
   342  	// Here, we lose the ability, but we don't need it.
   343  	server.http.Listener = newLocalListener(setup.Port, setup.Protocol)
   344  	switch setup.Protocol {
   345  	case ProtocolHttp:
   346  		server.http.Start()
   347  	case ProtocolHttps:
   348  		server.http.TLS = &tls.Config{InsecureSkipVerify: true}
   349  		server.http.StartTLS()
   350  	default:
   351  		panic("Woe is me!")
   352  	}
   353  	server.URL = server.http.URL
   354  	server.handler.serverURL = server.http.URL
   355  	host := server.http.Listener.Addr().String()
   356  	switch setup.Auth {
   357  	case AuthNone:
   358  		// nothing to do
   359  	case AuthBasic:
   360  		creds := `"user": "bar",
   361  		"password": "baz"`
   362  		server.Conf = sprintCreds(host, "basic", creds)
   363  	case AuthOauth:
   364  		creds := `"token": "sometoken"`
   365  		server.Conf = sprintCreds(host, "oauth", creds)
   366  	default:
   367  		panic("Woe is me!")
   368  	}
   369  	return server
   370  }
   371  
   372  // based on code from net/http/httptest
   373  func newLocalListener(port PortType, protocol ProtocolType) net.Listener {
   374  	portNumber := 0
   375  	if port == PortFixed {
   376  		switch protocol {
   377  		case ProtocolHttp:
   378  			portNumber = 80
   379  		case ProtocolHttps:
   380  			portNumber = 443
   381  		}
   382  	}
   383  	l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", portNumber))
   384  	if err != nil {
   385  		if l, err = net.Listen("tcp6", fmt.Sprintf("[::1]:%d", portNumber)); err != nil {
   386  			panic(fmt.Sprintf("aci test server: failed to listen on a port: %v", err))
   387  		}
   388  	}
   389  	return l
   390  }
   391  
   392  func sprintCreds(host, auth, creds string) string {
   393  	return fmt.Sprintf(`
   394  {
   395  	"rktKind": "auth",
   396  	"rktVersion": "v1",
   397  	"domains": ["%s"],
   398  	"type": "%s",
   399  	"credentials":
   400  	{
   401  		%s
   402  	}
   403  }
   404  
   405  `, host, auth, creds)
   406  }