github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/tests/integration/mux/mux_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mux_test
    16  
    17  import (
    18  	"html/template"
    19  	"math"
    20  	"net/http/httptest"
    21  	"testing"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  	"github.com/google/go-safeweb/safehttp"
    25  	"github.com/google/safehtml"
    26  	safetemplate "github.com/google/safehtml/template"
    27  )
    28  
    29  func TestMuxDefaultDispatcher(t *testing.T) {
    30  	tests := []struct {
    31  		name        string
    32  		handler     safehttp.Handler
    33  		wantHeaders map[string][]string
    34  		wantBody    string
    35  	}{
    36  		{
    37  			name: "Safe HTML Response",
    38  			handler: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    39  				return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
    40  			}),
    41  			wantHeaders: map[string][]string{
    42  				"Content-Type": {"text/html; charset=utf-8"},
    43  			},
    44  			wantBody: "&lt;h1&gt;Hello World!&lt;/h1&gt;",
    45  		},
    46  		{
    47  			name: "Safe HTML Template Response",
    48  			handler: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    49  				return safehttp.ExecuteTemplate(w, safetemplate.
    50  					Must(safetemplate.New("name").
    51  						Parse("<h1>{{ . }}</h1>")), "This is an actual heading, though.")
    52  			}),
    53  			wantHeaders: map[string][]string{
    54  				"Content-Type": {"text/html; charset=utf-8"},
    55  			},
    56  			wantBody: "<h1>This is an actual heading, though.</h1>",
    57  		},
    58  		{
    59  			name: "Valid JSON Response",
    60  			handler: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    61  				data := struct {
    62  					Field string `json:"field"`
    63  				}{Field: "myField"}
    64  				return safehttp.WriteJSON(w, data)
    65  			}),
    66  			wantHeaders: map[string][]string{
    67  				"Content-Type": {"application/json; charset=utf-8"},
    68  			},
    69  			wantBody: ")]}',\n{\"field\":\"myField\"}\n",
    70  		},
    71  	}
    72  	for _, tt := range tests {
    73  		t.Run(tt.name, func(t *testing.T) {
    74  			mb := safehttp.NewServeMuxConfig(nil)
    75  			mux := mb.Mux()
    76  
    77  			mux.Handle("/pizza", safehttp.MethodGet, tt.handler)
    78  
    79  			rw := httptest.NewRecorder()
    80  			req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/pizza", nil)
    81  
    82  			mux.ServeHTTP(rw, req)
    83  
    84  			if wantStatus := safehttp.StatusOK; rw.Code != int(wantStatus) {
    85  				t.Errorf("rw.Code: got %v want %v", rw.Code, wantStatus)
    86  			}
    87  
    88  			if diff := cmp.Diff(tt.wantHeaders, map[string][]string(rw.Header())); diff != "" {
    89  				t.Errorf("rw.Header mismatch (-want +got):\n%s", diff)
    90  			}
    91  
    92  			if gotBody := rw.Body.String(); tt.wantBody != gotBody {
    93  				t.Errorf("response body: got %v, want %v", gotBody, tt.wantBody)
    94  			}
    95  		})
    96  	}
    97  }
    98  
    99  func TestMuxDefaultDispatcherUnsafeResponses(t *testing.T) {
   100  	tests := []struct {
   101  		name    string
   102  		handler safehttp.Handler
   103  	}{
   104  		{
   105  			name: "Unsafe HTML Response",
   106  			handler: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   107  				return w.Write("<h1>Hello World!</h1>")
   108  			}),
   109  		},
   110  		{
   111  			name: "Unsafe Template Response",
   112  			handler: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   113  				return safehttp.ExecuteTemplate(w, template.
   114  					Must(template.New("name").
   115  						Parse("<h1>{{ . }}</h1>")), "This is an actual heading, though.")
   116  			}),
   117  		},
   118  		{
   119  			name: "Invalid JSON Response",
   120  			handler: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   121  				return safehttp.WriteJSON(w, math.Inf(1))
   122  			}),
   123  		},
   124  	}
   125  	for _, tt := range tests {
   126  		t.Run(tt.name, func(t *testing.T) {
   127  			// TODO: Unskip these test cases and combine them with the test
   128  			// cases from the previous test into a single table test after
   129  			// error-handling in the ResponseWriter has been fixed.
   130  			t.Skip()
   131  
   132  			mb := safehttp.NewServeMuxConfig(nil)
   133  			mux := mb.Mux()
   134  
   135  			mux.Handle("/pizza", safehttp.MethodGet, tt.handler)
   136  
   137  			rw := httptest.NewRecorder()
   138  			req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/pizza", nil)
   139  
   140  			mux.ServeHTTP(rw, req)
   141  
   142  			if wantStatus := safehttp.StatusInternalServerError; rw.Code != int(wantStatus) {
   143  				t.Errorf("rw.Code: got %v want %v", rw.Code, wantStatus)
   144  			}
   145  
   146  			wantHeaders := map[string][]string{
   147  				"Content-Type":           {"text/plain; charset=utf-8"},
   148  				"X-Content-Type-Options": {"nosniff"},
   149  			}
   150  			if diff := cmp.Diff(wantHeaders, map[string][]string(rw.Header())); diff != "" {
   151  				t.Errorf("rw.Header(): mismatch (-want +got):\n%s", diff)
   152  			}
   153  
   154  			if wantBody, gotBody := "Internal Server Error\n", rw.Body.String(); wantBody != gotBody {
   155  				t.Errorf("response body: got %v, want %v", gotBody, wantBody)
   156  			}
   157  		})
   158  	}
   159  }