github.com/stackdocker/rkt@v0.10.1-0.20151109095037-1aa827478248/tests/test-auth-server/aci/server.go (about) 1 // Copyright 2015 The rkt Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aci 16 17 import ( 18 "crypto/tls" 19 "encoding/base64" 20 "fmt" 21 "net/http" 22 "net/http/httptest" 23 "os/exec" 24 "path/filepath" 25 "strings" 26 ) 27 28 type Type int 29 30 const ( 31 None Type = iota 32 Basic 33 Oauth 34 ) 35 36 type httpError struct { 37 code int 38 message string 39 } 40 41 func (e *httpError) Error() string { 42 return fmt.Sprintf("%d: %s", e.code, e.message) 43 } 44 45 type serverHandler struct { 46 auth Type 47 stop chan<- struct{} 48 msg chan<- string 49 tools *aciToolkit 50 } 51 52 func (h *serverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 53 switch r.Method { 54 case "POST": 55 w.WriteHeader(http.StatusOK) 56 h.stop <- struct{}{} 57 return 58 case "GET": 59 // handled later 60 default: 61 w.WriteHeader(http.StatusMethodNotAllowed) 62 return 63 } 64 switch h.auth { 65 case None: 66 // no auth to do. 67 case Basic: 68 payload, httpErr := getAuthPayload(r, "Basic") 69 if httpErr != nil { 70 w.WriteHeader(httpErr.code) 71 h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message)) 72 return 73 } 74 creds, err := base64.StdEncoding.DecodeString(string(payload)) 75 if err != nil { 76 w.WriteHeader(http.StatusBadRequest) 77 h.sendMsg(fmt.Sprintf(`Badly formed "Authorization" header`)) 78 return 79 } 80 parts := strings.Split(string(creds), ":") 81 if len(parts) != 2 { 82 w.WriteHeader(http.StatusBadRequest) 83 h.sendMsg(fmt.Sprintf(`Badly formed "Authorization" header (2)`)) 84 return 85 } 86 user := parts[0] 87 password := parts[1] 88 if user != "bar" || password != "baz" { 89 w.WriteHeader(http.StatusUnauthorized) 90 h.sendMsg(fmt.Sprintf("Bad credentials: %q", string(creds))) 91 return 92 } 93 case Oauth: 94 payload, httpErr := getAuthPayload(r, "Bearer") 95 if httpErr != nil { 96 w.WriteHeader(httpErr.code) 97 h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message)) 98 return 99 } 100 if payload != "sometoken" { 101 w.WriteHeader(http.StatusUnauthorized) 102 h.sendMsg(fmt.Sprintf(`Bad token: %q`, payload)) 103 return 104 } 105 default: 106 panic("Woe is me!") 107 } 108 h.sendMsg(fmt.Sprintf("Trying to serve %q", r.URL.String())) 109 switch filepath.Base(r.URL.Path) { 110 case "prog.aci": 111 h.sendMsg(fmt.Sprintf(" serving")) 112 if data, err := h.tools.prepareACI(); err != nil { 113 w.WriteHeader(http.StatusInternalServerError) 114 h.sendMsg(fmt.Sprintf(" failed (%v)", err)) 115 } else { 116 w.Write(data) 117 h.sendMsg(fmt.Sprintf(" done.")) 118 } 119 default: 120 h.sendMsg(fmt.Sprintf(" not found.")) 121 w.WriteHeader(http.StatusNotFound) 122 } 123 } 124 125 func (h *serverHandler) sendMsg(msg string) { 126 select { 127 case h.msg <- msg: 128 default: 129 } 130 } 131 132 func getAuthPayload(r *http.Request, authType string) (string, *httpError) { 133 auth := r.Header.Get("Authorization") 134 if auth == "" { 135 err := &httpError{ 136 code: http.StatusUnauthorized, 137 message: "No auth", 138 } 139 return "", err 140 } 141 parts := strings.Split(auth, " ") 142 if len(parts) != 2 { 143 err := &httpError{ 144 code: http.StatusBadRequest, 145 message: "Malformed auth", 146 } 147 return "", err 148 } 149 if parts[0] != authType { 150 err := &httpError{ 151 code: http.StatusUnauthorized, 152 message: "Wrong auth", 153 } 154 return "", err 155 } 156 return parts[1], nil 157 } 158 159 type Server struct { 160 Stop <-chan struct{} 161 Msg <-chan string 162 Conf string 163 URL string 164 handler *serverHandler 165 http *httptest.Server 166 } 167 168 func (s *Server) Close() { 169 s.http.Close() 170 close(s.handler.msg) 171 close(s.handler.stop) 172 } 173 174 func NewServer(auth Type, msgCapacity int) (*Server, error) { 175 return NewServerWithPaths(auth, msgCapacity, "actool", "go") 176 } 177 178 func NewServerWithPaths(auth Type, msgCapacity int, acTool, goTool string) (*Server, error) { 179 if !filepath.IsAbs(acTool) { 180 absAcTool, err := getTool(acTool) 181 if err != nil { 182 return nil, err 183 } 184 acTool = absAcTool 185 } 186 if !filepath.IsAbs(goTool) { 187 absGoTool, err := getTool(goTool) 188 if err != nil { 189 return nil, err 190 } 191 goTool = absGoTool 192 } 193 stop := make(chan struct{}) 194 msg := make(chan string, msgCapacity) 195 server := &Server{ 196 Stop: stop, 197 Msg: msg, 198 handler: &serverHandler{ 199 auth: auth, 200 stop: stop, 201 msg: msg, 202 tools: &aciToolkit{ 203 acTool: acTool, 204 goTool: goTool, 205 }, 206 }, 207 } 208 server.http = httptest.NewUnstartedServer(server.handler) 209 server.http.TLS = &tls.Config{InsecureSkipVerify: true} 210 server.http.StartTLS() 211 server.URL = server.http.URL 212 host := server.http.Listener.Addr().String() 213 switch auth { 214 case None: 215 // nothing to do 216 case Basic: 217 creds := `"user": "bar", 218 "password": "baz"` 219 server.Conf = sprintCreds(host, "basic", creds) 220 case Oauth: 221 creds := `"token": "sometoken"` 222 server.Conf = sprintCreds(host, "oauth", creds) 223 default: 224 panic("Woe is me!") 225 } 226 return server, nil 227 } 228 229 func getTool(tool string) (string, error) { 230 toolPath, err := exec.LookPath(tool) 231 if err != nil { 232 return "", fmt.Errorf("failed to find %s in $PATH: %v", tool, err) 233 } 234 absToolPath, err := filepath.Abs(toolPath) 235 if err != nil { 236 return "", fmt.Errorf("failed to get absolute path of %s: %v", tool, err) 237 } 238 return absToolPath, nil 239 } 240 241 func sprintCreds(host, auth, creds string) string { 242 return fmt.Sprintf(` 243 { 244 "rktKind": "auth", 245 "rktVersion": "v1", 246 "domains": ["%s"], 247 "type": "%s", 248 "credentials": 249 { 250 %s 251 } 252 } 253 254 `, host, auth, creds) 255 }