github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/dashboard/app/util_test.go (about) 1 // Copyright 2017 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 // The test uses aetest package that starts local dev_appserver and handles all requests locally: 5 // https://cloud.google.com/appengine/docs/standard/go/tools/localunittesting/reference 6 7 package main 8 9 import ( 10 "bytes" 11 "context" 12 "errors" 13 "fmt" 14 "io" 15 "net/http" 16 "net/http/httptest" 17 "os" 18 "os/exec" 19 "path/filepath" 20 "reflect" 21 "runtime" 22 "strings" 23 "sync" 24 "testing" 25 "time" 26 27 "github.com/google/go-cmp/cmp" 28 "github.com/google/syzkaller/dashboard/dashapi" 29 "github.com/google/syzkaller/pkg/email" 30 "github.com/google/syzkaller/pkg/subsystem" 31 "google.golang.org/appengine/v2/aetest" 32 db "google.golang.org/appengine/v2/datastore" 33 "google.golang.org/appengine/v2/log" 34 aemail "google.golang.org/appengine/v2/mail" 35 "google.golang.org/appengine/v2/user" 36 ) 37 38 type Ctx struct { 39 t *testing.T 40 inst aetest.Instance 41 ctx context.Context 42 mockedTime time.Time 43 emailSink chan *aemail.Message 44 transformContext func(context.Context) context.Context 45 client *apiClient 46 client2 *apiClient 47 publicClient *apiClient 48 } 49 50 var skipDevAppserverTests = func() bool { 51 _, err := exec.LookPath("dev_appserver.py") 52 // Don't silently skip tests on CI, we should have gcloud sdk installed there. 53 return err != nil && os.Getenv("SYZ_ENV") == "" || 54 os.Getenv("SYZ_SKIP_DEV_APPSERVER_TESTS") != "" 55 }() 56 57 func NewCtx(t *testing.T) *Ctx { 58 if skipDevAppserverTests { 59 t.Skip("skipping test (no dev_appserver.py)") 60 } 61 t.Parallel() 62 inst, err := aetest.NewInstance(&aetest.Options{ 63 // Without this option datastore queries return data with slight delay, 64 // which fails reporting tests. 65 StronglyConsistentDatastore: true, 66 }) 67 if err != nil { 68 t.Fatal(err) 69 } 70 r, err := inst.NewRequest("GET", "", nil) 71 if err != nil { 72 t.Fatal(err) 73 } 74 c := &Ctx{ 75 t: t, 76 inst: inst, 77 mockedTime: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), 78 emailSink: make(chan *aemail.Message, 100), 79 transformContext: func(c context.Context) context.Context { return c }, 80 } 81 c.client = c.makeClient(client1, password1, true) 82 c.client2 = c.makeClient(client2, password2, true) 83 c.publicClient = c.makeClient(clientPublicEmail, keyPublicEmail, true) 84 c.ctx = registerRequest(r, c).Context() 85 return c 86 } 87 88 func (c *Ctx) config() *GlobalConfig { 89 return getConfig(c.ctx) 90 } 91 92 func (c *Ctx) expectOK(err error) { 93 if err != nil { 94 c.t.Helper() 95 c.t.Fatalf("expected OK, got error: %v", err) 96 } 97 } 98 99 func (c *Ctx) expectFail(msg string, err error) { 100 c.t.Helper() 101 if err == nil { 102 c.t.Fatalf("expected to fail, but it does not") 103 } 104 if !strings.Contains(err.Error(), msg) { 105 c.t.Fatalf("expected to fail with %q, but failed with %q", msg, err) 106 } 107 } 108 109 func (c *Ctx) expectFailureStatus(err error, code int) { 110 c.t.Helper() 111 if err == nil { 112 c.t.Fatalf("expected to fail as %d, but it does not", code) 113 } 114 var httpErr *HTTPError 115 if !errors.As(err, &httpErr) || httpErr.Code != code { 116 c.t.Fatalf("expected to fail as %d, but it failed as %v", code, err) 117 } 118 } 119 120 func (c *Ctx) expectForbidden(err error) { 121 c.expectFailureStatus(err, http.StatusForbidden) 122 } 123 124 func (c *Ctx) expectBadReqest(err error) { 125 c.expectFailureStatus(err, http.StatusBadRequest) 126 } 127 128 func (c *Ctx) expectEQ(got, want interface{}) { 129 if diff := cmp.Diff(got, want); diff != "" { 130 c.t.Helper() 131 c.t.Fatal(diff) 132 } 133 } 134 135 func (c *Ctx) expectNE(got, want interface{}) { 136 if reflect.DeepEqual(got, want) { 137 c.t.Helper() 138 c.t.Fatalf("equal: %#v", got) 139 } 140 } 141 142 func (c *Ctx) expectTrue(v bool) { 143 if !v { 144 c.t.Helper() 145 c.t.Fatal("failed") 146 } 147 } 148 149 func caller(skip int) string { 150 pcs := make([]uintptr, 10) 151 n := runtime.Callers(skip+3, pcs) 152 pcs = pcs[:n] 153 frames := runtime.CallersFrames(pcs) 154 stack := "" 155 for { 156 frame, more := frames.Next() 157 if strings.HasPrefix(frame.Function, "testing.") { 158 break 159 } 160 stack = fmt.Sprintf("%v:%v\n", filepath.Base(frame.File), frame.Line) + stack 161 if !more { 162 break 163 } 164 } 165 if stack != "" { 166 stack = stack[:len(stack)-1] 167 } 168 return stack 169 } 170 171 func (c *Ctx) Close() { 172 defer c.inst.Close() 173 if !c.t.Failed() { 174 // To avoid per-day reporting limits for left-over emails. 175 c.advanceTime(25 * time.Hour) 176 // Ensure that we can render main page and all bugs in the final test state. 177 _, err := c.GET("/test1") 178 c.expectOK(err) 179 _, err = c.GET("/test2") 180 c.expectOK(err) 181 _, err = c.GET("/test1/fixed") 182 c.expectOK(err) 183 _, err = c.GET("/test2/fixed") 184 c.expectOK(err) 185 _, err = c.GET("/admin") 186 c.expectOK(err) 187 var bugs []*Bug 188 keys, err := db.NewQuery("Bug").GetAll(c.ctx, &bugs) 189 if err != nil { 190 c.t.Errorf("ERROR: failed to query bugs: %v", err) 191 } 192 for _, key := range keys { 193 _, err = c.GET(fmt.Sprintf("/bug?id=%v", key.StringID())) 194 c.expectOK(err) 195 } 196 // No pending emails (tests need to consume them). 197 _, err = c.GET("/cron/email_poll") 198 c.expectOK(err) 199 for len(c.emailSink) != 0 { 200 c.t.Errorf("ERROR: leftover email: %v", (<-c.emailSink).Body) 201 } 202 // No pending external reports (tests need to consume them). 203 resp, _ := c.client.ReportingPollBugs("test") 204 for _, rep := range resp.Reports { 205 c.t.Errorf("ERROR: leftover external report:\n%#v", rep) 206 } 207 } 208 unregisterContext(c) 209 validateGlobalConfig() 210 } 211 212 func (c *Ctx) advanceTime(d time.Duration) { 213 c.mockedTime = c.mockedTime.Add(d) 214 } 215 216 func (c *Ctx) setSubsystems(ns string, list []*subsystem.Subsystem, rev int) { 217 c.transformContext = func(c context.Context) context.Context { 218 newConfig := replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 219 ret := *cfg 220 ret.Subsystems.Revision = rev 221 if list == nil { 222 ret.Subsystems.Service = nil 223 } else { 224 ret.Subsystems.Service = subsystem.MustMakeService(list) 225 } 226 return &ret 227 }) 228 return contextWithConfig(c, newConfig) 229 } 230 } 231 232 func (c *Ctx) setKernelRepos(ns string, list []KernelRepo) { 233 c.transformContext = func(c context.Context) context.Context { 234 newConfig := replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 235 ret := *cfg 236 ret.Repos = list 237 return &ret 238 }) 239 return contextWithConfig(c, newConfig) 240 } 241 } 242 243 func (c *Ctx) setNoObsoletions() { 244 c.transformContext = func(c context.Context) context.Context { 245 return contextWithNoObsoletions(c) 246 } 247 } 248 249 func (c *Ctx) updateReporting(ns, name string, f func(Reporting) Reporting) { 250 c.transformContext = func(c context.Context) context.Context { 251 return contextWithConfig(c, replaceReporting(c, ns, name, f)) 252 } 253 } 254 255 func (c *Ctx) decommissionManager(ns, oldManager, newManager string) { 256 c.transformContext = func(c context.Context) context.Context { 257 newConfig := replaceManagerConfig(c, ns, oldManager, func(cfg ConfigManager) ConfigManager { 258 cfg.Decommissioned = true 259 cfg.DelegatedTo = newManager 260 return cfg 261 }) 262 return contextWithConfig(c, newConfig) 263 } 264 } 265 266 func (c *Ctx) decommission(ns string) { 267 c.transformContext = func(c context.Context) context.Context { 268 newConfig := replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 269 ret := *cfg 270 ret.Decommissioned = true 271 return &ret 272 }) 273 return contextWithConfig(c, newConfig) 274 } 275 } 276 277 func (c *Ctx) setWaitForRepro(ns string, d time.Duration) { 278 c.transformContext = func(c context.Context) context.Context { 279 newConfig := replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 280 ret := *cfg 281 ret.WaitForRepro = d 282 return &ret 283 }) 284 return contextWithConfig(c, newConfig) 285 } 286 } 287 288 // GET sends admin-authorized HTTP GET request to the app. 289 func (c *Ctx) GET(url string) ([]byte, error) { 290 return c.AuthGET(AccessAdmin, url) 291 } 292 293 // AuthGET sends HTTP GET request to the app with the specified authorization. 294 func (c *Ctx) AuthGET(access AccessLevel, url string) ([]byte, error) { 295 w, err := c.httpRequest("GET", url, "", access) 296 if err != nil { 297 return nil, err 298 } 299 return w.Body.Bytes(), nil 300 } 301 302 // POST sends admin-authorized HTTP POST requestd to the app. 303 func (c *Ctx) POST(url, body string) ([]byte, error) { 304 w, err := c.httpRequest("POST", url, body, AccessAdmin) 305 if err != nil { 306 return nil, err 307 } 308 return w.Body.Bytes(), nil 309 } 310 311 // ContentType returns the response Content-Type header value. 312 func (c *Ctx) ContentType(url string) (string, error) { 313 w, err := c.httpRequest("HEAD", url, "", AccessAdmin) 314 if err != nil { 315 return "", err 316 } 317 values := w.Header()["Content-Type"] 318 if len(values) == 0 { 319 return "", fmt.Errorf("no Content-Type") 320 } 321 return values[0], nil 322 } 323 324 func (c *Ctx) httpRequest(method, url, body string, access AccessLevel) (*httptest.ResponseRecorder, error) { 325 c.t.Logf("%v: %v", method, url) 326 r, err := c.inst.NewRequest(method, url, strings.NewReader(body)) 327 if err != nil { 328 c.t.Fatal(err) 329 } 330 r.Header.Add("X-Appengine-User-IP", "127.0.0.1") 331 r = registerRequest(r, c) 332 r = r.WithContext(c.transformContext(r.Context())) 333 if access == AccessAdmin || access == AccessUser { 334 user := &user.User{ 335 Email: "user@syzkaller.com", 336 AuthDomain: "gmail.com", 337 } 338 if access == AccessAdmin { 339 user.Admin = true 340 } 341 aetest.Login(user, r) 342 } 343 w := httptest.NewRecorder() 344 http.DefaultServeMux.ServeHTTP(w, r) 345 c.t.Logf("REPLY: %v", w.Code) 346 if w.Code != http.StatusOK { 347 return nil, &HTTPError{w.Code, w.Body.String(), w.Result().Header} 348 } 349 return w, nil 350 } 351 352 type HTTPError struct { 353 Code int 354 Body string 355 Headers http.Header 356 } 357 358 func (err *HTTPError) Error() string { 359 return fmt.Sprintf("%v: %v", err.Code, err.Body) 360 } 361 362 func (c *Ctx) loadBug(extID string) (*Bug, *Crash, *Build) { 363 bug, _, err := findBugByReportingID(c.ctx, extID) 364 if err != nil { 365 c.t.Fatalf("failed to load bug: %v", err) 366 } 367 return c.loadBugInfo(bug) 368 } 369 370 func (c *Ctx) loadBugByHash(hash string) (*Bug, *Crash, *Build) { 371 bug := new(Bug) 372 bugKey := db.NewKey(c.ctx, "Bug", hash, 0, nil) 373 c.expectOK(db.Get(c.ctx, bugKey, bug)) 374 return c.loadBugInfo(bug) 375 } 376 377 func (c *Ctx) loadBugInfo(bug *Bug) (*Bug, *Crash, *Build) { 378 crash, _, err := findCrashForBug(c.ctx, bug) 379 if err != nil { 380 c.t.Fatalf("failed to load crash: %v", err) 381 } 382 build := c.loadBuild(bug.Namespace, crash.BuildID) 383 return bug, crash, build 384 } 385 386 func (c *Ctx) loadJob(extID string) (*Job, *Build, *Crash) { 387 jobKey, err := jobID2Key(c.ctx, extID) 388 if err != nil { 389 c.t.Fatalf("failed to create job key: %v", err) 390 } 391 job := new(Job) 392 if err := db.Get(c.ctx, jobKey, job); err != nil { 393 c.t.Fatalf("failed to get job %v: %v", extID, err) 394 } 395 build := c.loadBuild(job.Namespace, job.BuildID) 396 crash := new(Crash) 397 crashKey := db.NewKey(c.ctx, "Crash", "", job.CrashID, jobKey.Parent()) 398 if err := db.Get(c.ctx, crashKey, crash); err != nil { 399 c.t.Fatalf("failed to load crash for job: %v", err) 400 } 401 return job, build, crash 402 } 403 404 func (c *Ctx) loadBuild(ns, id string) *Build { 405 build, err := loadBuild(c.ctx, ns, id) 406 c.expectOK(err) 407 return build 408 } 409 410 func (c *Ctx) loadManager(ns, name string) (*Manager, *Build) { 411 mgr, err := loadManager(c.ctx, ns, name) 412 c.expectOK(err) 413 build := c.loadBuild(ns, mgr.CurrentBuild) 414 return mgr, build 415 } 416 417 func (c *Ctx) loadSingleBug() (*Bug, *db.Key) { 418 var bugs []*Bug 419 keys, err := db.NewQuery("Bug").GetAll(c.ctx, &bugs) 420 c.expectEQ(err, nil) 421 c.expectEQ(len(bugs), 1) 422 423 return bugs[0], keys[0] 424 } 425 426 func (c *Ctx) loadSingleJob() (*Job, *db.Key) { 427 var jobs []*Job 428 keys, err := db.NewQuery("Job").GetAll(c.ctx, &jobs) 429 c.expectEQ(err, nil) 430 c.expectEQ(len(jobs), 1) 431 432 return jobs[0], keys[0] 433 } 434 435 func (c *Ctx) checkURLContents(url string, want []byte) { 436 c.t.Helper() 437 got, err := c.AuthGET(AccessAdmin, url) 438 if err != nil { 439 c.t.Fatalf("%v request failed: %v", url, err) 440 } 441 if !bytes.Equal(got, want) { 442 c.t.Fatalf("url %v: got:\n%s\nwant:\n%s\n", url, got, want) 443 } 444 } 445 446 func (c *Ctx) pollEmailBug() *aemail.Message { 447 _, err := c.GET("/cron/email_poll") 448 c.expectOK(err) 449 if len(c.emailSink) == 0 { 450 c.t.Helper() 451 c.t.Fatal("got no emails") 452 } 453 return <-c.emailSink 454 } 455 456 func (c *Ctx) pollEmailExtID() string { 457 c.t.Helper() 458 _, extBugID := c.pollEmailAndExtID() 459 return extBugID 460 } 461 462 func (c *Ctx) pollEmailAndExtID() (string, string) { 463 c.t.Helper() 464 msg := c.pollEmailBug() 465 _, extBugID, err := email.RemoveAddrContext(msg.Sender) 466 if err != nil { 467 c.t.Fatalf("failed to remove addr context: %v", err) 468 } 469 return msg.Sender, extBugID 470 } 471 472 func (c *Ctx) expectNoEmail() { 473 _, err := c.GET("/cron/email_poll") 474 c.expectOK(err) 475 if len(c.emailSink) != 0 { 476 msg := <-c.emailSink 477 c.t.Helper() 478 c.t.Fatalf("got unexpected email: %v\n%s", msg.Subject, msg.Body) 479 } 480 } 481 482 type apiClient struct { 483 *Ctx 484 *dashapi.Dashboard 485 } 486 487 func (c *Ctx) makeClient(client, key string, failOnErrors bool) *apiClient { 488 doer := func(r *http.Request) (*http.Response, error) { 489 r = registerRequest(r, c) 490 r = r.WithContext(c.transformContext(r.Context())) 491 w := httptest.NewRecorder() 492 http.DefaultServeMux.ServeHTTP(w, r) 493 res := &http.Response{ 494 StatusCode: w.Code, 495 Status: http.StatusText(w.Code), 496 Body: io.NopCloser(w.Result().Body), 497 } 498 return res, nil 499 } 500 logger := func(msg string, args ...interface{}) { 501 c.t.Logf("%v: "+msg, append([]interface{}{caller(3)}, args...)...) 502 } 503 errorHandler := func(err error) { 504 if failOnErrors { 505 c.t.Fatalf("\n%v: %v", caller(2), err) 506 } 507 } 508 dash, err := dashapi.NewCustom(client, "", key, c.inst.NewRequest, doer, logger, errorHandler) 509 if err != nil { 510 panic(fmt.Sprintf("Impossible error: %v", err)) 511 } 512 return &apiClient{ 513 Ctx: c, 514 Dashboard: dash, 515 } 516 } 517 518 func (client *apiClient) pollBugs(expect int) []*dashapi.BugReport { 519 resp, _ := client.ReportingPollBugs("test") 520 if len(resp.Reports) != expect { 521 client.t.Helper() 522 client.t.Fatalf("want %v reports, got %v", expect, len(resp.Reports)) 523 } 524 for _, rep := range resp.Reports { 525 reproLevel := dashapi.ReproLevelNone 526 if len(rep.ReproC) != 0 { 527 reproLevel = dashapi.ReproLevelC 528 } else if len(rep.ReproSyz) != 0 { 529 reproLevel = dashapi.ReproLevelSyz 530 } 531 reply, _ := client.ReportingUpdate(&dashapi.BugUpdate{ 532 ID: rep.ID, 533 JobID: rep.JobID, 534 Status: dashapi.BugStatusOpen, 535 ReproLevel: reproLevel, 536 CrashID: rep.CrashID, 537 }) 538 client.expectEQ(reply.Error, false) 539 client.expectEQ(reply.OK, true) 540 } 541 return resp.Reports 542 } 543 544 func (client *apiClient) pollBug() *dashapi.BugReport { 545 return client.pollBugs(1)[0] 546 } 547 548 func (client *apiClient) pollNotifs(expect int) []*dashapi.BugNotification { 549 resp, _ := client.ReportingPollNotifications("test") 550 if len(resp.Notifications) != expect { 551 client.t.Helper() 552 client.t.Fatalf("want %v notifs, got %v", expect, len(resp.Notifications)) 553 } 554 return resp.Notifications 555 } 556 557 func (client *apiClient) updateBug(extID string, status dashapi.BugStatus, dup string) { 558 reply, _ := client.ReportingUpdate(&dashapi.BugUpdate{ 559 ID: extID, 560 Status: status, 561 DupOf: dup, 562 }) 563 client.expectTrue(reply.OK) 564 } 565 566 func (client *apiClient) pollSpecificJobs(manager string, jobs dashapi.ManagerJobs) *dashapi.JobPollResp { 567 req := &dashapi.JobPollReq{ 568 Managers: map[string]dashapi.ManagerJobs{ 569 manager: jobs, 570 }, 571 } 572 resp, err := client.JobPoll(req) 573 client.expectOK(err) 574 return resp 575 } 576 577 func (client *apiClient) pollJobs(manager string) *dashapi.JobPollResp { 578 return client.pollSpecificJobs(manager, dashapi.ManagerJobs{ 579 TestPatches: true, 580 BisectCause: true, 581 BisectFix: true, 582 }) 583 } 584 585 func (client *apiClient) pollAndFailBisectJob(manager string) { 586 resp := client.pollJobs(manager) 587 client.expectNE(resp.ID, "") 588 client.expectEQ(resp.Type, dashapi.JobBisectCause) 589 done := &dashapi.JobDoneReq{ 590 ID: resp.ID, 591 Error: []byte("pollAndFailBisectJob"), 592 } 593 client.expectOK(client.JobDone(done)) 594 } 595 596 type ( 597 EmailOptMessageID int 598 EmailOptSubject string 599 EmailOptFrom string 600 EmailOptOrigFrom string 601 EmailOptCC []string 602 EmailOptSender string 603 ) 604 605 func (c *Ctx) incomingEmail(to, body string, opts ...interface{}) { 606 id := 0 607 subject := "crash1" 608 from := "default@sender.com" 609 cc := []string{"test@syzkaller.com", "bugs@syzkaller.com", "bugs2@syzkaller.com"} 610 sender := "" 611 origFrom := "" 612 for _, o := range opts { 613 switch opt := o.(type) { 614 case EmailOptMessageID: 615 id = int(opt) 616 case EmailOptSubject: 617 subject = string(opt) 618 case EmailOptFrom: 619 from = string(opt) 620 case EmailOptSender: 621 sender = string(opt) 622 case EmailOptCC: 623 cc = []string(opt) 624 case EmailOptOrigFrom: 625 origFrom = fmt.Sprintf("\nX-Original-From: %v", string(opt)) 626 } 627 } 628 if sender == "" { 629 sender = from 630 } 631 email := fmt.Sprintf(`Sender: %v 632 Date: Tue, 15 Aug 2017 14:59:00 -0700 633 Message-ID: <%v> 634 Subject: %v 635 From: %v 636 Cc: %v 637 To: %v%v 638 Content-Type: text/plain 639 640 %v 641 `, sender, id, subject, from, strings.Join(cc, ","), to, origFrom, body) 642 log.Infof(c.ctx, "sending %s", email) 643 _, err := c.POST("/_ah/mail/email@server.com", email) 644 c.expectOK(err) 645 } 646 647 func initMocks() { 648 // Mock time as some functionality relies on real time. 649 timeNow = func(c context.Context) time.Time { 650 return getRequestContext(c).mockedTime 651 } 652 sendEmail = func(c context.Context, msg *aemail.Message) error { 653 getRequestContext(c).emailSink <- msg 654 return nil 655 } 656 maxCrashes = func() int { 657 // dev_appserver is very slow, so let's make tests smaller. 658 const maxCrashesDuringTest = 20 659 return maxCrashesDuringTest 660 } 661 } 662 663 // Machinery to associate mocked time with requests. 664 type RequestMapping struct { 665 id int 666 ctx *Ctx 667 } 668 669 var ( 670 requestMu sync.Mutex 671 requestNum int 672 requestContexts []RequestMapping 673 ) 674 675 func registerRequest(r *http.Request, c *Ctx) *http.Request { 676 requestMu.Lock() 677 defer requestMu.Unlock() 678 679 requestNum++ 680 newContext := context.WithValue(r.Context(), requestIDKey{}, requestNum) 681 newRequest := r.WithContext(newContext) 682 requestContexts = append(requestContexts, RequestMapping{requestNum, c}) 683 return newRequest 684 } 685 686 func getRequestContext(c context.Context) *Ctx { 687 requestMu.Lock() 688 defer requestMu.Unlock() 689 reqID := getRequestID(c) 690 for _, m := range requestContexts { 691 if m.id == reqID { 692 return m.ctx 693 } 694 } 695 panic(fmt.Sprintf("no context for: %#v", c)) 696 } 697 698 func unregisterContext(c *Ctx) { 699 requestMu.Lock() 700 defer requestMu.Unlock() 701 n := 0 702 for _, m := range requestContexts { 703 if m.ctx == c { 704 continue 705 } 706 requestContexts[n] = m 707 n++ 708 } 709 requestContexts = requestContexts[:n] 710 } 711 712 type requestIDKey struct{} 713 714 func getRequestID(c context.Context) int { 715 val, ok := c.Value(requestIDKey{}).(int) 716 if !ok { 717 panic("the context did not come from a test") 718 } 719 return val 720 } 721 722 // Create a shallow copy of GlobalConfig with a replaced namespace config. 723 func replaceNamespaceConfig(c context.Context, ns string, f func(*Config) *Config) *GlobalConfig { 724 ret := *getConfig(c) 725 newNsMap := map[string]*Config{} 726 for name, nsCfg := range ret.Namespaces { 727 if name == ns { 728 nsCfg = f(nsCfg) 729 } 730 newNsMap[name] = nsCfg 731 } 732 ret.Namespaces = newNsMap 733 return &ret 734 } 735 736 func replaceManagerConfig(c context.Context, ns, mgr string, f func(ConfigManager) ConfigManager) *GlobalConfig { 737 return replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 738 ret := *cfg 739 newMgrMap := map[string]ConfigManager{} 740 for name, mgrCfg := range ret.Managers { 741 if name == mgr { 742 mgrCfg = f(mgrCfg) 743 } 744 newMgrMap[name] = mgrCfg 745 } 746 ret.Managers = newMgrMap 747 return &ret 748 }) 749 } 750 751 func replaceReporting(c context.Context, ns, name string, f func(Reporting) Reporting) *GlobalConfig { 752 return replaceNamespaceConfig(c, ns, func(cfg *Config) *Config { 753 ret := *cfg 754 var newReporting []Reporting 755 for _, cfg := range ret.Reporting { 756 if cfg.Name == name { 757 cfg = f(cfg) 758 } 759 newReporting = append(newReporting, cfg) 760 } 761 ret.Reporting = newReporting 762 return &ret 763 }) 764 }