github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/flightvalues_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp_test
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"net/http/httptest"
    21  	"testing"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  	"github.com/google/go-safeweb/safehttp"
    25  	"github.com/google/safehtml"
    26  )
    27  
    28  type safeHeadersInterceptor struct{}
    29  
    30  func (ip *safeHeadersInterceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
    31  	// We claim the header here in order to protect it from being tampered. The
    32  	// only way to set it is through a helper method exposed by this package. It
    33  	// only allows for setting safe values.
    34  	setter := w.Header().Claim("Super-Safe-Header")
    35  	safehttp.FlightValues(r.Context()).Put(safeHeaderKey{}, setter)
    36  	return safehttp.NotWritten()
    37  }
    38  
    39  func (ip *safeHeadersInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
    40  }
    41  
    42  func (ip *safeHeadersInterceptor) Match(_ safehttp.InterceptorConfig) bool {
    43  	// This interceptor does not offer any configuration options.
    44  	return false
    45  }
    46  
    47  type safeHeaderKey struct{}
    48  
    49  func SetHeaderSafely(ctx context.Context, level int) {
    50  	var value string
    51  	switch level {
    52  	case 0:
    53  		value = "Safe"
    54  	case 1:
    55  		value = "VerySafe"
    56  	case 2:
    57  		value = "VeryVerySafe"
    58  	default:
    59  		value = "Safe"
    60  	}
    61  	setter := safehttp.FlightValues(ctx).Get(safeHeaderKey{}).(func([]string))
    62  	setter([]string{value})
    63  }
    64  
    65  func handlerInteractingWithTheInterceptor(w safehttp.ResponseWriter, req *safehttp.IncomingRequest) safehttp.Result {
    66  	f, err := req.URL().Query()
    67  	if err != nil {
    68  		panic(err)
    69  	}
    70  	safety := f.Int64("level", 0)
    71  	SetHeaderSafely(req.Context(), int(safety))
    72  
    73  	return w.Write(safehtml.HTMLEscaped(fmt.Sprintf("Safety header set to %v", safety)))
    74  }
    75  
    76  func TestHandlerInteractingWithInterceptor(t *testing.T) {
    77  	mb := safehttp.NewServeMuxConfig(nil)
    78  	mb.Intercept(&safeHeadersInterceptor{})
    79  	m := mb.Mux()
    80  
    81  	m.Handle("/safety", safehttp.MethodGet, safehttp.HandlerFunc(handlerInteractingWithTheInterceptor))
    82  
    83  	rr := httptest.NewRecorder()
    84  
    85  	req := httptest.NewRequest(safehttp.MethodGet, "https://foo.com/safety?level=2", nil)
    86  	m.ServeHTTP(rr, req)
    87  
    88  	if got, want := rr.Code, safehttp.StatusOK; got != int(want) {
    89  		t.Errorf("rr.Code got: %v want: %v", got, want)
    90  	}
    91  
    92  	want := `Safety header set to 2`
    93  	if diff := cmp.Diff(want, rr.Body.String()); diff != "" {
    94  		t.Errorf("response body diff (-want,+got): \n%s\ngot %q, want %q", diff, rr.Body.String(), want)
    95  	}
    96  
    97  	wantHeaders := map[string][]string{
    98  		"Content-Type":      {"text/html; charset=utf-8"},
    99  		"Super-Safe-Header": {"VeryVerySafe"},
   100  	}
   101  	if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   102  		t.Errorf("rr.Header mismatch (-want +got):\n%s", diff)
   103  	}
   104  }