github.com/stackdocker/rkt@v0.10.1-0.20151109095037-1aa827478248/tests/test-auth-server/aci/server.go (about)

     1  // Copyright 2015 The rkt Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aci
    16  
    17  import (
    18  	"crypto/tls"
    19  	"encoding/base64"
    20  	"fmt"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"os/exec"
    24  	"path/filepath"
    25  	"strings"
    26  )
    27  
    28  type Type int
    29  
    30  const (
    31  	None Type = iota
    32  	Basic
    33  	Oauth
    34  )
    35  
    36  type httpError struct {
    37  	code    int
    38  	message string
    39  }
    40  
    41  func (e *httpError) Error() string {
    42  	return fmt.Sprintf("%d: %s", e.code, e.message)
    43  }
    44  
    45  type serverHandler struct {
    46  	auth  Type
    47  	stop  chan<- struct{}
    48  	msg   chan<- string
    49  	tools *aciToolkit
    50  }
    51  
    52  func (h *serverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    53  	switch r.Method {
    54  	case "POST":
    55  		w.WriteHeader(http.StatusOK)
    56  		h.stop <- struct{}{}
    57  		return
    58  	case "GET":
    59  		// handled later
    60  	default:
    61  		w.WriteHeader(http.StatusMethodNotAllowed)
    62  		return
    63  	}
    64  	switch h.auth {
    65  	case None:
    66  		// no auth to do.
    67  	case Basic:
    68  		payload, httpErr := getAuthPayload(r, "Basic")
    69  		if httpErr != nil {
    70  			w.WriteHeader(httpErr.code)
    71  			h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message))
    72  			return
    73  		}
    74  		creds, err := base64.StdEncoding.DecodeString(string(payload))
    75  		if err != nil {
    76  			w.WriteHeader(http.StatusBadRequest)
    77  			h.sendMsg(fmt.Sprintf(`Badly formed "Authorization" header`))
    78  			return
    79  		}
    80  		parts := strings.Split(string(creds), ":")
    81  		if len(parts) != 2 {
    82  			w.WriteHeader(http.StatusBadRequest)
    83  			h.sendMsg(fmt.Sprintf(`Badly formed "Authorization" header (2)`))
    84  			return
    85  		}
    86  		user := parts[0]
    87  		password := parts[1]
    88  		if user != "bar" || password != "baz" {
    89  			w.WriteHeader(http.StatusUnauthorized)
    90  			h.sendMsg(fmt.Sprintf("Bad credentials: %q", string(creds)))
    91  			return
    92  		}
    93  	case Oauth:
    94  		payload, httpErr := getAuthPayload(r, "Bearer")
    95  		if httpErr != nil {
    96  			w.WriteHeader(httpErr.code)
    97  			h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message))
    98  			return
    99  		}
   100  		if payload != "sometoken" {
   101  			w.WriteHeader(http.StatusUnauthorized)
   102  			h.sendMsg(fmt.Sprintf(`Bad token: %q`, payload))
   103  			return
   104  		}
   105  	default:
   106  		panic("Woe is me!")
   107  	}
   108  	h.sendMsg(fmt.Sprintf("Trying to serve %q", r.URL.String()))
   109  	switch filepath.Base(r.URL.Path) {
   110  	case "prog.aci":
   111  		h.sendMsg(fmt.Sprintf("  serving"))
   112  		if data, err := h.tools.prepareACI(); err != nil {
   113  			w.WriteHeader(http.StatusInternalServerError)
   114  			h.sendMsg(fmt.Sprintf("    failed (%v)", err))
   115  		} else {
   116  			w.Write(data)
   117  			h.sendMsg(fmt.Sprintf("    done."))
   118  		}
   119  	default:
   120  		h.sendMsg(fmt.Sprintf("  not found."))
   121  		w.WriteHeader(http.StatusNotFound)
   122  	}
   123  }
   124  
   125  func (h *serverHandler) sendMsg(msg string) {
   126  	select {
   127  	case h.msg <- msg:
   128  	default:
   129  	}
   130  }
   131  
   132  func getAuthPayload(r *http.Request, authType string) (string, *httpError) {
   133  	auth := r.Header.Get("Authorization")
   134  	if auth == "" {
   135  		err := &httpError{
   136  			code:    http.StatusUnauthorized,
   137  			message: "No auth",
   138  		}
   139  		return "", err
   140  	}
   141  	parts := strings.Split(auth, " ")
   142  	if len(parts) != 2 {
   143  		err := &httpError{
   144  			code:    http.StatusBadRequest,
   145  			message: "Malformed auth",
   146  		}
   147  		return "", err
   148  	}
   149  	if parts[0] != authType {
   150  		err := &httpError{
   151  			code:    http.StatusUnauthorized,
   152  			message: "Wrong auth",
   153  		}
   154  		return "", err
   155  	}
   156  	return parts[1], nil
   157  }
   158  
   159  type Server struct {
   160  	Stop    <-chan struct{}
   161  	Msg     <-chan string
   162  	Conf    string
   163  	URL     string
   164  	handler *serverHandler
   165  	http    *httptest.Server
   166  }
   167  
   168  func (s *Server) Close() {
   169  	s.http.Close()
   170  	close(s.handler.msg)
   171  	close(s.handler.stop)
   172  }
   173  
   174  func NewServer(auth Type, msgCapacity int) (*Server, error) {
   175  	return NewServerWithPaths(auth, msgCapacity, "actool", "go")
   176  }
   177  
   178  func NewServerWithPaths(auth Type, msgCapacity int, acTool, goTool string) (*Server, error) {
   179  	if !filepath.IsAbs(acTool) {
   180  		absAcTool, err := getTool(acTool)
   181  		if err != nil {
   182  			return nil, err
   183  		}
   184  		acTool = absAcTool
   185  	}
   186  	if !filepath.IsAbs(goTool) {
   187  		absGoTool, err := getTool(goTool)
   188  		if err != nil {
   189  			return nil, err
   190  		}
   191  		goTool = absGoTool
   192  	}
   193  	stop := make(chan struct{})
   194  	msg := make(chan string, msgCapacity)
   195  	server := &Server{
   196  		Stop: stop,
   197  		Msg:  msg,
   198  		handler: &serverHandler{
   199  			auth: auth,
   200  			stop: stop,
   201  			msg:  msg,
   202  			tools: &aciToolkit{
   203  				acTool: acTool,
   204  				goTool: goTool,
   205  			},
   206  		},
   207  	}
   208  	server.http = httptest.NewUnstartedServer(server.handler)
   209  	server.http.TLS = &tls.Config{InsecureSkipVerify: true}
   210  	server.http.StartTLS()
   211  	server.URL = server.http.URL
   212  	host := server.http.Listener.Addr().String()
   213  	switch auth {
   214  	case None:
   215  		// nothing to do
   216  	case Basic:
   217  		creds := `"user": "bar",
   218  		"password": "baz"`
   219  		server.Conf = sprintCreds(host, "basic", creds)
   220  	case Oauth:
   221  		creds := `"token": "sometoken"`
   222  		server.Conf = sprintCreds(host, "oauth", creds)
   223  	default:
   224  		panic("Woe is me!")
   225  	}
   226  	return server, nil
   227  }
   228  
   229  func getTool(tool string) (string, error) {
   230  	toolPath, err := exec.LookPath(tool)
   231  	if err != nil {
   232  		return "", fmt.Errorf("failed to find %s in $PATH: %v", tool, err)
   233  	}
   234  	absToolPath, err := filepath.Abs(toolPath)
   235  	if err != nil {
   236  		return "", fmt.Errorf("failed to get absolute path of %s: %v", tool, err)
   237  	}
   238  	return absToolPath, nil
   239  }
   240  
   241  func sprintCreds(host, auth, creds string) string {
   242  	return fmt.Sprintf(`
   243  {
   244  	"rktKind": "auth",
   245  	"rktVersion": "v1",
   246  	"domains": ["%s"],
   247  	"type": "%s",
   248  	"credentials":
   249  	{
   250  		%s
   251  	}
   252  }
   253  
   254  `, host, auth, creds)
   255  }