github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/tests/integration/errors/errors_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package errors_test 16 17 import ( 18 "fmt" 19 "math" 20 "net/http" 21 "net/http/httptest" 22 "testing" 23 24 "github.com/google/go-cmp/cmp" 25 "github.com/google/safehtml" 26 27 "github.com/google/go-safeweb/safehttp" 28 "github.com/google/safehtml/template" 29 ) 30 31 var myErrorTmpl = template.Must(template.New("not found").Parse(`<h1>Error: {{ .Code }}</h1> 32 <p>{{ .Message }}</p> 33 `)) 34 35 type myError struct { 36 safehttp.StatusCode 37 Message string 38 } 39 40 type myDispatcher struct { 41 safehttp.DefaultDispatcher 42 } 43 44 func (m myDispatcher) Error(rw http.ResponseWriter, resp safehttp.ErrorResponse) error { 45 if x, ok := resp.(myError); ok { 46 rw.Header().Set("Content-Type", "text/html; charset=utf-8") 47 rw.WriteHeader(int(x.StatusCode)) 48 return myErrorTmpl.Execute(rw, x) 49 } 50 return m.DefaultDispatcher.Error(rw, resp) 51 } 52 53 // TestCustomErrors tests a scenario where custom errors are implemented using a 54 // custom dispatcher implementation. XSRF interceptor is added to check that 55 // error responses go through the commit phase. 56 func TestCustomErrors(t *testing.T) { 57 mb := safehttp.NewServeMuxConfig(myDispatcher{}) 58 mux := mb.Mux() 59 60 mux.Handle("/compute", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 61 qs, err := r.URL().Query() 62 if err != nil { 63 return w.WriteError(safehttp.StatusBadRequest) 64 } 65 a := qs.Int64("a", math.MaxInt64) 66 if qs.Err() != nil || a == math.MaxInt64 { 67 return w.WriteError(myError{StatusCode: safehttp.StatusBadRequest, Message: "missing parameter 'a'"}) 68 } 69 if a > 10 { 70 return w.WriteError(myError{StatusCode: safehttp.StatusNotImplemented, Message: "we can't process queries with large numbers yet"}) 71 } 72 return w.Write(safehtml.HTMLEscaped(fmt.Sprintf("Result: %d", a*a))) 73 })) 74 75 t.Run("correct request", func(t *testing.T) { 76 rr := httptest.NewRecorder() 77 78 req := httptest.NewRequest(safehttp.MethodGet, "https://foo.com/compute?a=3", nil) 79 mux.ServeHTTP(rr, req) 80 81 if got, want := rr.Code, safehttp.StatusOK; got != int(want) { 82 t.Errorf("rr.Code got: %v want: %v", got, want) 83 } 84 want := "Result: 9" 85 if diff := cmp.Diff(want, rr.Body.String()); diff != "" { 86 t.Errorf("response body diff (-want,+got): \n%s\ngot %q, want %q", diff, rr.Body.String(), want) 87 } 88 }) 89 90 t.Run("missing parameter", func(t *testing.T) { 91 rr := httptest.NewRecorder() 92 93 req := httptest.NewRequest(safehttp.MethodGet, "https://foo.com/compute?foo=3", nil) 94 mux.ServeHTTP(rr, req) 95 96 if got, want := rr.Code, safehttp.StatusBadRequest; got != int(want) { 97 t.Errorf("rr.Code got: %v want: %v", got, want) 98 } 99 100 want := `<h1>Error: Bad Request</h1> 101 <p>missing parameter 'a'</p> 102 ` 103 if diff := cmp.Diff(want, rr.Body.String()); diff != "" { 104 t.Errorf("response body diff (-want,+got): \n%s\ngot %q, want %q", diff, rr.Body.String(), want) 105 } 106 }) 107 108 } 109 110 type interceptor struct { 111 errBefore bool 112 } 113 114 func (it interceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, _ safehttp.InterceptorConfig) safehttp.Result { 115 if it.errBefore { 116 return w.WriteError(myError{StatusCode: safehttp.StatusForbidden, Message: "forbidden in Before"}) 117 } 118 return safehttp.NotWritten() 119 } 120 121 func (interceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, _ safehttp.InterceptorConfig) { 122 } 123 124 func (interceptor) Match(safehttp.InterceptorConfig) bool { 125 return false 126 } 127 128 func TestCustomErrorsInBefore(t *testing.T) { 129 mb := safehttp.NewServeMuxConfig(myDispatcher{}) 130 mb.Intercept(interceptor{errBefore: true}) 131 mux := mb.Mux() 132 133 mux.Handle("/compute", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 134 return w.Write(safehtml.HTMLEscaped("the handler code doesn't matter in this test case")) 135 })) 136 137 t.Run("error in Before", func(t *testing.T) { 138 rr := httptest.NewRecorder() 139 140 req := httptest.NewRequest(safehttp.MethodGet, "https://foo.com/compute?a=3", nil) 141 mux.ServeHTTP(rr, req) 142 143 if got, want := rr.Code, safehttp.StatusForbidden; got != int(want) { 144 t.Errorf("rr.Code got: %v want: %v", got, want) 145 } 146 want := `<h1>Error: Forbidden</h1> 147 <p>forbidden in Before</p> 148 ` 149 if diff := cmp.Diff(want, rr.Body.String()); diff != "" { 150 t.Errorf("response body diff (-want,+got): \n%s\ngot %q, want %q", diff, rr.Body.String(), want) 151 } 152 }) 153 }