github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/dashboard/app/util_test.go (about) 1 // Copyright 2017 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 // The test uses aetest package that starts local dev_appserver and handles all requests locally: 5 // https://cloud.google.com/appengine/docs/standard/go/tools/localunittesting/reference 6 7 package main 8 9 import ( 10 "bytes" 11 "context" 12 "errors" 13 "fmt" 14 "io" 15 "net/http" 16 "net/http/httptest" 17 "net/url" 18 "os" 19 "os/exec" 20 "path/filepath" 21 "reflect" 22 "runtime" 23 "strings" 24 "sync" 25 "testing" 26 "time" 27 28 "github.com/google/go-cmp/cmp" 29 "github.com/google/syzkaller/dashboard/api" 30 "github.com/google/syzkaller/dashboard/dashapi" 31 "github.com/google/syzkaller/pkg/coveragedb/spannerclient" 32 "github.com/google/syzkaller/pkg/covermerger" 33 "github.com/google/syzkaller/pkg/email" 34 "github.com/google/syzkaller/pkg/subsystem" 35 "google.golang.org/appengine/v2/aetest" 36 db "google.golang.org/appengine/v2/datastore" 37 "google.golang.org/appengine/v2/log" 38 aemail "google.golang.org/appengine/v2/mail" 39 ) 40 41 type Ctx struct { 42 t *testing.T 43 inst aetest.Instance 44 ctx context.Context 45 mockedTime time.Time 46 emailSink chan *aemail.Message 47 transformContext func(context.Context) context.Context 48 client *apiClient 49 client2 *apiClient 50 publicClient *apiClient 51 } 52 53 var skipDevAppserverTests = func() bool { 54 _, err := exec.LookPath("dev_appserver.py") 55 // Don't silently skip tests on CI, we should have gcloud sdk installed there. 56 return err != nil && os.Getenv("SYZ_ENV") == "" || 57 os.Getenv("SYZ_SKIP_DEV_APPSERVER_TESTS") != "" 58 }() 59 60 func NewCtx(t *testing.T) *Ctx { 61 if skipDevAppserverTests { 62 t.Skip("skipping test (no dev_appserver.py)") 63 } 64 t.Parallel() 65 inst, err := aetest.NewInstance(&aetest.Options{ 66 StartupTimeout: 120 * time.Second, 67 // Without this option datastore queries return data with slight delay, 68 // which fails reporting tests. 69 StronglyConsistentDatastore: true, 70 }) 71 if err != nil { 72 t.Fatal(err) 73 } 74 r, err := inst.NewRequest("GET", "", nil) 75 if err != nil { 76 t.Fatal(err) 77 } 78 c := &Ctx{ 79 t: t, 80 inst: inst, 81 mockedTime: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), 82 emailSink: make(chan *aemail.Message, 100), 83 transformContext: func(c context.Context) context.Context { return c }, 84 } 85 c.client = c.makeClient(client1, password1, true) 86 c.client2 = c.makeClient(client2, password2, true) 87 c.publicClient = c.makeClient(clientPublicEmail, keyPublicEmail, true) 88 c.ctx = registerRequest(r, c).Context() 89 return c 90 } 91 92 func (c *Ctx) config() *GlobalConfig { 93 return getConfig(c.ctx) 94 } 95 96 func (c *Ctx) expectOK(err error) { 97 if err != nil { 98 c.t.Helper() 99 c.t.Fatalf("expected OK, got error: %v", err) 100 } 101 } 102 103 func (c *Ctx) expectFail(msg string, err error) { 104 c.t.Helper() 105 if err == nil { 106 c.t.Fatalf("expected to fail, but it does not") 107 } 108 if !strings.Contains(err.Error(), msg) { 109 c.t.Fatalf("expected to fail with %q, but failed with %q", msg, err) 110 } 111 } 112 113 func expectFailureStatus(t *testing.T, err error, code int) { 114 t.Helper() 115 if err == nil { 116 t.Fatalf("expected to fail as %d, but it does not", code) 117 } 118 var httpErr *HTTPError 119 if !errors.As(err, &httpErr) || httpErr.Code != code { 120 t.Fatalf("expected to fail as %d, but it failed as %v", code, err) 121 } 122 } 123 124 func (c *Ctx) expectBadReqest(err error) { 125 expectFailureStatus(c.t, err, http.StatusBadRequest) 126 } 127 128 func (c *Ctx) expectEQ(got, want interface{}) { 129 if diff := cmp.Diff(got, want); diff != "" { 130 c.t.Helper() 131 c.t.Fatal(diff) 132 } 133 } 134 135 func (c *Ctx) expectNE(got, want interface{}) { 136 if reflect.DeepEqual(got, want) { 137 c.t.Helper() 138 c.t.Fatalf("equal: %#v", got) 139 } 140 } 141 142 func (c *Ctx) expectTrue(v bool) { 143 if !v { 144 c.t.Helper() 145 c.t.Fatal("failed") 146 } 147 } 148 149 func caller(skip int) string { 150 pcs := make([]uintptr, 10) 151 n := runtime.Callers(skip+3, pcs) 152 pcs = pcs[:n] 153 frames := runtime.CallersFrames(pcs) 154 stack := "" 155 for { 156 frame, more := frames.Next() 157 if strings.HasPrefix(frame.Function, "testing.") { 158 break 159 } 160 stack = fmt.Sprintf("%v:%v\n", filepath.Base(frame.File), frame.Line) + stack 161 if !more { 162 break 163 } 164 } 165 if stack != "" { 166 stack = stack[:len(stack)-1] 167 } 168 return stack 169 } 170 171 func (c *Ctx) Close() { 172 defer c.inst.Close() 173 if !c.t.Failed() { 174 // To avoid per-day reporting limits for left-over emails. 175 c.advanceTime(25 * time.Hour) 176 // Ensure that we can render main page and all bugs in the final test state. 177 _, err := c.GET("/test1") 178 c.expectOK(err) 179 _, err = c.GET("/test2") 180 c.expectOK(err) 181 _, err = c.GET("/test1/fixed") 182 c.expectOK(err) 183 _, err = c.GET("/test2/fixed") 184 c.expectOK(err) 185 _, err = c.GET("/admin") 186 c.expectOK(err) 187 var bugs []*Bug 188 keys, err := db.NewQuery("Bug").GetAll(c.ctx, &bugs) 189 if err != nil { 190 c.t.Errorf("ERROR: failed to query bugs: %v", err) 191 } 192 for _, key := range keys { 193 _, err = c.GET(fmt.Sprintf("/bug?id=%v", key.StringID())) 194 c.expectOK(err) 195 } 196 // No pending emails (tests need to consume them). 197 _, err = c.GET("/cron/email_poll") 198 c.expectOK(err) 199 for len(c.emailSink) != 0 { 200 c.t.Errorf("ERROR: leftover email: %v", (<-c.emailSink).Body) 201 } 202 // No pending external reports (tests need to consume them). 203 resp, _ := c.client.ReportingPollBugs("test") 204 for _, rep := range resp.Reports { 205 c.t.Errorf("ERROR: leftover external report:\n%#v", rep) 206 } 207 } 208 unregisterContext(c) 209 validateGlobalConfig() 210 } 211 212 func (c *Ctx) advanceTime(d time.Duration) { 213 c.mockedTime = c.mockedTime.Add(d) 214 } 215 216 func (c *Ctx) setSubsystems(ns string, list []*subsystem.Subsystem, rev int) { 217 c.transformContext = func(c context.Context) context.Context { 218 newConfig := replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 219 ret := *cfg 220 ret.Subsystems.Service = subsystem.MustMakeService(list, rev) 221 return &ret 222 }) 223 return contextWithConfig(c, newConfig) 224 } 225 } 226 227 func (c *Ctx) setCoverageMocks(ns string, dbClientMock spannerclient.SpannerClient, 228 fileProvMock covermerger.FileVersProvider) { 229 c.transformContext = func(ctx context.Context) context.Context { 230 newConfig := replaceNamespaceConfig(ctx, ns, func(cfg *Config) *Config { 231 ret := *cfg 232 ret.Coverage = &CoverageConfig{WebGitURI: "test-git"} 233 return &ret 234 }) 235 ctxWithSpanner := setCoverageDBClient(ctx, dbClientMock) 236 ctxWithSpannerAndFileProvider := setWebGit(ctxWithSpanner, fileProvMock) 237 return contextWithConfig(ctxWithSpannerAndFileProvider, newConfig) 238 } 239 } 240 241 func (c *Ctx) setKernelRepos(ns string, list []KernelRepo) { 242 c.transformContext = func(c context.Context) context.Context { 243 newConfig := replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 244 ret := *cfg 245 ret.Repos = list 246 return &ret 247 }) 248 return contextWithConfig(c, newConfig) 249 } 250 } 251 252 func (c *Ctx) setNoObsoletions() { 253 c.transformContext = func(c context.Context) context.Context { 254 return contextWithNoObsoletions(c) 255 } 256 } 257 258 func (c *Ctx) updateReporting(ns, name string, f func(Reporting) Reporting) { 259 c.transformContext = func(c context.Context) context.Context { 260 return contextWithConfig(c, replaceReporting(c, ns, name, f)) 261 } 262 } 263 264 func (c *Ctx) decommissionManager(ns, oldManager, newManager string) { 265 c.transformContext = func(c context.Context) context.Context { 266 newConfig := replaceManagerConfig(c, ns, oldManager, func(cfg ConfigManager) ConfigManager { 267 cfg.Decommissioned = true 268 cfg.DelegatedTo = newManager 269 return cfg 270 }) 271 return contextWithConfig(c, newConfig) 272 } 273 } 274 275 func (c *Ctx) decommission(ns string) { 276 c.transformContext = func(c context.Context) context.Context { 277 newConfig := replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 278 ret := *cfg 279 ret.Decommissioned = true 280 return &ret 281 }) 282 return contextWithConfig(c, newConfig) 283 } 284 } 285 286 func (c *Ctx) setWaitForRepro(ns string, d time.Duration) { 287 c.transformContext = func(c context.Context) context.Context { 288 newConfig := replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 289 ret := *cfg 290 ret.WaitForRepro = d 291 return &ret 292 }) 293 return contextWithConfig(c, newConfig) 294 } 295 } 296 297 // GET sends admin-authorized HTTP GET request to the app. 298 func (c *Ctx) GET(url string) ([]byte, error) { 299 return c.AuthGET(AccessAdmin, url) 300 } 301 302 // AuthGET sends HTTP GET request to the app with the specified authorization. 303 func (c *Ctx) AuthGET(access AccessLevel, url string) ([]byte, error) { 304 w, err := c.httpRequest("GET", url, "", "", access) 305 if err != nil { 306 return nil, err 307 } 308 return w.Body.Bytes(), nil 309 } 310 311 // POST sends admin-authorized HTTP POST requestd to the app. 312 func (c *Ctx) POST(url, body string) ([]byte, error) { 313 w, err := c.httpRequest("POST", url, body, "", AccessAdmin) 314 if err != nil { 315 return nil, err 316 } 317 return w.Body.Bytes(), nil 318 } 319 320 // POST sends an admin-authorized HTTP POST form to the app. 321 func (c *Ctx) POSTForm(url string, form url.Values) ([]byte, error) { 322 w, err := c.httpRequest("POST", url, form.Encode(), 323 "application/x-www-form-urlencoded", AccessAdmin) 324 if err != nil { 325 return nil, err 326 } 327 return w.Body.Bytes(), nil 328 } 329 330 // ContentType returns the response Content-Type header value. 331 func (c *Ctx) ContentType(url string) (string, error) { 332 w, err := c.httpRequest("HEAD", url, "", "", AccessAdmin) 333 if err != nil { 334 return "", err 335 } 336 values := w.Header()["Content-Type"] 337 if len(values) == 0 { 338 return "", fmt.Errorf("no Content-Type") 339 } 340 return values[0], nil 341 } 342 343 func (c *Ctx) httpRequest(method, url, body, contentType string, 344 access AccessLevel) (*httptest.ResponseRecorder, error) { 345 c.t.Logf("%v: %v", method, url) 346 r, err := c.inst.NewRequest(method, url, strings.NewReader(body)) 347 if err != nil { 348 c.t.Fatal(err) 349 } 350 r.Header.Add("X-Appengine-User-IP", "127.0.0.1") 351 if contentType != "" { 352 r.Header.Add("Content-Type", contentType) 353 } 354 r = registerRequest(r, c) 355 r = r.WithContext(c.transformContext(r.Context())) 356 switch access { 357 case AccessAdmin: 358 aetest.Login(makeUser(AuthorizedAdmin), r) 359 case AccessUser: 360 aetest.Login(makeUser(AuthorizedUser), r) 361 } 362 w := httptest.NewRecorder() 363 http.DefaultServeMux.ServeHTTP(w, r) 364 c.t.Logf("REPLY: %v", w.Code) 365 if w.Code != http.StatusOK { 366 return nil, &HTTPError{w.Code, w.Body.String(), w.Result().Header} 367 } 368 return w, nil 369 } 370 371 type HTTPError struct { 372 Code int 373 Body string 374 Headers http.Header 375 } 376 377 func (err *HTTPError) Error() string { 378 return fmt.Sprintf("%v: %v", err.Code, err.Body) 379 } 380 381 func (c *Ctx) loadBug(extID string) (*Bug, *Crash, *Build) { 382 bug, _, err := findBugByReportingID(c.ctx, extID) 383 if err != nil { 384 c.t.Fatalf("failed to load bug: %v", err) 385 } 386 return c.loadBugInfo(bug) 387 } 388 389 func (c *Ctx) loadBugByHash(hash string) (*Bug, *Crash, *Build) { 390 bug := new(Bug) 391 bugKey := db.NewKey(c.ctx, "Bug", hash, 0, nil) 392 c.expectOK(db.Get(c.ctx, bugKey, bug)) 393 return c.loadBugInfo(bug) 394 } 395 396 func (c *Ctx) loadBugInfo(bug *Bug) (*Bug, *Crash, *Build) { 397 crash, _, err := findCrashForBug(c.ctx, bug) 398 if err != nil { 399 c.t.Fatalf("failed to load crash: %v", err) 400 } 401 build := c.loadBuild(bug.Namespace, crash.BuildID) 402 return bug, crash, build 403 } 404 405 func (c *Ctx) loadJob(extID string) (*Job, *Build, *Crash) { 406 jobKey, err := jobID2Key(c.ctx, extID) 407 if err != nil { 408 c.t.Fatalf("failed to create job key: %v", err) 409 } 410 job := new(Job) 411 if err := db.Get(c.ctx, jobKey, job); err != nil { 412 c.t.Fatalf("failed to get job %v: %v", extID, err) 413 } 414 build := c.loadBuild(job.Namespace, job.BuildID) 415 crash := new(Crash) 416 crashKey := db.NewKey(c.ctx, "Crash", "", job.CrashID, jobKey.Parent()) 417 if err := db.Get(c.ctx, crashKey, crash); err != nil { 418 c.t.Fatalf("failed to load crash for job: %v", err) 419 } 420 return job, build, crash 421 } 422 423 func (c *Ctx) loadBuild(ns, id string) *Build { 424 build, err := loadBuild(c.ctx, ns, id) 425 c.expectOK(err) 426 return build 427 } 428 429 func (c *Ctx) loadManager(ns, name string) (*Manager, *Build) { 430 mgr, err := loadManager(c.ctx, ns, name) 431 c.expectOK(err) 432 build := c.loadBuild(ns, mgr.CurrentBuild) 433 return mgr, build 434 } 435 436 func (c *Ctx) loadSingleBug() (*Bug, *db.Key) { 437 var bugs []*Bug 438 keys, err := db.NewQuery("Bug").GetAll(c.ctx, &bugs) 439 c.expectEQ(err, nil) 440 c.expectEQ(len(bugs), 1) 441 442 return bugs[0], keys[0] 443 } 444 445 func (c *Ctx) loadSingleJob() (*Job, *db.Key) { 446 var jobs []*Job 447 keys, err := db.NewQuery("Job").GetAll(c.ctx, &jobs) 448 c.expectEQ(err, nil) 449 c.expectEQ(len(jobs), 1) 450 451 return jobs[0], keys[0] 452 } 453 454 func (c *Ctx) checkURLContents(url string, want []byte) { 455 c.t.Helper() 456 got, err := c.AuthGET(AccessAdmin, url) 457 if err != nil { 458 c.t.Fatalf("%v request failed: %v", url, err) 459 } 460 if !bytes.Equal(got, want) { 461 c.t.Fatalf("url %v: got:\n%s\nwant:\n%s\n", url, got, want) 462 } 463 } 464 465 func (c *Ctx) pollEmailBug() *aemail.Message { 466 _, err := c.GET("/cron/email_poll") 467 c.expectOK(err) 468 if len(c.emailSink) == 0 { 469 c.t.Helper() 470 c.t.Fatal("got no emails") 471 } 472 return <-c.emailSink 473 } 474 475 func (c *Ctx) pollEmailExtID() string { 476 c.t.Helper() 477 _, extBugID := c.pollEmailAndExtID() 478 return extBugID 479 } 480 481 func (c *Ctx) pollEmailAndExtID() (string, string) { 482 c.t.Helper() 483 msg := c.pollEmailBug() 484 _, extBugID, err := email.RemoveAddrContext(msg.Sender) 485 if err != nil { 486 c.t.Fatalf("failed to remove addr context: %v", err) 487 } 488 return msg.Sender, extBugID 489 } 490 491 func (c *Ctx) expectNoEmail() { 492 _, err := c.GET("/cron/email_poll") 493 c.expectOK(err) 494 if len(c.emailSink) != 0 { 495 msg := <-c.emailSink 496 c.t.Helper() 497 c.t.Fatalf("got unexpected email: %v\n%s", msg.Subject, msg.Body) 498 } 499 } 500 501 type apiClient struct { 502 *Ctx 503 *dashapi.Dashboard 504 } 505 506 func (c *Ctx) makeClient(client, key string, failOnErrors bool) *apiClient { 507 logger := func(msg string, args ...interface{}) { 508 c.t.Logf("%v: "+msg, append([]interface{}{caller(3)}, args...)...) 509 } 510 errorHandler := func(err error) { 511 if failOnErrors { 512 c.t.Fatalf("\n%v: %v", caller(2), err) 513 } 514 } 515 dash, err := dashapi.NewCustom(client, "", key, c.inst.NewRequest, c.httpDoer(), logger, errorHandler) 516 if err != nil { 517 panic(fmt.Sprintf("Impossible error: %v", err)) 518 } 519 return &apiClient{ 520 Ctx: c, 521 Dashboard: dash, 522 } 523 } 524 525 func (c *Ctx) makeAPIClient() *api.Client { 526 return api.NewTestClient(c.inst.NewRequest, c.httpDoer()) 527 } 528 529 func (c *Ctx) httpDoer() func(*http.Request) (*http.Response, error) { 530 return func(r *http.Request) (*http.Response, error) { 531 r = registerRequest(r, c) 532 r = r.WithContext(c.transformContext(r.Context())) 533 w := httptest.NewRecorder() 534 http.DefaultServeMux.ServeHTTP(w, r) 535 res := &http.Response{ 536 StatusCode: w.Code, 537 Status: http.StatusText(w.Code), 538 Body: io.NopCloser(w.Result().Body), 539 } 540 return res, nil 541 } 542 } 543 544 func (client *apiClient) pollBugs(expect int) []*dashapi.BugReport { 545 resp, _ := client.ReportingPollBugs("test") 546 if len(resp.Reports) != expect { 547 client.t.Helper() 548 client.t.Fatalf("want %v reports, got %v", expect, len(resp.Reports)) 549 } 550 for _, rep := range resp.Reports { 551 reproLevel := dashapi.ReproLevelNone 552 if len(rep.ReproC) != 0 { 553 reproLevel = dashapi.ReproLevelC 554 } else if len(rep.ReproSyz) != 0 { 555 reproLevel = dashapi.ReproLevelSyz 556 } 557 reply, _ := client.ReportingUpdate(&dashapi.BugUpdate{ 558 ID: rep.ID, 559 JobID: rep.JobID, 560 Status: dashapi.BugStatusOpen, 561 ReproLevel: reproLevel, 562 CrashID: rep.CrashID, 563 }) 564 client.expectEQ(reply.Error, false) 565 client.expectEQ(reply.OK, true) 566 } 567 return resp.Reports 568 } 569 570 func (client *apiClient) pollBug() *dashapi.BugReport { 571 return client.pollBugs(1)[0] 572 } 573 574 func (client *apiClient) pollNotifs(expect int) []*dashapi.BugNotification { 575 resp, _ := client.ReportingPollNotifications("test") 576 if len(resp.Notifications) != expect { 577 client.t.Helper() 578 client.t.Fatalf("want %v notifs, got %v", expect, len(resp.Notifications)) 579 } 580 return resp.Notifications 581 } 582 583 func (client *apiClient) updateBug(extID string, status dashapi.BugStatus, dup string) { 584 reply, _ := client.ReportingUpdate(&dashapi.BugUpdate{ 585 ID: extID, 586 Status: status, 587 DupOf: dup, 588 }) 589 client.expectTrue(reply.OK) 590 } 591 592 func (client *apiClient) pollSpecificJobs(manager string, jobs dashapi.ManagerJobs) *dashapi.JobPollResp { 593 req := &dashapi.JobPollReq{ 594 Managers: map[string]dashapi.ManagerJobs{ 595 manager: jobs, 596 }, 597 } 598 resp, err := client.JobPoll(req) 599 client.expectOK(err) 600 return resp 601 } 602 603 func (client *apiClient) pollJobs(manager string) *dashapi.JobPollResp { 604 return client.pollSpecificJobs(manager, dashapi.ManagerJobs{ 605 TestPatches: true, 606 BisectCause: true, 607 BisectFix: true, 608 }) 609 } 610 611 func (client *apiClient) pollAndFailBisectJob(manager string) { 612 resp := client.pollJobs(manager) 613 client.expectNE(resp.ID, "") 614 client.expectEQ(resp.Type, dashapi.JobBisectCause) 615 done := &dashapi.JobDoneReq{ 616 ID: resp.ID, 617 Error: []byte("pollAndFailBisectJob"), 618 } 619 client.expectOK(client.JobDone(done)) 620 } 621 622 type ( 623 EmailOptMessageID int 624 EmailOptSubject string 625 EmailOptFrom string 626 EmailOptOrigFrom string 627 EmailOptCC []string 628 EmailOptSender string 629 ) 630 631 func (c *Ctx) incomingEmail(to, body string, opts ...interface{}) { 632 id := 0 633 subject := "crash1" 634 from := "default@sender.com" 635 cc := []string{"test@syzkaller.com", "bugs@syzkaller.com", "bugs2@syzkaller.com"} 636 sender := "" 637 origFrom := "" 638 for _, o := range opts { 639 switch opt := o.(type) { 640 case EmailOptMessageID: 641 id = int(opt) 642 case EmailOptSubject: 643 subject = string(opt) 644 case EmailOptFrom: 645 from = string(opt) 646 case EmailOptSender: 647 sender = string(opt) 648 case EmailOptCC: 649 cc = []string(opt) 650 case EmailOptOrigFrom: 651 origFrom = fmt.Sprintf("\nX-Original-From: %v", string(opt)) 652 } 653 } 654 if sender == "" { 655 sender = from 656 } 657 email := fmt.Sprintf(`Sender: %v 658 Date: Tue, 15 Aug 2017 14:59:00 -0700 659 Message-ID: <%v> 660 Subject: %v 661 From: %v 662 Cc: %v 663 To: %v%v 664 Content-Type: text/plain 665 666 %v 667 `, sender, id, subject, from, strings.Join(cc, ","), to, origFrom, body) 668 log.Infof(c.ctx, "sending %s", email) 669 _, err := c.POST("/_ah/mail/"+to, email) 670 c.expectOK(err) 671 } 672 673 func initMocks() { 674 // Mock time as some functionality relies on real time. 675 timeNow = func(c context.Context) time.Time { 676 return getRequestContext(c).mockedTime 677 } 678 sendEmail = func(c context.Context, msg *aemail.Message) error { 679 getRequestContext(c).emailSink <- msg 680 return nil 681 } 682 maxCrashes = func() int { 683 // dev_appserver is very slow, so let's make tests smaller. 684 const maxCrashesDuringTest = 20 685 return maxCrashesDuringTest 686 } 687 } 688 689 // Machinery to associate mocked time with requests. 690 type RequestMapping struct { 691 id int 692 ctx *Ctx 693 } 694 695 var ( 696 requestMu sync.Mutex 697 requestNum int 698 requestContexts []RequestMapping 699 ) 700 701 func registerRequest(r *http.Request, c *Ctx) *http.Request { 702 requestMu.Lock() 703 defer requestMu.Unlock() 704 705 requestNum++ 706 newContext := context.WithValue(r.Context(), requestIDKey{}, requestNum) 707 newRequest := r.WithContext(newContext) 708 requestContexts = append(requestContexts, RequestMapping{requestNum, c}) 709 return newRequest 710 } 711 712 func getRequestContext(c context.Context) *Ctx { 713 requestMu.Lock() 714 defer requestMu.Unlock() 715 reqID := getRequestID(c) 716 for _, m := range requestContexts { 717 if m.id == reqID { 718 return m.ctx 719 } 720 } 721 panic(fmt.Sprintf("no context for: %#v", c)) 722 } 723 724 func unregisterContext(c *Ctx) { 725 requestMu.Lock() 726 defer requestMu.Unlock() 727 n := 0 728 for _, m := range requestContexts { 729 if m.ctx == c { 730 continue 731 } 732 requestContexts[n] = m 733 n++ 734 } 735 requestContexts = requestContexts[:n] 736 } 737 738 type requestIDKey struct{} 739 740 func getRequestID(c context.Context) int { 741 val, ok := c.Value(requestIDKey{}).(int) 742 if !ok { 743 panic("the context did not come from a test") 744 } 745 return val 746 } 747 748 // Create a shallow copy of GlobalConfig with a replaced namespace config. 749 func replaceNamespaceConfig(c context.Context, ns string, f func(*Config) *Config) *GlobalConfig { 750 ret := *getConfig(c) 751 newNsMap := map[string]*Config{} 752 for name, nsCfg := range ret.Namespaces { 753 if name == ns { 754 nsCfg = f(nsCfg) 755 } 756 newNsMap[name] = nsCfg 757 } 758 ret.Namespaces = newNsMap 759 return &ret 760 } 761 762 func replaceManagerConfig(c context.Context, ns, mgr string, f func(ConfigManager) ConfigManager) *GlobalConfig { 763 return replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 764 ret := *cfg 765 newMgrMap := map[string]ConfigManager{} 766 for name, mgrCfg := range ret.Managers { 767 if name == mgr { 768 mgrCfg = f(mgrCfg) 769 } 770 newMgrMap[name] = mgrCfg 771 } 772 ret.Managers = newMgrMap 773 return &ret 774 }) 775 } 776 777 func replaceReporting(c context.Context, ns, name string, f func(Reporting) Reporting) *GlobalConfig { 778 return replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 779 ret := *cfg 780 var newReporting []Reporting 781 for _, cfg := range ret.Reporting { 782 if cfg.Name == name { 783 cfg = f(cfg) 784 } 785 newReporting = append(newReporting, cfg) 786 } 787 ret.Reporting = newReporting 788 return &ret 789 }) 790 }