github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/flight_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package safehttp_test 16 17 import ( 18 "fmt" 19 "net/http/httptest" 20 "testing" 21 22 "github.com/google/go-safeweb/safehttp" 23 "github.com/google/safehtml" 24 ) 25 26 type panickingInterceptor struct { 27 before, commit, onError bool 28 } 29 30 func (p panickingInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 31 if p.before { 32 panic("before") 33 } 34 return safehttp.NotWritten() 35 } 36 37 func (p panickingInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 38 if p.commit { 39 panic("commit") 40 } 41 } 42 43 func (panickingInterceptor) Match(safehttp.InterceptorConfig) bool { 44 return false 45 } 46 47 func TestFlightInterceptorPanic(t *testing.T) { 48 tests := []struct { 49 desc string 50 interceptor panickingInterceptor 51 wantPanic bool 52 }{ 53 { 54 desc: "panic in Before", 55 interceptor: panickingInterceptor{before: true}, 56 wantPanic: true, 57 }, 58 { 59 desc: "panic in Commit", 60 interceptor: panickingInterceptor{commit: true}, 61 wantPanic: true, 62 }, 63 } 64 for _, tc := range tests { 65 t.Run(tc.desc, func(t *testing.T) { 66 mb := safehttp.NewServeMuxConfig(nil) 67 mb.Intercept(tc.interceptor) 68 mux := mb.Mux() 69 70 mux.Handle("/search", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 71 // IMPORTANT: We are setting the header here and expecting to be 72 // cleared if a panic occurs. 73 w.Header().Set("foo", "bar") 74 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 75 })) 76 77 req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/search", nil) 78 rw := httptest.NewRecorder() 79 80 defer func() { 81 r := recover() 82 if !tc.wantPanic { 83 if r != nil { 84 t.Fatalf("unexpected panic %v", r) 85 } 86 return 87 } 88 if r == nil { 89 t.Fatal("expected panic") 90 } 91 // Good, the panic got propagated. 92 if len(rw.Header()) > 0 { 93 t.Errorf("ResponseWriter.Header() got %v, want empty", rw.Header()) 94 } 95 }() 96 mux.ServeHTTP(rw, req) 97 }) 98 } 99 } 100 101 func TestFlightHandlerPanic(t *testing.T) { 102 mb := safehttp.NewServeMuxConfig(nil) 103 mux := mb.Mux() 104 105 mux.Handle("/search", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 106 // IMPORTANT: We are setting the header here and expecting to be 107 // cleared if a panic occurs. 108 w.Header().Set("foo", "bar") 109 panic("handler") 110 })) 111 112 req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/search", nil) 113 rw := httptest.NewRecorder() 114 115 defer func() { 116 r := recover() 117 if r == nil { 118 t.Fatalf("expected panic") 119 } 120 // Good, the panic got propagated. 121 if len(rw.Header()) > 0 { 122 t.Errorf("ResponseWriter.Header() got %v, want empty", rw.Header()) 123 } 124 }() 125 mux.ServeHTTP(rw, req) 126 } 127 128 func TestFlightDoubleWritePanics(t *testing.T) { 129 writeFuncs := map[string]func(safehttp.ResponseWriter, *safehttp.IncomingRequest) safehttp.Result{ 130 "Write": func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 131 return w.Write(safehtml.HTMLEscaped("Hello")) 132 }, 133 "WriteError": func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 134 return w.WriteError(safehttp.StatusPreconditionFailed) 135 }, 136 } 137 138 for firstWriteName, firstWrite := range writeFuncs { 139 for secondWriteName, secondWrite := range writeFuncs { 140 t.Run(fmt.Sprintf("%s->%s", firstWriteName, secondWriteName), func(t *testing.T) { 141 mb := safehttp.NewServeMuxConfig(nil) 142 mux := mb.Mux() 143 mux.Handle("/search", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 144 firstWrite(w, r) 145 secondWrite(w, r) // this should panic 146 t.Fatal("should never reach this point") 147 return safehttp.Result{} 148 })) 149 150 req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/search", nil) 151 rw := httptest.NewRecorder() 152 defer func() { 153 if r := recover(); r == nil { 154 t.Fatalf("expected panic") 155 } 156 // Good, the panic got propagated. 157 // Note: we are not testing the response headers here, as the first write might have already succeeded. 158 }() 159 mux.ServeHTTP(rw, req) 160 }) 161 162 } 163 } 164 165 }