github.com/git-lfs/git-lfs@v2.5.2+incompatible/t/cmd/lfstest-gitserver.go (about) 1 // +build testtools 2 3 package main 4 5 import ( 6 "bufio" 7 "bytes" 8 "crypto/rand" 9 "crypto/rsa" 10 "crypto/sha256" 11 "crypto/tls" 12 "crypto/x509" 13 "crypto/x509/pkix" 14 "encoding/base64" 15 "encoding/hex" 16 "encoding/json" 17 "encoding/pem" 18 "errors" 19 "fmt" 20 "io" 21 "io/ioutil" 22 "log" 23 "math" 24 "math/big" 25 "net/http" 26 "net/http/httptest" 27 "net/textproto" 28 "os" 29 "os/exec" 30 "regexp" 31 "sort" 32 "strconv" 33 "strings" 34 "sync" 35 "time" 36 37 "github.com/ThomsonReutersEikon/go-ntlm/ntlm" 38 ) 39 40 var ( 41 repoDir string 42 largeObjects = newLfsStorage() 43 server *httptest.Server 44 serverTLS *httptest.Server 45 serverClientCert *httptest.Server 46 47 // maps OIDs to content strings. Both the LFS and Storage test servers below 48 // see OIDs. 49 oidHandlers map[string]string 50 51 // These magic strings tell the test lfs server change their behavior so the 52 // integration tests can check those use cases. Tests will create objects with 53 // the magic strings as the contents. 54 // 55 // printf "status:lfs:404" > 404.dat 56 // 57 contentHandlers = []string{ 58 "status-batch-403", "status-batch-404", "status-batch-410", "status-batch-422", "status-batch-500", 59 "status-storage-403", "status-storage-404", "status-storage-410", "status-storage-422", "status-storage-500", "status-storage-503", 60 "status-batch-resume-206", "batch-resume-fail-fallback", "return-expired-action", "return-expired-action-forever", "return-invalid-size", 61 "object-authenticated", "storage-download-retry", "storage-upload-retry", "unknown-oid", 62 "send-verify-action", "send-deprecated-links", 63 } 64 ) 65 66 func main() { 67 repoDir = os.Getenv("LFSTEST_DIR") 68 69 mux := http.NewServeMux() 70 server = httptest.NewServer(mux) 71 serverTLS = httptest.NewTLSServer(mux) 72 serverClientCert = httptest.NewUnstartedServer(mux) 73 74 //setup Client Cert server 75 rootKey, rootCert := generateCARootCertificates() 76 _, clientCertPEM, clientKeyPEM := generateClientCertificates(rootCert, rootKey) 77 78 certPool := x509.NewCertPool() 79 certPool.AddCert(rootCert) 80 81 serverClientCert.TLS = &tls.Config{ 82 Certificates: []tls.Certificate{serverTLS.TLS.Certificates[0]}, 83 ClientAuth: tls.RequireAndVerifyClientCert, 84 ClientCAs: certPool, 85 } 86 serverClientCert.StartTLS() 87 88 ntlmSession, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode) 89 if err != nil { 90 fmt.Println("Error creating ntlm session:", err) 91 os.Exit(1) 92 } 93 ntlmSession.SetUserInfo("ntlmuser", "ntlmpass", "NTLMDOMAIN") 94 95 stopch := make(chan bool) 96 97 mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) { 98 stopch <- true 99 }) 100 101 mux.HandleFunc("/storage/", storageHandler) 102 mux.HandleFunc("/verify", verifyHandler) 103 mux.HandleFunc("/redirect307/", redirect307Handler) 104 mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { 105 fmt.Fprintf(w, "%s\n", time.Now().String()) 106 }) 107 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 108 id, ok := reqId(w) 109 if !ok { 110 return 111 } 112 113 if strings.Contains(r.URL.Path, "/info/lfs") { 114 if !skipIfBadAuth(w, r, id, ntlmSession) { 115 lfsHandler(w, r, id) 116 } 117 118 return 119 } 120 121 debug(id, "git http-backend %s %s", r.Method, r.URL) 122 gitHandler(w, r) 123 }) 124 125 urlname := writeTestStateFile([]byte(server.URL), "LFSTEST_URL", "lfstest-gitserver") 126 defer os.RemoveAll(urlname) 127 128 sslurlname := writeTestStateFile([]byte(serverTLS.URL), "LFSTEST_SSL_URL", "lfstest-gitserver-ssl") 129 defer os.RemoveAll(sslurlname) 130 131 clientCertUrlname := writeTestStateFile([]byte(serverClientCert.URL), "LFSTEST_CLIENT_CERT_URL", "lfstest-gitserver-ssl") 132 defer os.RemoveAll(clientCertUrlname) 133 134 block := &pem.Block{} 135 block.Type = "CERTIFICATE" 136 block.Bytes = serverTLS.TLS.Certificates[0].Certificate[0] 137 pembytes := pem.EncodeToMemory(block) 138 139 certname := writeTestStateFile(pembytes, "LFSTEST_CERT", "lfstest-gitserver-cert") 140 defer os.RemoveAll(certname) 141 142 cccertname := writeTestStateFile(clientCertPEM, "LFSTEST_CLIENT_CERT", "lfstest-gitserver-client-cert") 143 defer os.RemoveAll(cccertname) 144 145 ckcertname := writeTestStateFile(clientKeyPEM, "LFSTEST_CLIENT_KEY", "lfstest-gitserver-client-key") 146 defer os.RemoveAll(ckcertname) 147 148 debug("init", "server url: %s", server.URL) 149 debug("init", "server tls url: %s", serverTLS.URL) 150 debug("init", "server client cert url: %s", serverClientCert.URL) 151 152 <-stopch 153 debug("init", "git server done") 154 } 155 156 // writeTestStateFile writes contents to either the file referenced by the 157 // environment variable envVar, or defaultFilename if that's not set. Returns 158 // the filename that was used 159 func writeTestStateFile(contents []byte, envVar, defaultFilename string) string { 160 f := os.Getenv(envVar) 161 if len(f) == 0 { 162 f = defaultFilename 163 } 164 file, err := os.Create(f) 165 if err != nil { 166 log.Fatalln(err) 167 } 168 file.Write(contents) 169 file.Close() 170 return f 171 } 172 173 type lfsObject struct { 174 Oid string `json:"oid,omitempty"` 175 Size int64 `json:"size,omitempty"` 176 Authenticated bool `json:"authenticated,omitempty"` 177 Actions map[string]*lfsLink `json:"actions,omitempty"` 178 Links map[string]*lfsLink `json:"_links,omitempty"` 179 Err *lfsError `json:"error,omitempty"` 180 } 181 182 type lfsLink struct { 183 Href string `json:"href"` 184 Header map[string]string `json:"header,omitempty"` 185 ExpiresAt time.Time `json:"expires_at,omitempty"` 186 ExpiresIn int `json:"expires_in,omitempty"` 187 } 188 189 type lfsError struct { 190 Code int `json:"code,omitempty"` 191 Message string `json:"message"` 192 } 193 194 func writeLFSError(w http.ResponseWriter, code int, msg string) { 195 by, err := json.Marshal(&lfsError{Message: msg}) 196 if err != nil { 197 http.Error(w, "json encoding error: "+err.Error(), 500) 198 return 199 } 200 201 w.Header().Set("Content-Type", "application/vnd.git-lfs+json") 202 w.WriteHeader(code) 203 w.Write(by) 204 } 205 206 // handles any requests with "{name}.server.git/info/lfs" in the path 207 func lfsHandler(w http.ResponseWriter, r *http.Request, id string) { 208 repo, err := repoFromLfsUrl(r.URL.Path) 209 if err != nil { 210 w.WriteHeader(500) 211 w.Write([]byte(err.Error())) 212 return 213 } 214 215 debug(id, "git lfs %s %s repo: %s", r.Method, r.URL, repo) 216 w.Header().Set("Content-Type", "application/vnd.git-lfs+json") 217 switch r.Method { 218 case "POST": 219 if strings.HasSuffix(r.URL.String(), "batch") { 220 lfsBatchHandler(w, r, id, repo) 221 } else { 222 locksHandler(w, r, repo) 223 } 224 case "DELETE": 225 lfsDeleteHandler(w, r, id, repo) 226 case "GET": 227 if strings.Contains(r.URL.String(), "/locks") { 228 locksHandler(w, r, repo) 229 } else { 230 w.WriteHeader(404) 231 w.Write([]byte("lock request")) 232 } 233 default: 234 w.WriteHeader(405) 235 } 236 } 237 238 func lfsUrl(repo, oid string) string { 239 return server.URL + "/storage/" + oid + "?r=" + repo 240 } 241 242 var ( 243 retries = make(map[string]uint32) 244 retriesMu sync.Mutex 245 ) 246 247 func incrementRetriesFor(api, direction, repo, oid string, check bool) (after uint32, ok bool) { 248 // fmtStr formats a string like "<api>-<direction>-[check]-<retry>", 249 // i.e., "legacy-upload-check-retry", or "storage-download-retry". 250 var fmtStr string 251 if check { 252 fmtStr = "%s-%s-check-retry" 253 } else { 254 fmtStr = "%s-%s-retry" 255 } 256 257 if oidHandlers[oid] != fmt.Sprintf(fmtStr, api, direction) { 258 return 0, false 259 } 260 261 retriesMu.Lock() 262 defer retriesMu.Unlock() 263 264 retryKey := strings.Join([]string{direction, repo, oid}, ":") 265 266 retries[retryKey]++ 267 retries := retries[retryKey] 268 269 return retries, true 270 } 271 272 func lfsDeleteHandler(w http.ResponseWriter, r *http.Request, id, repo string) { 273 parts := strings.Split(r.URL.Path, "/") 274 oid := parts[len(parts)-1] 275 276 largeObjects.Delete(repo, oid) 277 debug(id, "DELETE:", oid) 278 w.WriteHeader(200) 279 } 280 281 type batchReq struct { 282 Transfers []string `json:"transfers"` 283 Operation string `json:"operation"` 284 Objects []lfsObject `json:"objects"` 285 Ref *Ref `json:"ref,omitempty"` 286 } 287 288 func (r *batchReq) RefName() string { 289 if r.Ref == nil { 290 return "" 291 } 292 return r.Ref.Name 293 } 294 295 type batchResp struct { 296 Transfer string `json:"transfer,omitempty"` 297 Objects []lfsObject `json:"objects"` 298 } 299 300 func lfsBatchHandler(w http.ResponseWriter, r *http.Request, id, repo string) { 301 checkingObject := r.Header.Get("X-Check-Object") == "1" 302 if !checkingObject && repo == "batchunsupported" { 303 w.WriteHeader(404) 304 return 305 } 306 307 if !checkingObject && repo == "badbatch" { 308 w.WriteHeader(203) 309 return 310 } 311 312 if repo == "netrctest" { 313 user, pass, err := extractAuth(r.Header.Get("Authorization")) 314 if err != nil || (user != "netrcuser" || pass != "netrcpass") { 315 w.WriteHeader(403) 316 return 317 } 318 } 319 320 if missingRequiredCreds(w, r, repo) { 321 return 322 } 323 324 buf := &bytes.Buffer{} 325 tee := io.TeeReader(r.Body, buf) 326 objs := &batchReq{} 327 err := json.NewDecoder(tee).Decode(objs) 328 io.Copy(ioutil.Discard, r.Body) 329 r.Body.Close() 330 331 debug(id, "REQUEST") 332 debug(id, buf.String()) 333 334 if err != nil { 335 log.Fatal(err) 336 } 337 338 if strings.HasSuffix(repo, "branch-required") { 339 parts := strings.Split(repo, "-") 340 lenParts := len(parts) 341 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != objs.RefName() { 342 w.WriteHeader(403) 343 json.NewEncoder(w).Encode(struct { 344 Message string `json:"message"` 345 }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], objs.RefName())}) 346 return 347 } 348 } 349 350 res := []lfsObject{} 351 testingChunked := testingChunkedTransferEncoding(r) 352 testingTus := testingTusUploadInBatchReq(r) 353 testingTusInterrupt := testingTusUploadInterruptedInBatchReq(r) 354 testingCustomTransfer := testingCustomTransfer(r) 355 var transferChoice string 356 var searchForTransfer string 357 if testingTus { 358 searchForTransfer = "tus" 359 } else if testingCustomTransfer { 360 searchForTransfer = "testcustom" 361 } 362 if len(searchForTransfer) > 0 { 363 for _, t := range objs.Transfers { 364 if t == searchForTransfer { 365 transferChoice = searchForTransfer 366 break 367 } 368 369 } 370 } 371 for _, obj := range objs.Objects { 372 handler := oidHandlers[obj.Oid] 373 action := objs.Operation 374 375 o := lfsObject{ 376 Size: obj.Size, 377 Actions: make(map[string]*lfsLink), 378 } 379 380 // Clobber the OID if told to do so. 381 if handler == "unknown-oid" { 382 o.Oid = "unknown-oid" 383 } else { 384 o.Oid = obj.Oid 385 } 386 387 exists := largeObjects.Has(repo, obj.Oid) 388 addAction := true 389 if action == "download" { 390 if !exists { 391 o.Err = &lfsError{Code: 404, Message: fmt.Sprintf("Object %v does not exist", obj.Oid)} 392 addAction = false 393 } 394 } else { 395 if exists { 396 // not an error but don't add an action 397 addAction = false 398 } 399 } 400 401 if handler == "object-authenticated" { 402 o.Authenticated = true 403 } 404 405 switch handler { 406 case "status-batch-403": 407 o.Err = &lfsError{Code: 403, Message: "welp"} 408 case "status-batch-404": 409 o.Err = &lfsError{Code: 404, Message: "welp"} 410 case "status-batch-410": 411 o.Err = &lfsError{Code: 410, Message: "welp"} 412 case "status-batch-422": 413 o.Err = &lfsError{Code: 422, Message: "welp"} 414 case "status-batch-500": 415 o.Err = &lfsError{Code: 500, Message: "welp"} 416 default: // regular 200 response 417 if handler == "return-invalid-size" { 418 o.Size = -1 419 } 420 421 if handler == "send-deprecated-links" { 422 o.Links = make(map[string]*lfsLink) 423 } 424 425 if addAction { 426 a := &lfsLink{ 427 Href: lfsUrl(repo, obj.Oid), 428 Header: map[string]string{}, 429 } 430 a = serveExpired(a, repo, handler) 431 432 if handler == "send-deprecated-links" { 433 o.Links[action] = a 434 } else { 435 o.Actions[action] = a 436 } 437 } 438 439 if handler == "send-verify-action" { 440 o.Actions["verify"] = &lfsLink{ 441 Href: server.URL + "/verify", 442 Header: map[string]string{ 443 "repo": repo, 444 }, 445 } 446 } 447 } 448 449 if testingChunked && addAction { 450 if handler == "send-deprecated-links" { 451 o.Links[action].Header["Transfer-Encoding"] = "chunked" 452 } else { 453 o.Actions[action].Header["Transfer-Encoding"] = "chunked" 454 } 455 } 456 if testingTusInterrupt && addAction { 457 if handler == "send-deprecated-links" { 458 o.Links[action].Header["Lfs-Tus-Interrupt"] = "true" 459 } else { 460 o.Actions[action].Header["Lfs-Tus-Interrupt"] = "true" 461 } 462 } 463 464 res = append(res, o) 465 } 466 467 ores := batchResp{Transfer: transferChoice, Objects: res} 468 469 by, err := json.Marshal(ores) 470 if err != nil { 471 log.Fatal(err) 472 } 473 474 debug(id, "RESPONSE: 200") 475 debug(id, string(by)) 476 477 w.WriteHeader(200) 478 w.Write(by) 479 } 480 481 // emu guards expiredRepos 482 var emu sync.Mutex 483 484 // expiredRepos is a map keyed by repository name, valuing to whether or not it 485 // has yet served an expired object. 486 var expiredRepos = map[string]bool{} 487 488 // serveExpired marks the given repo as having served an expired object, making 489 // it unable for that same repository to return an expired object in the future, 490 func serveExpired(a *lfsLink, repo, handler string) *lfsLink { 491 var ( 492 dur = -5 * time.Minute 493 at = time.Now().Add(dur) 494 ) 495 496 if handler == "return-expired-action-forever" || 497 (handler == "return-expired-action" && canServeExpired(repo)) { 498 499 emu.Lock() 500 expiredRepos[repo] = true 501 emu.Unlock() 502 503 a.ExpiresAt = at 504 return a 505 } 506 507 switch repo { 508 case "expired-absolute": 509 a.ExpiresAt = at 510 case "expired-relative": 511 a.ExpiresIn = -5 512 case "expired-both": 513 a.ExpiresAt = at 514 a.ExpiresIn = -5 515 } 516 517 return a 518 } 519 520 // canServeExpired returns whether or not a repository is capable of serving an 521 // expired object. In other words, canServeExpired returns whether or not the 522 // given repo has yet served an expired object. 523 func canServeExpired(repo string) bool { 524 emu.Lock() 525 defer emu.Unlock() 526 527 return !expiredRepos[repo] 528 } 529 530 // Persistent state across requests 531 var batchResumeFailFallbackStorageAttempts = 0 532 var tusStorageAttempts = 0 533 534 var ( 535 vmu sync.Mutex 536 verifyCounts = make(map[string]int) 537 verifyRetryRe = regexp.MustCompile(`verify-fail-(\d+)-times?$`) 538 ) 539 540 func verifyHandler(w http.ResponseWriter, r *http.Request) { 541 repo := r.Header.Get("repo") 542 var payload struct { 543 Oid string `json:"oid"` 544 Size int64 `json:"size"` 545 } 546 547 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { 548 writeLFSError(w, http.StatusUnprocessableEntity, err.Error()) 549 return 550 } 551 552 var max int 553 if matches := verifyRetryRe.FindStringSubmatch(repo); len(matches) < 2 { 554 return 555 } else { 556 max, _ = strconv.Atoi(matches[1]) 557 } 558 559 key := strings.Join([]string{repo, payload.Oid}, ":") 560 561 vmu.Lock() 562 verifyCounts[key] = verifyCounts[key] + 1 563 count := verifyCounts[key] 564 vmu.Unlock() 565 566 if count < max { 567 writeLFSError(w, http.StatusServiceUnavailable, fmt.Sprintf( 568 "intentionally failing verify request %d (out of %d)", count, max, 569 )) 570 return 571 } 572 } 573 574 // handles any /storage/{oid} requests 575 func storageHandler(w http.ResponseWriter, r *http.Request) { 576 id, ok := reqId(w) 577 if !ok { 578 return 579 } 580 581 repo := r.URL.Query().Get("r") 582 parts := strings.Split(r.URL.Path, "/") 583 oid := parts[len(parts)-1] 584 if missingRequiredCreds(w, r, repo) { 585 return 586 } 587 588 debug(id, "storage %s %s repo: %s", r.Method, oid, repo) 589 switch r.Method { 590 case "PUT": 591 switch oidHandlers[oid] { 592 case "status-storage-403": 593 w.WriteHeader(403) 594 return 595 case "status-storage-404": 596 w.WriteHeader(404) 597 return 598 case "status-storage-410": 599 w.WriteHeader(410) 600 return 601 case "status-storage-422": 602 w.WriteHeader(422) 603 return 604 case "status-storage-500": 605 w.WriteHeader(500) 606 return 607 case "status-storage-503": 608 writeLFSError(w, 503, "LFS is temporarily unavailable") 609 return 610 case "object-authenticated": 611 if len(r.Header.Get("Authorization")) > 0 { 612 w.WriteHeader(400) 613 w.Write([]byte("Should not send authentication")) 614 } 615 return 616 case "storage-upload-retry": 617 if retries, ok := incrementRetriesFor("storage", "upload", repo, oid, false); ok && retries < 3 { 618 w.WriteHeader(500) 619 w.Write([]byte("malformed content")) 620 621 return 622 } 623 } 624 625 if testingChunkedTransferEncoding(r) { 626 valid := false 627 for _, value := range r.TransferEncoding { 628 if value == "chunked" { 629 valid = true 630 break 631 } 632 } 633 if !valid { 634 debug(id, "Chunked transfer encoding expected") 635 } 636 } 637 638 hash := sha256.New() 639 buf := &bytes.Buffer{} 640 641 io.Copy(io.MultiWriter(hash, buf), r.Body) 642 oid := hex.EncodeToString(hash.Sum(nil)) 643 if !strings.HasSuffix(r.URL.Path, "/"+oid) { 644 w.WriteHeader(403) 645 return 646 } 647 648 largeObjects.Set(repo, oid, buf.Bytes()) 649 650 case "GET": 651 parts := strings.Split(r.URL.Path, "/") 652 oid := parts[len(parts)-1] 653 statusCode := 200 654 byteLimit := 0 655 resumeAt := int64(0) 656 657 if by, ok := largeObjects.Get(repo, oid); ok { 658 if len(by) == len("storage-download-retry") && string(by) == "storage-download-retry" { 659 if retries, ok := incrementRetriesFor("storage", "download", repo, oid, false); ok && retries < 3 { 660 statusCode = 500 661 by = []byte("malformed content") 662 } 663 } else if len(by) == len("status-batch-resume-206") && string(by) == "status-batch-resume-206" { 664 // Resume if header includes range, otherwise deliberately interrupt 665 if rangeHdr := r.Header.Get("Range"); rangeHdr != "" { 666 regex := regexp.MustCompile(`bytes=(\d+)\-.*`) 667 match := regex.FindStringSubmatch(rangeHdr) 668 if match != nil && len(match) > 1 { 669 statusCode = 206 670 resumeAt, _ = strconv.ParseInt(match[1], 10, 32) 671 w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", resumeAt, len(by), resumeAt-int64(len(by)))) 672 } 673 } else { 674 byteLimit = 10 675 } 676 } else if len(by) == len("batch-resume-fail-fallback") && string(by) == "batch-resume-fail-fallback" { 677 // Fail any Range: request even though we said we supported it 678 // To make sure client can fall back 679 if rangeHdr := r.Header.Get("Range"); rangeHdr != "" { 680 w.WriteHeader(416) 681 return 682 } 683 if batchResumeFailFallbackStorageAttempts == 0 { 684 // Truncate output on FIRST attempt to cause resume 685 // Second attempt (without range header) is fallback, complete successfully 686 byteLimit = 8 687 batchResumeFailFallbackStorageAttempts++ 688 } 689 } 690 w.WriteHeader(statusCode) 691 if byteLimit > 0 { 692 w.Write(by[0:byteLimit]) 693 } else if resumeAt > 0 { 694 w.Write(by[resumeAt:]) 695 } else { 696 w.Write(by) 697 } 698 return 699 } 700 701 w.WriteHeader(404) 702 case "HEAD": 703 // tus.io 704 if !validateTusHeaders(r, id) { 705 w.WriteHeader(400) 706 return 707 } 708 parts := strings.Split(r.URL.Path, "/") 709 oid := parts[len(parts)-1] 710 var offset int64 711 if by, ok := largeObjects.GetIncomplete(repo, oid); ok { 712 offset = int64(len(by)) 713 } 714 w.Header().Set("Upload-Offset", strconv.FormatInt(offset, 10)) 715 w.WriteHeader(200) 716 case "PATCH": 717 // tus.io 718 if !validateTusHeaders(r, id) { 719 w.WriteHeader(400) 720 return 721 } 722 parts := strings.Split(r.URL.Path, "/") 723 oid := parts[len(parts)-1] 724 725 offsetHdr := r.Header.Get("Upload-Offset") 726 offset, err := strconv.ParseInt(offsetHdr, 10, 64) 727 if err != nil { 728 log.Fatal("Unable to parse Upload-Offset header in request: ", err) 729 w.WriteHeader(400) 730 return 731 } 732 hash := sha256.New() 733 buf := &bytes.Buffer{} 734 out := io.MultiWriter(hash, buf) 735 736 if by, ok := largeObjects.GetIncomplete(repo, oid); ok { 737 if offset != int64(len(by)) { 738 log.Fatal(fmt.Sprintf("Incorrect offset in request, got %d expected %d", offset, len(by))) 739 w.WriteHeader(400) 740 return 741 } 742 _, err := out.Write(by) 743 if err != nil { 744 log.Fatal("Error reading incomplete bytes from store: ", err) 745 w.WriteHeader(500) 746 return 747 } 748 largeObjects.DeleteIncomplete(repo, oid) 749 debug(id, "Resuming upload of %v at byte %d", oid, offset) 750 } 751 752 // As a test, we intentionally break the upload from byte 0 by only 753 // reading some bytes the quitting & erroring, this forces a resume 754 // any offset > 0 will work ok 755 var copyErr error 756 if r.Header.Get("Lfs-Tus-Interrupt") == "true" && offset == 0 { 757 chdr := r.Header.Get("Content-Length") 758 contentLen, err := strconv.ParseInt(chdr, 10, 64) 759 if err != nil { 760 log.Fatal(fmt.Sprintf("Invalid Content-Length %q", chdr)) 761 w.WriteHeader(400) 762 return 763 } 764 truncated := contentLen / 3 765 _, _ = io.CopyN(out, r.Body, truncated) 766 r.Body.Close() 767 copyErr = fmt.Errorf("Simulated copy error") 768 } else { 769 _, copyErr = io.Copy(out, r.Body) 770 } 771 if copyErr != nil { 772 b := buf.Bytes() 773 if len(b) > 0 { 774 debug(id, "Incomplete upload of %v, %d bytes", oid, len(b)) 775 largeObjects.SetIncomplete(repo, oid, b) 776 } 777 w.WriteHeader(500) 778 } else { 779 checkoid := hex.EncodeToString(hash.Sum(nil)) 780 if checkoid != oid { 781 log.Fatal(fmt.Sprintf("Incorrect oid after calculation, got %q expected %q", checkoid, oid)) 782 w.WriteHeader(403) 783 return 784 } 785 786 b := buf.Bytes() 787 largeObjects.Set(repo, oid, b) 788 w.Header().Set("Upload-Offset", strconv.FormatInt(int64(len(b)), 10)) 789 w.WriteHeader(204) 790 } 791 792 default: 793 w.WriteHeader(405) 794 } 795 } 796 797 func validateTusHeaders(r *http.Request, id string) bool { 798 if len(r.Header.Get("Tus-Resumable")) == 0 { 799 debug(id, "Missing Tus-Resumable header in request") 800 return false 801 } 802 return true 803 } 804 805 func gitHandler(w http.ResponseWriter, r *http.Request) { 806 defer func() { 807 io.Copy(ioutil.Discard, r.Body) 808 r.Body.Close() 809 }() 810 811 cmd := exec.Command("git", "http-backend") 812 cmd.Env = []string{ 813 fmt.Sprintf("GIT_PROJECT_ROOT=%s", repoDir), 814 fmt.Sprintf("GIT_HTTP_EXPORT_ALL="), 815 fmt.Sprintf("PATH_INFO=%s", r.URL.Path), 816 fmt.Sprintf("QUERY_STRING=%s", r.URL.RawQuery), 817 fmt.Sprintf("REQUEST_METHOD=%s", r.Method), 818 fmt.Sprintf("CONTENT_TYPE=%s", r.Header.Get("Content-Type")), 819 } 820 821 buffer := &bytes.Buffer{} 822 cmd.Stdin = r.Body 823 cmd.Stdout = buffer 824 cmd.Stderr = os.Stderr 825 826 if err := cmd.Run(); err != nil { 827 log.Fatal(err) 828 } 829 830 text := textproto.NewReader(bufio.NewReader(buffer)) 831 832 code, _, _ := text.ReadCodeLine(-1) 833 834 if code != 0 { 835 w.WriteHeader(code) 836 } 837 838 headers, _ := text.ReadMIMEHeader() 839 head := w.Header() 840 for key, values := range headers { 841 for _, value := range values { 842 head.Add(key, value) 843 } 844 } 845 846 io.Copy(w, text.R) 847 } 848 849 func redirect307Handler(w http.ResponseWriter, r *http.Request) { 850 id, ok := reqId(w) 851 if !ok { 852 return 853 } 854 855 // Send a redirect to info/lfs 856 // Make it either absolute or relative depending on subpath 857 parts := strings.Split(r.URL.Path, "/") 858 // first element is always blank since rooted 859 var redirectTo string 860 if parts[2] == "rel" { 861 redirectTo = "/" + strings.Join(parts[3:], "/") 862 } else if parts[2] == "abs" { 863 redirectTo = server.URL + "/" + strings.Join(parts[3:], "/") 864 } else { 865 debug(id, "Invalid URL for redirect: %v", r.URL) 866 w.WriteHeader(404) 867 return 868 } 869 w.Header().Set("Location", redirectTo) 870 w.WriteHeader(307) 871 } 872 873 type User struct { 874 Name string `json:"name"` 875 } 876 877 type Lock struct { 878 Id string `json:"id"` 879 Path string `json:"path"` 880 Owner User `json:"owner"` 881 LockedAt time.Time `json:"locked_at"` 882 } 883 884 type LockRequest struct { 885 Path string `json:"path"` 886 Ref *Ref `json:"ref,omitempty"` 887 } 888 889 func (r *LockRequest) RefName() string { 890 if r.Ref == nil { 891 return "" 892 } 893 return r.Ref.Name 894 } 895 896 type LockResponse struct { 897 Lock *Lock `json:"lock"` 898 Message string `json:"message,omitempty"` 899 } 900 901 type UnlockRequest struct { 902 Force bool `json:"force"` 903 Ref *Ref `json:"ref,omitempty"` 904 } 905 906 func (r *UnlockRequest) RefName() string { 907 if r.Ref == nil { 908 return "" 909 } 910 return r.Ref.Name 911 } 912 913 type UnlockResponse struct { 914 Lock *Lock `json:"lock"` 915 Message string `json:"message,omitempty"` 916 } 917 918 type LockList struct { 919 Locks []Lock `json:"locks"` 920 NextCursor string `json:"next_cursor,omitempty"` 921 Message string `json:"message,omitempty"` 922 } 923 924 type Ref struct { 925 Name string `json:"name,omitempty"` 926 } 927 928 type VerifiableLockRequest struct { 929 Ref *Ref `json:"ref,omitempty"` 930 Cursor string `json:"cursor,omitempty"` 931 Limit int `json:"limit,omitempty"` 932 } 933 934 func (r *VerifiableLockRequest) RefName() string { 935 if r.Ref == nil { 936 return "" 937 } 938 return r.Ref.Name 939 } 940 941 type VerifiableLockList struct { 942 Ours []Lock `json:"ours"` 943 Theirs []Lock `json:"theirs"` 944 NextCursor string `json:"next_cursor,omitempty"` 945 Message string `json:"message,omitempty"` 946 } 947 948 var ( 949 lmu sync.RWMutex 950 repoLocks = map[string][]Lock{} 951 ) 952 953 func addLocks(repo string, l ...Lock) { 954 lmu.Lock() 955 defer lmu.Unlock() 956 repoLocks[repo] = append(repoLocks[repo], l...) 957 sort.Sort(LocksByCreatedAt(repoLocks[repo])) 958 } 959 960 func getLocks(repo string) []Lock { 961 lmu.RLock() 962 defer lmu.RUnlock() 963 964 locks := repoLocks[repo] 965 cp := make([]Lock, len(locks)) 966 for i, l := range locks { 967 cp[i] = l 968 } 969 970 return cp 971 } 972 973 func getFilteredLocks(repo, path, cursor, limit string) ([]Lock, string, error) { 974 locks := getLocks(repo) 975 if cursor != "" { 976 lastSeen := -1 977 for i, l := range locks { 978 if l.Id == cursor { 979 lastSeen = i 980 break 981 } 982 } 983 984 if lastSeen > -1 { 985 locks = locks[lastSeen:] 986 } else { 987 return nil, "", fmt.Errorf("cursor (%s) not found", cursor) 988 } 989 } 990 991 if path != "" { 992 var filtered []Lock 993 for _, l := range locks { 994 if l.Path == path { 995 filtered = append(filtered, l) 996 } 997 } 998 999 locks = filtered 1000 } 1001 1002 if limit != "" { 1003 size, err := strconv.Atoi(limit) 1004 if err != nil { 1005 return nil, "", errors.New("unable to parse limit amount") 1006 } 1007 1008 size = int(math.Min(float64(len(locks)), 3)) 1009 if size < 0 { 1010 return nil, "", nil 1011 } 1012 1013 if size+1 < len(locks) { 1014 return locks[:size], locks[size+1].Id, nil 1015 } 1016 } 1017 1018 return locks, "", nil 1019 } 1020 1021 func delLock(repo string, id string) *Lock { 1022 lmu.RLock() 1023 defer lmu.RUnlock() 1024 1025 var deleted *Lock 1026 locks := make([]Lock, 0, len(repoLocks[repo])) 1027 for _, l := range repoLocks[repo] { 1028 if l.Id == id { 1029 deleted = &l 1030 continue 1031 } 1032 locks = append(locks, l) 1033 } 1034 repoLocks[repo] = locks 1035 return deleted 1036 } 1037 1038 type LocksByCreatedAt []Lock 1039 1040 func (c LocksByCreatedAt) Len() int { return len(c) } 1041 func (c LocksByCreatedAt) Less(i, j int) bool { return c[i].LockedAt.Before(c[j].LockedAt) } 1042 func (c LocksByCreatedAt) Swap(i, j int) { c[i], c[j] = c[j], c[i] } 1043 1044 var ( 1045 lockRe = regexp.MustCompile(`/locks/?$`) 1046 unlockRe = regexp.MustCompile(`locks/([^/]+)/unlock\z`) 1047 ) 1048 1049 func locksHandler(w http.ResponseWriter, r *http.Request, repo string) { 1050 dec := json.NewDecoder(r.Body) 1051 enc := json.NewEncoder(w) 1052 1053 switch r.Method { 1054 case "GET": 1055 if !lockRe.MatchString(r.URL.Path) { 1056 w.Header().Set("Content-Type", "application/json") 1057 w.WriteHeader(http.StatusNotFound) 1058 w.Write([]byte(`{"message":"unknown path: ` + r.URL.Path + `"}`)) 1059 return 1060 } 1061 1062 if err := r.ParseForm(); err != nil { 1063 http.Error(w, "could not parse form values", http.StatusInternalServerError) 1064 return 1065 } 1066 1067 if strings.HasSuffix(repo, "branch-required") { 1068 parts := strings.Split(repo, "-") 1069 lenParts := len(parts) 1070 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != r.FormValue("refspec") { 1071 w.WriteHeader(403) 1072 enc.Encode(struct { 1073 Message string `json:"message"` 1074 }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], r.FormValue("refspec"))}) 1075 return 1076 } 1077 } 1078 1079 ll := &LockList{} 1080 w.Header().Set("Content-Type", "application/json") 1081 locks, nextCursor, err := getFilteredLocks(repo, 1082 r.FormValue("path"), 1083 r.FormValue("cursor"), 1084 r.FormValue("limit")) 1085 1086 if err != nil { 1087 ll.Message = err.Error() 1088 } else { 1089 ll.Locks = locks 1090 ll.NextCursor = nextCursor 1091 } 1092 1093 enc.Encode(ll) 1094 return 1095 case "POST": 1096 w.Header().Set("Content-Type", "application/json") 1097 if strings.HasSuffix(r.URL.Path, "unlock") { 1098 var lockId string 1099 if matches := unlockRe.FindStringSubmatch(r.URL.Path); len(matches) > 1 { 1100 lockId = matches[1] 1101 } 1102 1103 if len(lockId) == 0 { 1104 enc.Encode(&UnlockResponse{Message: "Invalid lock"}) 1105 } 1106 1107 unlockRequest := &UnlockRequest{} 1108 if err := dec.Decode(unlockRequest); err != nil { 1109 enc.Encode(&UnlockResponse{Message: err.Error()}) 1110 return 1111 } 1112 1113 if strings.HasSuffix(repo, "branch-required") { 1114 parts := strings.Split(repo, "-") 1115 lenParts := len(parts) 1116 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != unlockRequest.RefName() { 1117 w.WriteHeader(403) 1118 enc.Encode(struct { 1119 Message string `json:"message"` 1120 }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], unlockRequest.RefName())}) 1121 return 1122 } 1123 } 1124 1125 if l := delLock(repo, lockId); l != nil { 1126 enc.Encode(&UnlockResponse{Lock: l}) 1127 } else { 1128 enc.Encode(&UnlockResponse{Message: "unable to find lock"}) 1129 } 1130 return 1131 } 1132 1133 if strings.HasSuffix(r.URL.Path, "/locks/verify") { 1134 if strings.HasSuffix(repo, "verify-5xx") { 1135 w.WriteHeader(500) 1136 return 1137 } 1138 if strings.HasSuffix(repo, "verify-501") { 1139 w.WriteHeader(501) 1140 return 1141 } 1142 if strings.HasSuffix(repo, "verify-403") { 1143 w.WriteHeader(403) 1144 return 1145 } 1146 1147 switch repo { 1148 case "pre_push_locks_verify_404": 1149 w.WriteHeader(http.StatusNotFound) 1150 w.Write([]byte(`{"message":"pre_push_locks_verify_404"}`)) 1151 return 1152 case "pre_push_locks_verify_410": 1153 w.WriteHeader(http.StatusGone) 1154 w.Write([]byte(`{"message":"pre_push_locks_verify_410"}`)) 1155 return 1156 } 1157 1158 reqBody := &VerifiableLockRequest{} 1159 if err := dec.Decode(reqBody); err != nil { 1160 w.WriteHeader(http.StatusBadRequest) 1161 enc.Encode(struct { 1162 Message string `json:"message"` 1163 }{"json decode error: " + err.Error()}) 1164 return 1165 } 1166 1167 if strings.HasSuffix(repo, "branch-required") { 1168 parts := strings.Split(repo, "-") 1169 lenParts := len(parts) 1170 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != reqBody.RefName() { 1171 w.WriteHeader(403) 1172 enc.Encode(struct { 1173 Message string `json:"message"` 1174 }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], reqBody.RefName())}) 1175 return 1176 } 1177 } 1178 1179 ll := &VerifiableLockList{} 1180 locks, nextCursor, err := getFilteredLocks(repo, "", 1181 reqBody.Cursor, 1182 strconv.Itoa(reqBody.Limit)) 1183 if err != nil { 1184 ll.Message = err.Error() 1185 } else { 1186 ll.NextCursor = nextCursor 1187 1188 for _, l := range locks { 1189 if strings.Contains(l.Path, "theirs") { 1190 ll.Theirs = append(ll.Theirs, l) 1191 } else { 1192 ll.Ours = append(ll.Ours, l) 1193 } 1194 } 1195 } 1196 1197 enc.Encode(ll) 1198 return 1199 } 1200 1201 if strings.HasSuffix(r.URL.Path, "/locks") { 1202 lockRequest := &LockRequest{} 1203 if err := dec.Decode(lockRequest); err != nil { 1204 enc.Encode(&LockResponse{Message: err.Error()}) 1205 } 1206 1207 if strings.HasSuffix(repo, "branch-required") { 1208 parts := strings.Split(repo, "-") 1209 lenParts := len(parts) 1210 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != lockRequest.RefName() { 1211 w.WriteHeader(403) 1212 enc.Encode(struct { 1213 Message string `json:"message"` 1214 }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], lockRequest.RefName())}) 1215 return 1216 } 1217 } 1218 1219 for _, l := range getLocks(repo) { 1220 if l.Path == lockRequest.Path { 1221 enc.Encode(&LockResponse{Message: "lock already created"}) 1222 return 1223 } 1224 } 1225 1226 var id [20]byte 1227 rand.Read(id[:]) 1228 1229 lock := &Lock{ 1230 Id: fmt.Sprintf("%x", id[:]), 1231 Path: lockRequest.Path, 1232 Owner: User{Name: "Git LFS Tests"}, 1233 LockedAt: time.Now(), 1234 } 1235 1236 addLocks(repo, *lock) 1237 1238 // TODO(taylor): commit_needed case 1239 // TODO(taylor): err case 1240 1241 enc.Encode(&LockResponse{ 1242 Lock: lock, 1243 }) 1244 return 1245 } 1246 } 1247 1248 http.NotFound(w, r) 1249 } 1250 1251 func missingRequiredCreds(w http.ResponseWriter, r *http.Request, repo string) bool { 1252 if !strings.HasPrefix(repo, "requirecreds") { 1253 return false 1254 } 1255 1256 auth := r.Header.Get("Authorization") 1257 user, pass, err := extractAuth(auth) 1258 if err != nil { 1259 writeLFSError(w, 403, err.Error()) 1260 return true 1261 } 1262 1263 if user != "requirecreds" || pass != "pass" { 1264 writeLFSError(w, 403, fmt.Sprintf("Got: '%s' => '%s' : '%s'", auth, user, pass)) 1265 return true 1266 } 1267 1268 return false 1269 } 1270 1271 func testingChunkedTransferEncoding(r *http.Request) bool { 1272 return strings.HasPrefix(r.URL.String(), "/test-chunked-transfer-encoding") 1273 } 1274 1275 func testingTusUploadInBatchReq(r *http.Request) bool { 1276 return strings.HasPrefix(r.URL.String(), "/test-tus-upload") 1277 } 1278 func testingTusUploadInterruptedInBatchReq(r *http.Request) bool { 1279 return strings.HasPrefix(r.URL.String(), "/test-tus-upload-interrupt") 1280 } 1281 func testingCustomTransfer(r *http.Request) bool { 1282 return strings.HasPrefix(r.URL.String(), "/test-custom-transfer") 1283 } 1284 1285 var lfsUrlRE = regexp.MustCompile(`\A/?([^/]+)/info/lfs`) 1286 1287 func repoFromLfsUrl(urlpath string) (string, error) { 1288 matches := lfsUrlRE.FindStringSubmatch(urlpath) 1289 if len(matches) != 2 { 1290 return "", fmt.Errorf("LFS url '%s' does not match %v", urlpath, lfsUrlRE) 1291 } 1292 1293 repo := matches[1] 1294 if strings.HasSuffix(repo, ".git") { 1295 return repo[0 : len(repo)-4], nil 1296 } 1297 return repo, nil 1298 } 1299 1300 type lfsStorage struct { 1301 objects map[string]map[string][]byte 1302 incomplete map[string]map[string][]byte 1303 mutex *sync.Mutex 1304 } 1305 1306 func (s *lfsStorage) Get(repo, oid string) ([]byte, bool) { 1307 s.mutex.Lock() 1308 defer s.mutex.Unlock() 1309 repoObjects, ok := s.objects[repo] 1310 if !ok { 1311 return nil, ok 1312 } 1313 1314 by, ok := repoObjects[oid] 1315 return by, ok 1316 } 1317 1318 func (s *lfsStorage) Has(repo, oid string) bool { 1319 s.mutex.Lock() 1320 defer s.mutex.Unlock() 1321 repoObjects, ok := s.objects[repo] 1322 if !ok { 1323 return false 1324 } 1325 1326 _, ok = repoObjects[oid] 1327 return ok 1328 } 1329 1330 func (s *lfsStorage) Set(repo, oid string, by []byte) { 1331 s.mutex.Lock() 1332 defer s.mutex.Unlock() 1333 repoObjects, ok := s.objects[repo] 1334 if !ok { 1335 repoObjects = make(map[string][]byte) 1336 s.objects[repo] = repoObjects 1337 } 1338 repoObjects[oid] = by 1339 } 1340 1341 func (s *lfsStorage) Delete(repo, oid string) { 1342 s.mutex.Lock() 1343 defer s.mutex.Unlock() 1344 repoObjects, ok := s.objects[repo] 1345 if ok { 1346 delete(repoObjects, oid) 1347 } 1348 } 1349 1350 func (s *lfsStorage) GetIncomplete(repo, oid string) ([]byte, bool) { 1351 s.mutex.Lock() 1352 defer s.mutex.Unlock() 1353 repoObjects, ok := s.incomplete[repo] 1354 if !ok { 1355 return nil, ok 1356 } 1357 1358 by, ok := repoObjects[oid] 1359 return by, ok 1360 } 1361 1362 func (s *lfsStorage) SetIncomplete(repo, oid string, by []byte) { 1363 s.mutex.Lock() 1364 defer s.mutex.Unlock() 1365 repoObjects, ok := s.incomplete[repo] 1366 if !ok { 1367 repoObjects = make(map[string][]byte) 1368 s.incomplete[repo] = repoObjects 1369 } 1370 repoObjects[oid] = by 1371 } 1372 1373 func (s *lfsStorage) DeleteIncomplete(repo, oid string) { 1374 s.mutex.Lock() 1375 defer s.mutex.Unlock() 1376 repoObjects, ok := s.incomplete[repo] 1377 if ok { 1378 delete(repoObjects, oid) 1379 } 1380 } 1381 1382 func newLfsStorage() *lfsStorage { 1383 return &lfsStorage{ 1384 objects: make(map[string]map[string][]byte), 1385 incomplete: make(map[string]map[string][]byte), 1386 mutex: &sync.Mutex{}, 1387 } 1388 } 1389 1390 func extractAuth(auth string) (string, string, error) { 1391 if strings.HasPrefix(auth, "Basic ") { 1392 decodeBy, err := base64.StdEncoding.DecodeString(auth[6:len(auth)]) 1393 decoded := string(decodeBy) 1394 1395 if err != nil { 1396 return "", "", err 1397 } 1398 1399 parts := strings.SplitN(decoded, ":", 2) 1400 if len(parts) == 2 { 1401 return parts[0], parts[1], nil 1402 } 1403 return "", "", nil 1404 } 1405 1406 return "", "", nil 1407 } 1408 1409 func skipIfBadAuth(w http.ResponseWriter, r *http.Request, id string, ntlmSession ntlm.ServerSession) bool { 1410 auth := r.Header.Get("Authorization") 1411 if strings.Contains(r.URL.Path, "ntlm") { 1412 return false 1413 } 1414 1415 if auth == "" { 1416 w.WriteHeader(401) 1417 return true 1418 } 1419 1420 user, pass, err := extractAuth(auth) 1421 if err != nil { 1422 w.WriteHeader(403) 1423 debug(id, "Error decoding auth: %s", err) 1424 return true 1425 } 1426 1427 switch user { 1428 case "user": 1429 if pass == "pass" { 1430 return false 1431 } 1432 case "netrcuser", "requirecreds": 1433 return false 1434 case "path": 1435 if strings.HasPrefix(r.URL.Path, "/"+pass) { 1436 return false 1437 } 1438 debug(id, "auth attempt against: %q", r.URL.Path) 1439 } 1440 1441 w.WriteHeader(403) 1442 debug(id, "Bad auth: %q", auth) 1443 return true 1444 } 1445 1446 func handleNTLM(w http.ResponseWriter, r *http.Request, authHeader string, session ntlm.ServerSession) { 1447 if strings.HasPrefix(strings.ToUpper(authHeader), "BASIC ") { 1448 authHeader = "" 1449 } 1450 1451 switch authHeader { 1452 case "": 1453 w.Header().Set("Www-Authenticate", "ntlm") 1454 w.WriteHeader(401) 1455 1456 // ntlmNegotiateMessage from httputil pkg 1457 case "NTLM TlRMTVNTUAABAAAAB7IIogwADAAzAAAACwALACgAAAAKAAAoAAAAD1dJTExISS1NQUlOTk9SVEhBTUVSSUNB": 1458 ch, err := session.GenerateChallengeMessage() 1459 if err != nil { 1460 writeLFSError(w, 500, err.Error()) 1461 return 1462 } 1463 1464 chMsg := base64.StdEncoding.EncodeToString(ch.Bytes()) 1465 w.Header().Set("Www-Authenticate", "ntlm "+chMsg) 1466 w.WriteHeader(401) 1467 1468 default: 1469 if !strings.HasPrefix(strings.ToUpper(authHeader), "NTLM ") { 1470 writeLFSError(w, 500, "bad authorization header: "+authHeader) 1471 return 1472 } 1473 1474 auth := authHeader[5:] // strip "ntlm " prefix 1475 val, err := base64.StdEncoding.DecodeString(auth) 1476 if err != nil { 1477 writeLFSError(w, 500, "base64 decode error: "+err.Error()) 1478 return 1479 } 1480 1481 _, err = ntlm.ParseAuthenticateMessage(val, 2) 1482 if err != nil { 1483 writeLFSError(w, 500, "auth parse error: "+err.Error()) 1484 return 1485 } 1486 } 1487 } 1488 1489 func init() { 1490 oidHandlers = make(map[string]string) 1491 for _, content := range contentHandlers { 1492 h := sha256.New() 1493 h.Write([]byte(content)) 1494 oidHandlers[hex.EncodeToString(h.Sum(nil))] = content 1495 } 1496 } 1497 1498 func debug(reqid, msg string, args ...interface{}) { 1499 fullargs := make([]interface{}, len(args)+1) 1500 fullargs[0] = reqid 1501 for i, a := range args { 1502 fullargs[i+1] = a 1503 } 1504 log.Printf("[%s] "+msg+"\n", fullargs...) 1505 } 1506 1507 func reqId(w http.ResponseWriter) (string, bool) { 1508 b := make([]byte, 16) 1509 _, err := rand.Read(b) 1510 if err != nil { 1511 http.Error(w, "error generating id: "+err.Error(), 500) 1512 return "", false 1513 } 1514 return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), true 1515 } 1516 1517 // https://ericchiang.github.io/post/go-tls/ 1518 func generateCARootCertificates() (rootKey *rsa.PrivateKey, rootCert *x509.Certificate) { 1519 1520 // generate a new key-pair 1521 rootKey, err := rsa.GenerateKey(rand.Reader, 2048) 1522 if err != nil { 1523 log.Fatalf("generating random key: %v", err) 1524 } 1525 1526 rootCertTmpl, err := CertTemplate() 1527 if err != nil { 1528 log.Fatalf("creating cert template: %v", err) 1529 } 1530 // describe what the certificate will be used for 1531 rootCertTmpl.IsCA = true 1532 rootCertTmpl.KeyUsage = x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature 1533 rootCertTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth} 1534 // rootCertTmpl.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")} 1535 1536 rootCert, _, err = CreateCert(rootCertTmpl, rootCertTmpl, &rootKey.PublicKey, rootKey) 1537 1538 return 1539 } 1540 1541 func generateClientCertificates(rootCert *x509.Certificate, rootKey interface{}) (clientKey *rsa.PrivateKey, clientCertPEM []byte, clientKeyPEM []byte) { 1542 1543 // create a key-pair for the client 1544 clientKey, err := rsa.GenerateKey(rand.Reader, 2048) 1545 if err != nil { 1546 log.Fatalf("generating random key: %v", err) 1547 } 1548 1549 // create a template for the client 1550 clientCertTmpl, err1 := CertTemplate() 1551 if err1 != nil { 1552 log.Fatalf("creating cert template: %v", err1) 1553 } 1554 clientCertTmpl.KeyUsage = x509.KeyUsageDigitalSignature 1555 clientCertTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} 1556 1557 // the root cert signs the cert by again providing its private key 1558 _, clientCertPEM, err2 := CreateCert(clientCertTmpl, rootCert, &clientKey.PublicKey, rootKey) 1559 if err2 != nil { 1560 log.Fatalf("error creating cert: %v", err2) 1561 } 1562 1563 // encode and load the cert and private key for the client 1564 clientKeyPEM = pem.EncodeToMemory(&pem.Block{ 1565 Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey), 1566 }) 1567 1568 return 1569 } 1570 1571 // helper function to create a cert template with a serial number and other required fields 1572 func CertTemplate() (*x509.Certificate, error) { 1573 // generate a random serial number (a real cert authority would have some logic behind this) 1574 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 1575 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 1576 if err != nil { 1577 return nil, errors.New("failed to generate serial number: " + err.Error()) 1578 } 1579 1580 tmpl := x509.Certificate{ 1581 SerialNumber: serialNumber, 1582 Subject: pkix.Name{Organization: []string{"Yhat, Inc."}}, 1583 SignatureAlgorithm: x509.SHA256WithRSA, 1584 NotBefore: time.Now(), 1585 NotAfter: time.Now().Add(time.Hour), // valid for an hour 1586 BasicConstraintsValid: true, 1587 } 1588 return &tmpl, nil 1589 } 1590 1591 func CreateCert(template, parent *x509.Certificate, pub interface{}, parentPriv interface{}) ( 1592 cert *x509.Certificate, certPEM []byte, err error) { 1593 1594 certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, parentPriv) 1595 if err != nil { 1596 return 1597 } 1598 // parse the resulting certificate so we can use it again 1599 cert, err = x509.ParseCertificate(certDER) 1600 if err != nil { 1601 return 1602 } 1603 // PEM encode the certificate (this is a standard TLS encoding) 1604 b := pem.Block{Type: "CERTIFICATE", Bytes: certDER} 1605 certPEM = pem.EncodeToMemory(&b) 1606 return 1607 }