github.com/rkt/rkt@v1.30.1-0.20200224141603-171c416fac02/tests/testutils/aci-server/server.go (about) 1 // Copyright 2015 The rkt Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aci 16 17 import ( 18 "crypto/sha512" 19 "crypto/tls" 20 "encoding/base64" 21 "fmt" 22 "io/ioutil" 23 "net" 24 "net/http" 25 "net/http/httptest" 26 "path/filepath" 27 "strings" 28 "time" 29 ) 30 31 type PortType int 32 33 const ( 34 PortFixed PortType = iota 35 PortRandom 36 ) 37 38 type ProtocolType int 39 40 const ( 41 ProtocolHttps ProtocolType = iota 42 ProtocolHttp 43 ) 44 45 type AuthType int 46 47 const ( 48 AuthNone AuthType = iota 49 AuthBasic 50 AuthOauth 51 ) 52 53 type ServerType int 54 55 const ( 56 ServerOrdinary ServerType = iota 57 ServerQuay 58 ) 59 60 type httpError struct { 61 code int 62 message string 63 } 64 65 func (e *httpError) Error() string { 66 return fmt.Sprintf("%d: %s", e.code, e.message) 67 } 68 69 type servedFile struct { 70 path string 71 etag string 72 } 73 74 func newServedFile(path string) (*servedFile, error) { 75 contents, err := ioutil.ReadFile(path) 76 if err != nil { 77 return nil, err 78 } 79 checksum := sha512.Sum512(contents) 80 sf := &servedFile{ 81 path: path, 82 etag: fmt.Sprintf("%x", checksum), 83 } 84 return sf, nil 85 } 86 87 type serverHandler struct { 88 server ServerType 89 auth AuthType 90 protocol ProtocolType 91 msg chan<- string 92 fileSet map[string]*servedFile 93 servedImages map[string]struct{} 94 serverURL string 95 } 96 97 func (h *serverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 98 if r.Method != "GET" { 99 w.WriteHeader(http.StatusMethodNotAllowed) 100 return 101 } 102 if authOk := h.handleAuth(w, r); !authOk { 103 return 104 } 105 h.sendMsg(fmt.Sprintf("Trying to serve %q", r.URL.String())) 106 h.handleRequest(w, r) 107 } 108 109 func (h *serverHandler) handleAuth(w http.ResponseWriter, r *http.Request) bool { 110 switch h.auth { 111 case AuthNone: 112 // no auth to do. 113 return true 114 case AuthBasic: 115 return h.handleBasicAuth(w, r) 116 case AuthOauth: 117 return h.handleOauthAuth(w, r) 118 default: 119 panic("Woe is me!") 120 } 121 } 122 123 func (h *serverHandler) handleBasicAuth(w http.ResponseWriter, r *http.Request) bool { 124 payload, httpErr := getAuthPayload(r, "Basic") 125 if httpErr != nil { 126 w.WriteHeader(httpErr.code) 127 h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message)) 128 return false 129 } 130 creds, err := base64.StdEncoding.DecodeString(string(payload)) 131 if err != nil { 132 w.WriteHeader(http.StatusBadRequest) 133 h.sendMsg(`Badly formed "Authorization" header`) 134 return false 135 } 136 parts := strings.Split(string(creds), ":") 137 if len(parts) != 2 { 138 w.WriteHeader(http.StatusBadRequest) 139 h.sendMsg(`Badly formed "Authorization" header (2)`) 140 return false 141 } 142 user := parts[0] 143 password := parts[1] 144 if user != "bar" || password != "baz" { 145 w.WriteHeader(http.StatusUnauthorized) 146 h.sendMsg(fmt.Sprintf("Bad credentials: %q", string(creds))) 147 return false 148 } 149 return true 150 } 151 152 func (h *serverHandler) handleOauthAuth(w http.ResponseWriter, r *http.Request) bool { 153 payload, httpErr := getAuthPayload(r, "Bearer") 154 if httpErr != nil { 155 w.WriteHeader(httpErr.code) 156 h.sendMsg(fmt.Sprintf(`No "Authorization" header: %v`, httpErr.message)) 157 return false 158 } 159 if payload != "sometoken" { 160 w.WriteHeader(http.StatusUnauthorized) 161 h.sendMsg(fmt.Sprintf(`Bad token: %q`, payload)) 162 return false 163 } 164 return true 165 } 166 167 func getAuthPayload(r *http.Request, authType string) (string, *httpError) { 168 auth := r.Header.Get("Authorization") 169 if auth == "" { 170 err := &httpError{ 171 code: http.StatusUnauthorized, 172 message: "No auth", 173 } 174 return "", err 175 } 176 parts := strings.Split(auth, " ") 177 if len(parts) != 2 { 178 err := &httpError{ 179 code: http.StatusBadRequest, 180 message: "Malformed auth", 181 } 182 return "", err 183 } 184 if parts[0] != authType { 185 err := &httpError{ 186 code: http.StatusUnauthorized, 187 message: "Wrong auth", 188 } 189 return "", err 190 } 191 return parts[1], nil 192 } 193 194 func (h *serverHandler) handleRequest(w http.ResponseWriter, r *http.Request) { 195 path := filepath.Base(r.URL.Path) 196 switch path { 197 case "/": 198 h.sendAcDiscovery(w) 199 default: 200 h.handleFile(w, path, r.Header) 201 } 202 } 203 204 func (h *serverHandler) sendAcDiscovery(w http.ResponseWriter) { 205 // TODO(krnowak): When appc spec gets the discovery over 206 // custom port feature, possibly take it into account here 207 indexHTML := fmt.Sprintf(`<meta name="ac-discovery" content="localhost %s/{name}.{ext}">`, h.serverURL) 208 w.Write([]byte(indexHTML)) 209 h.sendMsg(" done.") 210 } 211 212 func (h *serverHandler) handleFile(w http.ResponseWriter, reqPath string, headers http.Header) { 213 sf, ok := h.fileSet[reqPath] 214 if !ok { 215 w.WriteHeader(http.StatusNotFound) 216 h.sendMsg(" not found.") 217 return 218 } 219 if !h.canServe(reqPath, w) { 220 return 221 } 222 if headers.Get("If-None-Match") == sf.etag { 223 addCacheHeaders(w, sf) 224 w.WriteHeader(http.StatusNotModified) 225 h.sendMsg(" not modified, done.") 226 return 227 } 228 contents, err := ioutil.ReadFile(sf.path) 229 if err != nil { 230 w.WriteHeader(http.StatusInternalServerError) 231 h.sendMsg(" not found, but specified in fileset; bug?") 232 return 233 } 234 addCacheHeaders(w, sf) 235 w.Write(contents) 236 reqImagePath, isAsc := isPathAnImageKey(reqPath) 237 if isAsc { 238 delete(h.servedImages, reqImagePath) 239 } else { 240 h.servedImages[reqPath] = struct{}{} 241 } 242 h.sendMsg(" done.") 243 } 244 245 func (h *serverHandler) canServe(reqPath string, w http.ResponseWriter) bool { 246 if h.server != ServerQuay { 247 return true 248 } 249 reqImagePath, isAsc := isPathAnImageKey(reqPath) 250 if !isAsc { 251 return true 252 } 253 if _, imageAlreadyServed := h.servedImages[reqImagePath]; imageAlreadyServed { 254 return true 255 } 256 w.WriteHeader(http.StatusAccepted) 257 h.sendMsg(" asking to defer the download") 258 return false 259 } 260 261 func addCacheHeaders(w http.ResponseWriter, sf *servedFile) { 262 w.Header().Set("ETag", sf.etag) 263 w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", 60*60*24)) // one day 264 } 265 266 func (h *serverHandler) sendMsg(msg string) { 267 select { 268 case h.msg <- msg: 269 default: 270 } 271 } 272 273 func isPathAnImageKey(path string) (string, bool) { 274 if strings.HasSuffix(path, ".asc") { 275 imagePath := strings.TrimSuffix(path, ".asc") 276 return imagePath, true 277 } 278 return "", false 279 } 280 281 type Server struct { 282 Msg <-chan string 283 Conf string 284 URL string 285 handler *serverHandler 286 http *httptest.Server 287 } 288 289 type ServerSetup struct { 290 Port PortType 291 Protocol ProtocolType 292 Server ServerType 293 Auth AuthType 294 MsgCapacity int 295 } 296 297 func GetDefaultServerSetup() *ServerSetup { 298 return &ServerSetup{ 299 Port: PortFixed, 300 Protocol: ProtocolHttps, 301 Server: ServerOrdinary, 302 Auth: AuthNone, 303 MsgCapacity: 20, 304 } 305 } 306 307 func (s *Server) Close() { 308 s.http.Close() 309 close(s.handler.msg) 310 } 311 312 func (s *Server) UpdateFileSet(fileSet map[string]string) error { 313 s.handler.fileSet = make(map[string]*servedFile, len(fileSet)) 314 for base, path := range fileSet { 315 sf, err := newServedFile(path) 316 if err != nil { 317 return err 318 } 319 s.handler.fileSet[base] = sf 320 } 321 return nil 322 } 323 324 func NewServer(setup *ServerSetup) *Server { 325 msg := make(chan string, setup.MsgCapacity) 326 server := &Server{ 327 Msg: msg, 328 handler: &serverHandler{ 329 auth: setup.Auth, 330 msg: msg, 331 server: setup.Server, 332 protocol: setup.Protocol, 333 fileSet: make(map[string]*servedFile), 334 servedImages: make(map[string]struct{}), 335 }, 336 } 337 server.http = httptest.NewUnstartedServer(server.handler) 338 // We use our own listener, so we can override a port number 339 // without using a "httptest.serve" flag. Using the 340 // "httptest.serve" flag together with an HTTP protocol 341 // results in blocking for debugging purposes as described in 342 // https://golang.org/src/net/http/httptest/server.go#L74. 343 // Here, we lose the ability, but we don't need it. 344 server.http.Listener = newLocalListener(setup.Port, setup.Protocol) 345 switch setup.Protocol { 346 case ProtocolHttp: 347 server.http.Start() 348 case ProtocolHttps: 349 server.http.TLS = &tls.Config{InsecureSkipVerify: true} 350 server.http.StartTLS() 351 default: 352 panic("Woe is me!") 353 } 354 server.URL = server.http.URL 355 server.handler.serverURL = server.http.URL 356 host := server.http.Listener.Addr().String() 357 switch setup.Auth { 358 case AuthNone: 359 // nothing to do 360 case AuthBasic: 361 creds := `"user": "bar", 362 "password": "baz"` 363 server.Conf = sprintCreds(host, "basic", creds) 364 case AuthOauth: 365 creds := `"token": "sometoken"` 366 server.Conf = sprintCreds(host, "oauth", creds) 367 default: 368 panic("Woe is me!") 369 } 370 return server 371 } 372 373 func newLocalListener(port PortType, protocol ProtocolType) net.Listener { 374 portNumber := 0 375 if port == PortFixed { 376 switch protocol { 377 case ProtocolHttp: 378 portNumber = 80 379 case ProtocolHttps: 380 portNumber = 443 381 } 382 } 383 addrs, err := net.LookupHost("localhost") 384 if err != nil { 385 panic(`aci test server: failed to look up "localhost", really`) 386 } 387 var lookupErrs []string 388 for try := 0; try < 2; try++ { 389 for _, addr := range addrs { 390 addrport := fmt.Sprintf("%s:%d", addr, portNumber) 391 l, err := net.Listen("tcp", addrport) 392 if err == nil { 393 return l 394 } 395 lookupErrs = append(lookupErrs, fmt.Sprintf("(listen on %s, attempt #%d: %v)", addrport, try+1, err)) 396 } 397 // TODO: When we have discovery on a custom port then 398 // we could drop listening on fixed ports, so we 399 // probably won't get any races between old server 400 // stopping to listen and new server starting to 401 // listen. 402 // https://github.com/appc/spec/pull/110 403 // Might be possible with ABD: 404 // https://github.com/appc/abd 405 time.Sleep(time.Second) 406 } 407 panic(fmt.Sprintf("aci test server: failed to listen on localhost:%d: %v", portNumber, lookupErrs)) 408 } 409 410 func sprintCreds(host, auth, creds string) string { 411 return fmt.Sprintf(` 412 { 413 "rktKind": "auth", 414 "rktVersion": "v1", 415 "domains": ["%s"], 416 "type": "%s", 417 "credentials": 418 { 419 %s 420 } 421 } 422 423 `, host, auth, creds) 424 }