github.com/blixtra/rkt@v0.8.1-0.20160204105720-ab0d1add1a43/tests/testutils/aci-server/server.go (about) 1 // Copyright 2015 The rkt Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aci 16 17 import ( 18 "crypto/sha512" 19 "crypto/tls" 20 "encoding/base64" 21 "fmt" 22 "io/ioutil" 23 "net" 24 "net/http" 25 "net/http/httptest" 26 "path/filepath" 27 "strings" 28 ) 29 30 type PortType int 31 32 const ( 33 PortFixed PortType = iota 34 PortRandom 35 ) 36 37 type ProtocolType int 38 39 const ( 40 ProtocolHttps ProtocolType = iota 41 ProtocolHttp 42 ) 43 44 type AuthType int 45 46 const ( 47 AuthNone AuthType = iota 48 AuthBasic 49 AuthOauth 50 ) 51 52 type ServerType int 53 54 const ( 55 ServerOrdinary ServerType = iota 56 ServerQuay 57 ) 58 59 type httpError struct { 60 code int 61 message string 62 } 63 64 func (e *httpError) Error() string { 65 return fmt.Sprintf("%d: %s", e.code, e.message) 66 } 67 68 type servedFile struct { 69 path string 70 etag string 71 } 72 73 func newServedFile(path string) (*servedFile, error) { 74 contents, err := ioutil.ReadFile(path) 75 if err != nil { 76 return nil, err 77 } 78 checksum := sha512.Sum512(contents) 79 sf := &servedFile{ 80 path: path, 81 etag: fmt.Sprintf("%x", checksum), 82 } 83 return sf, nil 84 } 85 86 type serverHandler struct { 87 server ServerType 88 auth AuthType 89 protocol ProtocolType 90 msg chan<- string 91 fileSet map[string]*servedFile 92 servedImages map[string]struct{} 93 serverURL string 94 } 95 96 func (h *serverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 97 if r.Method != "GET" { 98 w.WriteHeader(http.StatusMethodNotAllowed) 99 return 100 } 101 if authOk := h.handleAuth(w, r); !authOk { 102 return 103 } 104 h.sendMsg(fmt.Sprintf("Trying to serve %q", r.URL.String())) 105 h.handleRequest(w, r) 106 } 107 108 func (h *serverHandler) handleAuth(w http.ResponseWriter, r *http.Request) bool { 109 switch h.auth { 110 case AuthNone: 111 // no auth to do. 112 return true 113 case AuthBasic: 114 return h.handleBasicAuth(w, r) 115 case AuthOauth: 116 return h.handleOauthAuth(w, r) 117 default: 118 panic("Woe is me!") 119 } 120 } 121 122 func (h *serverHandler) handleBasicAuth(w http.ResponseWriter, r *http.Request) bool { 123 payload, httpErr := getAuthPayload(r, "Basic") 124 if httpErr != nil { 125 w.WriteHeader(httpErr.code) 126 h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message)) 127 return false 128 } 129 creds, err := base64.StdEncoding.DecodeString(string(payload)) 130 if err != nil { 131 w.WriteHeader(http.StatusBadRequest) 132 h.sendMsg(`Badly formed "Authorization" header`) 133 return false 134 } 135 parts := strings.Split(string(creds), ":") 136 if len(parts) != 2 { 137 w.WriteHeader(http.StatusBadRequest) 138 h.sendMsg(`Badly formed "Authorization" header (2)`) 139 return false 140 } 141 user := parts[0] 142 password := parts[1] 143 if user != "bar" || password != "baz" { 144 w.WriteHeader(http.StatusUnauthorized) 145 h.sendMsg(fmt.Sprintf("Bad credentials: %q", string(creds))) 146 return false 147 } 148 return true 149 } 150 151 func (h *serverHandler) handleOauthAuth(w http.ResponseWriter, r *http.Request) bool { 152 payload, httpErr := getAuthPayload(r, "Bearer") 153 if httpErr != nil { 154 w.WriteHeader(httpErr.code) 155 h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message)) 156 return false 157 } 158 if payload != "sometoken" { 159 w.WriteHeader(http.StatusUnauthorized) 160 h.sendMsg(fmt.Sprintf(`Bad token: %q`, payload)) 161 return false 162 } 163 return true 164 } 165 166 func getAuthPayload(r *http.Request, authType string) (string, *httpError) { 167 auth := r.Header.Get("Authorization") 168 if auth == "" { 169 err := &httpError{ 170 code: http.StatusUnauthorized, 171 message: "No auth", 172 } 173 return "", err 174 } 175 parts := strings.Split(auth, " ") 176 if len(parts) != 2 { 177 err := &httpError{ 178 code: http.StatusBadRequest, 179 message: "Malformed auth", 180 } 181 return "", err 182 } 183 if parts[0] != authType { 184 err := &httpError{ 185 code: http.StatusUnauthorized, 186 message: "Wrong auth", 187 } 188 return "", err 189 } 190 return parts[1], nil 191 } 192 193 func (h *serverHandler) handleRequest(w http.ResponseWriter, r *http.Request) { 194 path := filepath.Base(r.URL.Path) 195 switch path { 196 case "/": 197 h.sendAcDiscovery(w) 198 default: 199 h.handleFile(w, path, r.Header) 200 } 201 } 202 203 func (h *serverHandler) sendAcDiscovery(w http.ResponseWriter) { 204 // TODO(krnowak): When appc spec gets the discovery over 205 // custom port feature, possibly take it into account here 206 indexHTML := fmt.Sprintf(`<meta name="ac-discovery" content="localhost %s/{name}.{ext}">`, h.serverURL) 207 w.Write([]byte(indexHTML)) 208 h.sendMsg(" done.") 209 } 210 211 func (h *serverHandler) handleFile(w http.ResponseWriter, reqPath string, headers http.Header) { 212 sf, ok := h.fileSet[reqPath] 213 if !ok { 214 w.WriteHeader(http.StatusNotFound) 215 h.sendMsg(" not found.") 216 return 217 } 218 if !h.canServe(reqPath, w) { 219 return 220 } 221 if headers.Get("If-None-Match") == sf.etag { 222 addCacheHeaders(w, sf) 223 w.WriteHeader(http.StatusNotModified) 224 h.sendMsg(" not modified, done.") 225 return 226 } 227 contents, err := ioutil.ReadFile(sf.path) 228 if err != nil { 229 w.WriteHeader(http.StatusInternalServerError) 230 h.sendMsg(" not found, but specified in fileset; bug?") 231 return 232 } 233 addCacheHeaders(w, sf) 234 w.Write(contents) 235 reqImagePath, isAsc := isPathAnImageKey(reqPath) 236 if isAsc { 237 delete(h.servedImages, reqImagePath) 238 } else { 239 h.servedImages[reqPath] = struct{}{} 240 } 241 h.sendMsg(" done.") 242 } 243 244 func (h *serverHandler) canServe(reqPath string, w http.ResponseWriter) bool { 245 if h.server != ServerQuay { 246 return true 247 } 248 reqImagePath, isAsc := isPathAnImageKey(reqPath) 249 if !isAsc { 250 return true 251 } 252 if _, imageAlreadyServed := h.servedImages[reqImagePath]; imageAlreadyServed { 253 return true 254 } 255 w.WriteHeader(http.StatusAccepted) 256 h.sendMsg(" asking to defer the download") 257 return false 258 } 259 260 func addCacheHeaders(w http.ResponseWriter, sf *servedFile) { 261 w.Header().Set("ETag", sf.etag) 262 w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", 60*60*24)) // one day 263 } 264 265 func (h *serverHandler) sendMsg(msg string) { 266 select { 267 case h.msg <- msg: 268 default: 269 } 270 } 271 272 func isPathAnImageKey(path string) (string, bool) { 273 if strings.HasSuffix(path, ".asc") { 274 imagePath := strings.TrimSuffix(path, ".asc") 275 return imagePath, true 276 } 277 return "", false 278 } 279 280 type Server struct { 281 Msg <-chan string 282 Conf string 283 URL string 284 handler *serverHandler 285 http *httptest.Server 286 } 287 288 type ServerSetup struct { 289 Port PortType 290 Protocol ProtocolType 291 Server ServerType 292 Auth AuthType 293 MsgCapacity int 294 } 295 296 func GetDefaultServerSetup() *ServerSetup { 297 return &ServerSetup{ 298 Port: PortFixed, 299 Protocol: ProtocolHttps, 300 Server: ServerOrdinary, 301 Auth: AuthNone, 302 MsgCapacity: 20, 303 } 304 } 305 306 func (s *Server) Close() { 307 s.http.Close() 308 close(s.handler.msg) 309 } 310 311 func (s *Server) UpdateFileSet(fileSet map[string]string) error { 312 s.handler.fileSet = make(map[string]*servedFile, len(fileSet)) 313 for base, path := range fileSet { 314 sf, err := newServedFile(path) 315 if err != nil { 316 return err 317 } 318 s.handler.fileSet[base] = sf 319 } 320 return nil 321 } 322 323 func NewServer(setup *ServerSetup) *Server { 324 msg := make(chan string, setup.MsgCapacity) 325 server := &Server{ 326 Msg: msg, 327 handler: &serverHandler{ 328 auth: setup.Auth, 329 msg: msg, 330 server: setup.Server, 331 protocol: setup.Protocol, 332 fileSet: make(map[string]*servedFile), 333 servedImages: make(map[string]struct{}), 334 }, 335 } 336 server.http = httptest.NewUnstartedServer(server.handler) 337 // We use our own listener, so we can override a port number 338 // without using a "httptest.serve" flag. Using the 339 // "httptest.serve" flag together with an HTTP protocol 340 // results in blocking for debugging purposes as described in 341 // https://golang.org/src/net/http/httptest/server.go#L74. 342 // Here, we lose the ability, but we don't need it. 343 server.http.Listener = newLocalListener(setup.Port, setup.Protocol) 344 switch setup.Protocol { 345 case ProtocolHttp: 346 server.http.Start() 347 case ProtocolHttps: 348 server.http.TLS = &tls.Config{InsecureSkipVerify: true} 349 server.http.StartTLS() 350 default: 351 panic("Woe is me!") 352 } 353 server.URL = server.http.URL 354 server.handler.serverURL = server.http.URL 355 host := server.http.Listener.Addr().String() 356 switch setup.Auth { 357 case AuthNone: 358 // nothing to do 359 case AuthBasic: 360 creds := `"user": "bar", 361 "password": "baz"` 362 server.Conf = sprintCreds(host, "basic", creds) 363 case AuthOauth: 364 creds := `"token": "sometoken"` 365 server.Conf = sprintCreds(host, "oauth", creds) 366 default: 367 panic("Woe is me!") 368 } 369 return server 370 } 371 372 // based on code from net/http/httptest 373 func newLocalListener(port PortType, protocol ProtocolType) net.Listener { 374 portNumber := 0 375 if port == PortFixed { 376 switch protocol { 377 case ProtocolHttp: 378 portNumber = 80 379 case ProtocolHttps: 380 portNumber = 443 381 } 382 } 383 l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", portNumber)) 384 if err != nil { 385 if l, err = net.Listen("tcp6", fmt.Sprintf("[::1]:%d", portNumber)); err != nil { 386 panic(fmt.Sprintf("aci test server: failed to listen on a port: %v", err)) 387 } 388 } 389 return l 390 } 391 392 func sprintCreds(host, auth, creds string) string { 393 return fmt.Sprintf(` 394 { 395 "rktKind": "auth", 396 "rktVersion": "v1", 397 "domains": ["%s"], 398 "type": "%s", 399 "credentials": 400 { 401 %s 402 } 403 } 404 405 `, host, auth, creds) 406 }