github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/flightvalues_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package safehttp_test 16 17 import ( 18 "context" 19 "fmt" 20 "net/http/httptest" 21 "testing" 22 23 "github.com/google/go-cmp/cmp" 24 "github.com/google/go-safeweb/safehttp" 25 "github.com/google/safehtml" 26 ) 27 28 type safeHeadersInterceptor struct{} 29 30 func (ip *safeHeadersInterceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 31 // We claim the header here in order to protect it from being tampered. The 32 // only way to set it is through a helper method exposed by this package. It 33 // only allows for setting safe values. 34 setter := w.Header().Claim("Super-Safe-Header") 35 safehttp.FlightValues(r.Context()).Put(safeHeaderKey{}, setter) 36 return safehttp.NotWritten() 37 } 38 39 func (ip *safeHeadersInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 40 } 41 42 func (ip *safeHeadersInterceptor) Match(_ safehttp.InterceptorConfig) bool { 43 // This interceptor does not offer any configuration options. 44 return false 45 } 46 47 type safeHeaderKey struct{} 48 49 func SetHeaderSafely(ctx context.Context, level int) { 50 var value string 51 switch level { 52 case 0: 53 value = "Safe" 54 case 1: 55 value = "VerySafe" 56 case 2: 57 value = "VeryVerySafe" 58 default: 59 value = "Safe" 60 } 61 setter := safehttp.FlightValues(ctx).Get(safeHeaderKey{}).(func([]string)) 62 setter([]string{value}) 63 } 64 65 func handlerInteractingWithTheInterceptor(w safehttp.ResponseWriter, req *safehttp.IncomingRequest) safehttp.Result { 66 f, err := req.URL().Query() 67 if err != nil { 68 panic(err) 69 } 70 safety := f.Int64("level", 0) 71 SetHeaderSafely(req.Context(), int(safety)) 72 73 return w.Write(safehtml.HTMLEscaped(fmt.Sprintf("Safety header set to %v", safety))) 74 } 75 76 func TestHandlerInteractingWithInterceptor(t *testing.T) { 77 mb := safehttp.NewServeMuxConfig(nil) 78 mb.Intercept(&safeHeadersInterceptor{}) 79 m := mb.Mux() 80 81 m.Handle("/safety", safehttp.MethodGet, safehttp.HandlerFunc(handlerInteractingWithTheInterceptor)) 82 83 rr := httptest.NewRecorder() 84 85 req := httptest.NewRequest(safehttp.MethodGet, "https://foo.com/safety?level=2", nil) 86 m.ServeHTTP(rr, req) 87 88 if got, want := rr.Code, safehttp.StatusOK; got != int(want) { 89 t.Errorf("rr.Code got: %v want: %v", got, want) 90 } 91 92 want := `Safety header set to 2` 93 if diff := cmp.Diff(want, rr.Body.String()); diff != "" { 94 t.Errorf("response body diff (-want,+got): \n%s\ngot %q, want %q", diff, rr.Body.String(), want) 95 } 96 97 wantHeaders := map[string][]string{ 98 "Content-Type": {"text/html; charset=utf-8"}, 99 "Super-Safe-Header": {"VeryVerySafe"}, 100 } 101 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 102 t.Errorf("rr.Header mismatch (-want +got):\n%s", diff) 103 } 104 }