github.com/letsencrypt/boulder@v0.20251208.0/web/context_test.go (about) 1 package web 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "fmt" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/letsencrypt/boulder/features" 15 blog "github.com/letsencrypt/boulder/log" 16 "github.com/letsencrypt/boulder/test" 17 ) 18 19 type myHandler struct{} 20 21 func (m myHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) { 22 w.WriteHeader(201) 23 e.Endpoint = "/endpoint" 24 _, _ = w.Write([]byte("hi")) 25 } 26 27 func TestLogCode(t *testing.T) { 28 mockLog := blog.UseMock() 29 th := NewTopHandler(mockLog, myHandler{}) 30 req, err := http.NewRequest("GET", "/thisisignored", &bytes.Reader{}) 31 if err != nil { 32 t.Fatal(err) 33 } 34 th.ServeHTTP(httptest.NewRecorder(), req) 35 expected := `INFO: GET /endpoint 0 201 0 0.0.0.0 JSON={}` 36 if len(mockLog.GetAllMatching(expected)) != 1 { 37 t.Errorf("Expected exactly one log line matching %q. Got \n%s", 38 expected, strings.Join(mockLog.GetAllMatching(".*"), "\n")) 39 } 40 } 41 42 type codeHandler struct{} 43 44 func (ch codeHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) { 45 e.Endpoint = "/endpoint" 46 _, _ = w.Write([]byte("hi")) 47 } 48 49 func TestStatusCodeLogging(t *testing.T) { 50 mockLog := blog.UseMock() 51 th := NewTopHandler(mockLog, codeHandler{}) 52 req, err := http.NewRequest("GET", "/thisisignored", &bytes.Reader{}) 53 if err != nil { 54 t.Fatal(err) 55 } 56 th.ServeHTTP(httptest.NewRecorder(), req) 57 expected := `INFO: GET /endpoint 0 200 0 0.0.0.0 JSON={}` 58 if len(mockLog.GetAllMatching(expected)) != 1 { 59 t.Errorf("Expected exactly one log line matching %q. Got \n%s", 60 expected, strings.Join(mockLog.GetAllMatching(".*"), "\n")) 61 } 62 } 63 64 func TestOrigin(t *testing.T) { 65 mockLog := blog.UseMock() 66 th := NewTopHandler(mockLog, myHandler{}) 67 req, err := http.NewRequest("GET", "/thisisignored", &bytes.Reader{}) 68 if err != nil { 69 t.Fatal(err) 70 } 71 req.Header.Add("Origin", "https://example.com") 72 th.ServeHTTP(httptest.NewRecorder(), req) 73 expected := `INFO: GET /endpoint 0 201 0 0.0.0.0 JSON={.*"Origin":"https://example.com"}` 74 if len(mockLog.GetAllMatching(expected)) != 1 { 75 t.Errorf("Expected exactly one log line matching %q. Got \n%s", 76 expected, strings.Join(mockLog.GetAllMatching(".*"), "\n")) 77 } 78 } 79 80 type hostHeaderHandler struct { 81 f func(*RequestEvent, http.ResponseWriter, *http.Request) 82 } 83 84 func (hhh hostHeaderHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) { 85 hhh.f(e, w, r) 86 } 87 88 func TestHostHeaderRewrite(t *testing.T) { 89 mockLog := blog.UseMock() 90 hhh := hostHeaderHandler{f: func(_ *RequestEvent, _ http.ResponseWriter, r *http.Request) { 91 t.Helper() 92 test.AssertEquals(t, r.Host, "localhost") 93 }} 94 th := NewTopHandler(mockLog, &hhh) 95 96 req, err := http.NewRequest("GET", "/", &bytes.Reader{}) 97 test.AssertNotError(t, err, "http.NewRequest failed") 98 req.Host = "localhost:80" 99 fmt.Println("here") 100 th.ServeHTTP(httptest.NewRecorder(), req) 101 102 req, err = http.NewRequest("GET", "/", &bytes.Reader{}) 103 test.AssertNotError(t, err, "http.NewRequest failed") 104 req.Host = "localhost:443" 105 req.TLS = &tls.ConnectionState{} 106 th.ServeHTTP(httptest.NewRecorder(), req) 107 108 req, err = http.NewRequest("GET", "/", &bytes.Reader{}) 109 test.AssertNotError(t, err, "http.NewRequest failed") 110 req.Host = "localhost:443" 111 req.TLS = nil 112 th.ServeHTTP(httptest.NewRecorder(), req) 113 114 hhh.f = func(_ *RequestEvent, _ http.ResponseWriter, r *http.Request) { 115 t.Helper() 116 test.AssertEquals(t, r.Host, "localhost:123") 117 } 118 req, err = http.NewRequest("GET", "/", &bytes.Reader{}) 119 test.AssertNotError(t, err, "http.NewRequest failed") 120 req.Host = "localhost:123" 121 th.ServeHTTP(httptest.NewRecorder(), req) 122 } 123 124 type cancelHandler struct { 125 res chan string 126 } 127 128 func (ch cancelHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) { 129 select { 130 case <-r.Context().Done(): 131 ch.res <- r.Context().Err().Error() 132 case <-time.After(300 * time.Millisecond): 133 ch.res <- "300 ms passed" 134 } 135 } 136 137 func TestPropagateCancel(t *testing.T) { 138 mockLog := blog.UseMock() 139 res := make(chan string) 140 features.Set(features.Config{PropagateCancels: true}) 141 th := NewTopHandler(mockLog, cancelHandler{res}) 142 ctx, cancel := context.WithCancel(context.Background()) 143 go func() { 144 req, err := http.NewRequestWithContext(ctx, "GET", "/thisisignored", &bytes.Reader{}) 145 if err != nil { 146 t.Error(err) 147 } 148 th.ServeHTTP(httptest.NewRecorder(), req) 149 }() 150 cancel() 151 result := <-res 152 if result != "context canceled" { 153 t.Errorf("expected 'context canceled', got %q", result) 154 } 155 }