github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/default_dispatcher_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp_test
    16  
    17  import (
    18  	"html/template"
    19  	"math"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"github.com/google/go-safeweb/safehttp"
    26  	"github.com/google/safehtml"
    27  	safetemplate "github.com/google/safehtml/template"
    28  )
    29  
    30  func TestDefaultDispatcherValidResponse(t *testing.T) {
    31  	tests := []struct {
    32  		name        string
    33  		write       func(w http.ResponseWriter) error
    34  		wantStatus  safehttp.StatusCode
    35  		wantHeaders map[string][]string
    36  		wantBody    string
    37  	}{
    38  		{
    39  			name: "Safe HTML Response",
    40  			write: func(w http.ResponseWriter) error {
    41  				d := &safehttp.DefaultDispatcher{}
    42  				return d.Write(w, safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
    43  			},
    44  			wantBody: "&lt;h1&gt;Hello World!&lt;/h1&gt;",
    45  		},
    46  		{
    47  			name: "Safe HTML Template Response",
    48  			write: func(w http.ResponseWriter) error {
    49  				d := &safehttp.DefaultDispatcher{}
    50  				t := safehttp.Template(safetemplate.
    51  					Must(safetemplate.New("name").
    52  						Parse("<h1>{{ . }}</h1>")))
    53  				var data interface{}
    54  				data = "This is an actual heading, though."
    55  				return d.Write(w, &safehttp.TemplateResponse{Template: t, Data: data})
    56  			},
    57  			wantBody: "<h1>This is an actual heading, though.</h1>",
    58  		},
    59  		{
    60  			name: "Named Safe HTML Template Response",
    61  			write: func(w http.ResponseWriter) error {
    62  				d := &safehttp.DefaultDispatcher{}
    63  				t := safehttp.Template(
    64  					safetemplate.Must(
    65  						safetemplate.Must(safetemplate.New("name").Parse("<h1>{{ . }}</h1>")).
    66  							New("associated").Parse("<h2>{{.}}</h2>")))
    67  				var data interface{}
    68  				data = "This is an actual heading, though."
    69  				return d.Write(w, &safehttp.TemplateResponse{t, "associated", data, nil})
    70  			},
    71  			wantBody: "<h2>This is an actual heading, though.</h2>",
    72  		},
    73  		{
    74  			name: "Safe HTML Template Response with Token",
    75  			write: func(w http.ResponseWriter) error {
    76  				d := &safehttp.DefaultDispatcher{}
    77  				defaultFunc := func() string { panic("this function should never be called") }
    78  				t := safehttp.Template(safetemplate.
    79  					Must(safetemplate.New("name").
    80  						Funcs(map[string]interface{}{"Token": defaultFunc}).
    81  						Parse(`<form><input type="hidden" name="token" value="{{Token}}">{{.}}</form>`)))
    82  				var data interface{}
    83  				data = "Content"
    84  				fm := map[string]interface{}{
    85  					"Token": func() string { return "Token-secret" },
    86  				}
    87  				return d.Write(w, &safehttp.TemplateResponse{Template: t, Data: data, FuncMap: fm})
    88  			},
    89  			wantBody: `<form><input type="hidden" name="token" value="Token-secret">Content</form>`,
    90  		},
    91  		{
    92  			name: "Safe HTML Template Response with  Nonce",
    93  			write: func(w http.ResponseWriter) error {
    94  				d := &safehttp.DefaultDispatcher{}
    95  				defaultFunc := func() string { panic("this function should never be called") }
    96  				t := safehttp.Template(safetemplate.
    97  					Must(safetemplate.New("name").
    98  						Funcs(map[string]interface{}{"Nonce": defaultFunc}).
    99  						Parse(`<script nonce="{{Nonce}}" type="application/javascript">alert("script")</script><h1>{{.}}</h1>`)))
   100  				var data interface{}
   101  				data = "Content"
   102  				fm := map[string]interface{}{
   103  					"Nonce": func() string { return "Nonce-secret" },
   104  				}
   105  				return d.Write(w, &safehttp.TemplateResponse{Template: t, Data: data, FuncMap: fm})
   106  			},
   107  			wantBody: `<script nonce="Nonce-secret" type="application/javascript">alert("script")</script><h1>Content</h1>`,
   108  		},
   109  		{
   110  			name: "Valid JSON Response",
   111  			write: func(w http.ResponseWriter) error {
   112  				d := &safehttp.DefaultDispatcher{}
   113  				data := struct {
   114  					Field string `json:"field"`
   115  				}{Field: "myField"}
   116  				return d.Write(w, safehttp.JSONResponse{data})
   117  			},
   118  			wantBody: ")]}',\n{\"field\":\"myField\"}\n",
   119  		},
   120  		{
   121  			name: "Redirect Response",
   122  			write: func(w http.ResponseWriter) error {
   123  				d := &safehttp.DefaultDispatcher{}
   124  				req := httptest.NewRequest("GET", "/path", nil)
   125  				r := safehttp.NewIncomingRequest(req)
   126  				return d.Write(w, safehttp.RedirectResponse{Request: r, Location: "/anotherpath", Code: safehttp.StatusFound})
   127  			},
   128  			wantHeaders: map[string][]string{"Location": {"/anotherpath"}},
   129  			wantStatus:  safehttp.StatusFound,
   130  			wantBody:    "<a href=\"/anotherpath\">Found</a>.\n\n",
   131  		},
   132  		{
   133  			name: "No Content Response",
   134  			write: func(w http.ResponseWriter) error {
   135  				d := &safehttp.DefaultDispatcher{}
   136  				return d.Write(w, safehttp.NoContentResponse{})
   137  			},
   138  			wantBody:   "",
   139  			wantStatus: safehttp.StatusNoContent,
   140  		},
   141  	}
   142  	for _, tt := range tests {
   143  		t.Run(tt.name, func(t *testing.T) {
   144  			rw := httptest.NewRecorder()
   145  			err := tt.write(rw)
   146  
   147  			if err != nil {
   148  				t.Errorf("tt.write(rw): got error %v, want nil", err)
   149  			}
   150  
   151  			if gotBody := rw.Body.String(); tt.wantBody != gotBody {
   152  				t.Errorf("response body: got %q, want %q", gotBody, tt.wantBody)
   153  			}
   154  
   155  			for k, want := range tt.wantHeaders {
   156  				got := rw.Header().Values(k)
   157  				if diff := cmp.Diff(want, got); diff != "" {
   158  					t.Errorf("response header %q: -want +got %s", k, diff)
   159  				}
   160  			}
   161  
   162  			wantStatus := tt.wantStatus
   163  			if wantStatus == 0 {
   164  				wantStatus = 200
   165  			}
   166  
   167  			if got := rw.Code; got != int(wantStatus) {
   168  				t.Errorf("Status: got %d, want %d", got, wantStatus)
   169  			}
   170  		})
   171  	}
   172  }
   173  
   174  func TestDefaultDispatcherInvalidResponse(t *testing.T) {
   175  	tests := []struct {
   176  		name  string
   177  		write func(w http.ResponseWriter) error
   178  		want  string
   179  	}{
   180  		{
   181  			name: "Unsafe HTML Response",
   182  			write: func(w http.ResponseWriter) error {
   183  				d := &safehttp.DefaultDispatcher{}
   184  				return d.Write(w, "<h1>Hello World!</h1>")
   185  			},
   186  			want: "",
   187  		},
   188  		{
   189  			name: "Unsafe Template Response",
   190  			write: func(w http.ResponseWriter) error {
   191  				d := &safehttp.DefaultDispatcher{}
   192  				t := safehttp.Template(template.
   193  					Must(template.New("name").
   194  						Parse("<h1>{{ . }}</h1>")))
   195  				var data interface{}
   196  				data = "This is an actual heading, though."
   197  				return d.Write(w, safehttp.TemplateResponse{Template: t, Data: data})
   198  			},
   199  			want: "",
   200  		},
   201  		{
   202  			name: "Invalid JSON Response",
   203  			write: func(w http.ResponseWriter) error {
   204  				d := &safehttp.DefaultDispatcher{}
   205  				return d.Write(w, safehttp.JSONResponse{math.Inf(1)})
   206  			},
   207  			want: ")]}',\n",
   208  		},
   209  	}
   210  	for _, tt := range tests {
   211  		t.Run(tt.name, func(t *testing.T) {
   212  			rw := httptest.NewRecorder()
   213  
   214  			if err := tt.write(rw); err == nil {
   215  				t.Error("tt.write(rw): got nil, want error")
   216  			}
   217  
   218  			if want, got := tt.want, rw.Body.String(); want != got {
   219  				t.Errorf("response body: got %q, want %q", got, want)
   220  			}
   221  		})
   222  	}
   223  }