github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/mux_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp_test
    16  
    17  import (
    18  	"io"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"testing"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  	"github.com/google/go-safeweb/safehttp"
    25  	"github.com/google/safehtml"
    26  )
    27  
    28  func TestMuxOneHandlerOneRequest(t *testing.T) {
    29  	var test = []struct {
    30  		name       string
    31  		req        *http.Request
    32  		wantStatus safehttp.StatusCode
    33  		wantHeader map[string][]string
    34  		wantBody   string
    35  	}{
    36  		{
    37  			name:       "Valid Request",
    38  			req:        httptest.NewRequest(safehttp.MethodGet, "http://foo.com/", nil),
    39  			wantStatus: safehttp.StatusOK,
    40  			wantHeader: map[string][]string{
    41  				"Content-Type": {"text/html; charset=utf-8"},
    42  			},
    43  			wantBody: "<h1>Hello World!</h1>",
    44  		},
    45  		{
    46  			name:       "Invalid Method",
    47  			req:        httptest.NewRequest(safehttp.MethodPost, "http://foo.com/", nil),
    48  			wantStatus: safehttp.StatusMethodNotAllowed,
    49  			wantHeader: map[string][]string{
    50  				"Content-Type":           {"text/plain; charset=utf-8"},
    51  				"X-Content-Type-Options": {"nosniff"},
    52  			},
    53  			wantBody: "Method Not Allowed\n",
    54  		},
    55  	}
    56  
    57  	for _, tt := range test {
    58  		t.Run(tt.name, func(t *testing.T) {
    59  			mb := safehttp.NewServeMuxConfig(nil)
    60  			mux := mb.Mux()
    61  
    62  			h := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    63  				return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
    64  			})
    65  			mux.Handle("/", safehttp.MethodGet, h)
    66  
    67  			rw := httptest.NewRecorder()
    68  
    69  			mux.ServeHTTP(rw, tt.req)
    70  
    71  			if rw.Code != int(tt.wantStatus) {
    72  				t.Errorf("rw.Code: got %v want %v", rw.Code, tt.wantStatus)
    73  			}
    74  
    75  			if diff := cmp.Diff(tt.wantHeader, map[string][]string(rw.Header())); diff != "" {
    76  				t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff)
    77  			}
    78  
    79  			if got := rw.Body.String(); got != tt.wantBody {
    80  				t.Errorf("response body: got %q want %q", got, tt.wantBody)
    81  			}
    82  		})
    83  	}
    84  }
    85  
    86  func TestMuxServeTwoHandlers(t *testing.T) {
    87  	var tests = []struct {
    88  		name        string
    89  		req         *http.Request
    90  		hf          safehttp.Handler
    91  		wantStatus  safehttp.StatusCode
    92  		wantHeaders map[string][]string
    93  		wantBody    string
    94  	}{
    95  		{
    96  			name: "GET Handler",
    97  			req:  httptest.NewRequest(safehttp.MethodGet, "http://foo.com/bar", nil),
    98  			hf: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    99  				return w.Write(safehtml.HTMLEscaped("<h1>Hello World! GET</h1>"))
   100  			}),
   101  			wantStatus: safehttp.StatusOK,
   102  			wantHeaders: map[string][]string{
   103  				"Content-Type": {"text/html; charset=utf-8"},
   104  			},
   105  			wantBody: "&lt;h1&gt;Hello World! GET&lt;/h1&gt;",
   106  		},
   107  		{
   108  			name: "POST Handler",
   109  			req:  httptest.NewRequest(safehttp.MethodPost, "http://foo.com/bar", nil),
   110  			hf: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   111  				return w.Write(safehtml.HTMLEscaped("<h1>Hello World! POST</h1>"))
   112  			}),
   113  			wantStatus: safehttp.StatusOK,
   114  			wantHeaders: map[string][]string{
   115  				"Content-Type": {"text/html; charset=utf-8"},
   116  			},
   117  			wantBody: "&lt;h1&gt;Hello World! POST&lt;/h1&gt;",
   118  		},
   119  	}
   120  
   121  	mb := safehttp.NewServeMuxConfig(nil)
   122  	mux := mb.Mux()
   123  
   124  	mux.Handle("/bar", safehttp.MethodGet, tests[0].hf)
   125  	mux.Handle("/bar", safehttp.MethodPost, tests[1].hf)
   126  
   127  	for _, test := range tests {
   128  		rw := httptest.NewRecorder()
   129  		mux.ServeHTTP(rw, test.req)
   130  		if want := int(test.wantStatus); rw.Code != want {
   131  			t.Errorf("rw.Code: got %v want %v", rw.Code, want)
   132  		}
   133  
   134  		if diff := cmp.Diff(test.wantHeaders, map[string][]string(rw.Header())); diff != "" {
   135  			t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff)
   136  		}
   137  
   138  		if got, want := rw.Body.String(), test.wantBody; got != want {
   139  			t.Errorf("response body: got %q want %q", got, want)
   140  		}
   141  	}
   142  }
   143  
   144  func TestMuxRegisterCorrectHandlerAllPaths(t *testing.T) {
   145  	var tests = []struct {
   146  		name     string
   147  		req      *http.Request
   148  		hf       safehttp.Handler
   149  		wantBody string
   150  	}{
   151  		{
   152  			name: "GET Handler",
   153  			req:  httptest.NewRequest(safehttp.MethodGet, "http://foo.com/get", nil),
   154  			hf: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   155  				return w.Write(safehtml.HTMLEscaped("GET handler for /get"))
   156  			}),
   157  			wantBody: "GET handler for /get",
   158  		},
   159  		{
   160  			name: "GET Handler #2",
   161  			req:  httptest.NewRequest(safehttp.MethodGet, "http://foo.com/get2", nil),
   162  			hf: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   163  				return w.Write(safehtml.HTMLEscaped("GET handler for /get2"))
   164  			}),
   165  			wantBody: "GET handler for /get2",
   166  		},
   167  	}
   168  
   169  	mb := safehttp.NewServeMuxConfig(nil)
   170  	mux := mb.Mux()
   171  	mux.Handle("/get", safehttp.MethodGet, tests[0].hf)
   172  	mux.Handle("/get2", safehttp.MethodGet, tests[1].hf)
   173  
   174  	for _, test := range tests {
   175  		rw := httptest.NewRecorder()
   176  		mux.ServeHTTP(rw, test.req)
   177  
   178  		if got, want := rw.Body.String(), test.wantBody; got != want {
   179  			t.Errorf("response body: got %q want %q", got, want)
   180  		}
   181  	}
   182  }
   183  
   184  func TestMuxHandleSameMethodTwice(t *testing.T) {
   185  	mb := safehttp.NewServeMuxConfig(nil)
   186  	mux := mb.Mux()
   187  
   188  	registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   189  		return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
   190  	})
   191  	mux.Handle("/bar", safehttp.MethodGet, registeredHandler)
   192  
   193  	defer func() {
   194  		if r := recover(); r != nil {
   195  			return
   196  		}
   197  		t.Errorf(`mux.Handle("/bar", MethodGet, registeredHandler) expected panic`)
   198  	}()
   199  
   200  	mux.Handle("/bar", safehttp.MethodGet, registeredHandler)
   201  }
   202  
   203  type setHeaderInterceptor struct {
   204  	name  string
   205  	value string
   206  }
   207  
   208  func (p setHeaderInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   209  	w.Header().Set(p.name, p.value)
   210  	return safehttp.NotWritten()
   211  }
   212  
   213  func (p setHeaderInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   214  }
   215  
   216  func (setHeaderInterceptor) Match(safehttp.InterceptorConfig) bool {
   217  	return false
   218  }
   219  
   220  type internalErrorInterceptor struct{}
   221  
   222  func (internalErrorInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   223  	return w.WriteError(safehttp.StatusInternalServerError)
   224  }
   225  
   226  func (internalErrorInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   227  }
   228  
   229  func (internalErrorInterceptor) Match(safehttp.InterceptorConfig) bool {
   230  	return false
   231  }
   232  
   233  type claimHeaderInterceptor struct {
   234  	headerToClaim string
   235  }
   236  
   237  type claimKey struct{}
   238  
   239  func (p *claimHeaderInterceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   240  	f := w.Header().Claim(p.headerToClaim)
   241  	safehttp.FlightValues(r.Context()).Put(claimKey{}, f)
   242  	return safehttp.NotWritten()
   243  }
   244  
   245  func (p *claimHeaderInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   246  }
   247  
   248  func (claimHeaderInterceptor) Match(safehttp.InterceptorConfig) bool {
   249  	return false
   250  }
   251  
   252  func claimInterceptorSetHeader(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, value string) {
   253  	f := safehttp.FlightValues(r.Context()).Get(claimKey{}).(func([]string))
   254  	f([]string{value})
   255  }
   256  
   257  type committerInterceptor struct{}
   258  
   259  func (committerInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   260  	return safehttp.NotWritten()
   261  }
   262  
   263  func (committerInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   264  	w.Header().Set("foo", "bar")
   265  }
   266  
   267  func (committerInterceptor) Match(safehttp.InterceptorConfig) bool {
   268  	return false
   269  }
   270  
   271  type setHeaderErroringInterceptor struct{}
   272  
   273  func (setHeaderErroringInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   274  	return w.WriteError(safehttp.StatusForbidden)
   275  }
   276  
   277  func (setHeaderErroringInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   278  	w.Header().Set("name", "foo")
   279  }
   280  
   281  func TestMuxInterceptors(t *testing.T) {
   282  	tests := []struct {
   283  		name        string
   284  		mux         *safehttp.ServeMux
   285  		wantStatus  safehttp.StatusCode
   286  		wantHeaders map[string][]string
   287  		wantBody    string
   288  	}{
   289  		{
   290  			name: "Install ServeMux Interceptor before handler registration",
   291  			mux: func() *safehttp.ServeMux {
   292  				mb := safehttp.NewServeMuxConfig(nil)
   293  				mb.Intercept(setHeaderInterceptor{
   294  					name:  "Foo",
   295  					value: "bar",
   296  				})
   297  				mux := mb.Mux()
   298  
   299  				registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   300  					return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
   301  				})
   302  				mux.Handle("/bar", safehttp.MethodGet, registeredHandler)
   303  				return mux
   304  			}(),
   305  			wantStatus: safehttp.StatusOK,
   306  			wantHeaders: map[string][]string{
   307  				"Content-Type": {"text/html; charset=utf-8"},
   308  				"Foo":          {"bar"},
   309  			},
   310  			wantBody: "&lt;h1&gt;Hello World!&lt;/h1&gt;",
   311  		},
   312  		{
   313  			name: "Install Interrupting Interceptor",
   314  			mux: func() *safehttp.ServeMux {
   315  				mb := safehttp.NewServeMuxConfig(nil)
   316  				mb.Intercept(internalErrorInterceptor{})
   317  				mux := mb.Mux()
   318  
   319  				registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   320  					return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
   321  				})
   322  				mux.Handle("/bar", safehttp.MethodGet, registeredHandler)
   323  
   324  				return mux
   325  			}(),
   326  			wantStatus: safehttp.StatusInternalServerError,
   327  			wantHeaders: map[string][]string{
   328  				"Content-Type":           {"text/plain; charset=utf-8"},
   329  				"X-Content-Type-Options": {"nosniff"},
   330  			},
   331  			wantBody: "Internal Server Error\n",
   332  		},
   333  		{
   334  			name: "Handler Communication With ServeMux Interceptor",
   335  			mux: func() *safehttp.ServeMux {
   336  				mb := safehttp.NewServeMuxConfig(nil)
   337  				mb.Intercept(&claimHeaderInterceptor{headerToClaim: "Foo"})
   338  				mux := mb.Mux()
   339  
   340  				registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   341  					claimInterceptorSetHeader(w, r, "bar")
   342  					return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
   343  				})
   344  				mux.Handle("/bar", safehttp.MethodGet, registeredHandler)
   345  
   346  				return mux
   347  			}(),
   348  			wantStatus: safehttp.StatusOK,
   349  			wantHeaders: map[string][]string{
   350  				"Content-Type": {"text/html; charset=utf-8"},
   351  				"Foo":          {"bar"},
   352  			},
   353  			wantBody: "&lt;h1&gt;Hello World!&lt;/h1&gt;",
   354  		},
   355  		{
   356  			name: "Commit phase sets header",
   357  			mux: func() *safehttp.ServeMux {
   358  				mb := safehttp.NewServeMuxConfig(nil)
   359  				mb.Intercept(committerInterceptor{})
   360  				mux := mb.Mux()
   361  
   362  				registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   363  					return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
   364  				})
   365  				mux.Handle("/bar", safehttp.MethodGet, registeredHandler)
   366  
   367  				return mux
   368  			}(),
   369  			wantStatus: safehttp.StatusOK,
   370  			wantHeaders: map[string][]string{
   371  				"Foo":          {"bar"},
   372  				"Content-Type": {"text/html; charset=utf-8"},
   373  			},
   374  			wantBody: "&lt;h1&gt;Hello World!&lt;/h1&gt;",
   375  		},
   376  	}
   377  
   378  	for _, tt := range tests {
   379  		t.Run(tt.name, func(t *testing.T) {
   380  			rw := httptest.NewRecorder()
   381  			req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/bar", nil)
   382  
   383  			tt.mux.ServeHTTP(rw, req)
   384  
   385  			if rw.Code != int(tt.wantStatus) {
   386  				t.Errorf("rw.Code: got %v want %v", rw.Code, tt.wantStatus)
   387  			}
   388  
   389  			if diff := cmp.Diff(tt.wantHeaders, map[string][]string(rw.Header())); diff != "" {
   390  				t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff)
   391  			}
   392  
   393  			if got := rw.Body.String(); got != tt.wantBody {
   394  				t.Errorf("response body: got %q want %q", got, tt.wantBody)
   395  			}
   396  		})
   397  	}
   398  }
   399  
   400  type setHeaderConfig struct {
   401  	name  string
   402  	value string
   403  }
   404  
   405  type setHeaderConfigInterceptor struct{}
   406  
   407  func (p setHeaderConfigInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   408  	name := "Pizza"
   409  	value := "Hawaii"
   410  	if c, ok := cfg.(setHeaderConfig); ok {
   411  		name = c.name
   412  		value = c.value
   413  	}
   414  	w.Header().Set(name, value)
   415  	return safehttp.NotWritten()
   416  }
   417  
   418  func (p setHeaderConfigInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   419  	name := "Commit-Pizza"
   420  	value := "Hawaii"
   421  	if c, ok := cfg.(setHeaderConfig); ok {
   422  		name = "Commit-" + c.name
   423  		value = c.value
   424  	}
   425  	w.Header().Set(name, value)
   426  }
   427  
   428  func (setHeaderConfigInterceptor) Match(cfg safehttp.InterceptorConfig) bool {
   429  	_, ok := cfg.(setHeaderConfig)
   430  	return ok
   431  }
   432  
   433  type noInterceptorConfig struct{}
   434  
   435  type wrappedInterceptor struct {
   436  	w safehttp.Interceptor
   437  }
   438  
   439  func (wi wrappedInterceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   440  	return wi.w.Before(w, r, cfg)
   441  }
   442  
   443  func (wi wrappedInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   444  	wi.w.Commit(w, r, resp, cfg)
   445  }
   446  
   447  func (wi wrappedInterceptor) Match(cfg safehttp.InterceptorConfig) bool {
   448  	return wi.w.Match(cfg)
   449  }
   450  
   451  func (noInterceptorConfig) Match(i safehttp.Interceptor) bool {
   452  	return false
   453  }
   454  
   455  func TestMuxInterceptorConfigs(t *testing.T) {
   456  	tests := []struct {
   457  		name        string
   458  		interceptor safehttp.Interceptor
   459  		config      safehttp.InterceptorConfig
   460  		wantStatus  safehttp.StatusCode
   461  		wantHeaders map[string][]string
   462  		wantBody    string
   463  	}{
   464  		{
   465  			name:        "SetHeaderInterceptor with config",
   466  			interceptor: setHeaderConfigInterceptor{},
   467  			config:      setHeaderConfig{name: "Foo", value: "Bar"},
   468  			wantStatus:  safehttp.StatusOK,
   469  			wantHeaders: map[string][]string{
   470  				"Content-Type": {"text/html; charset=utf-8"},
   471  				"Commit-Foo":   {"Bar"},
   472  				"Foo":          {"Bar"},
   473  			},
   474  			wantBody: "&lt;h1&gt;Hello World!&lt;/h1&gt;",
   475  		},
   476  		{
   477  			name:        "Wrapped SetHeaderInterceptor with config",
   478  			interceptor: wrappedInterceptor{w: setHeaderConfigInterceptor{}},
   479  			config:      setHeaderConfig{name: "Foo", value: "Bar"},
   480  			wantStatus:  safehttp.StatusOK,
   481  			wantHeaders: map[string][]string{
   482  				"Content-Type": {"text/html; charset=utf-8"},
   483  				"Commit-Foo":   {"Bar"},
   484  				"Foo":          {"Bar"},
   485  			},
   486  			wantBody: "&lt;h1&gt;Hello World!&lt;/h1&gt;",
   487  		},
   488  		{
   489  			name:        "SetHeaderInterceptor with mismatching config",
   490  			interceptor: setHeaderConfigInterceptor{},
   491  			config:      noInterceptorConfig{},
   492  			wantStatus:  safehttp.StatusOK,
   493  			wantHeaders: map[string][]string{
   494  				"Content-Type": {"text/html; charset=utf-8"},
   495  				"Pizza":        {"Hawaii"},
   496  				"Commit-Pizza": {"Hawaii"},
   497  			},
   498  			wantBody: "&lt;h1&gt;Hello World!&lt;/h1&gt;",
   499  		},
   500  	}
   501  
   502  	for _, tt := range tests {
   503  		t.Run(tt.name, func(t *testing.T) {
   504  			mb := safehttp.NewServeMuxConfig(nil)
   505  			mb.Intercept(tt.interceptor)
   506  			mux := mb.Mux()
   507  
   508  			registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   509  				return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
   510  			})
   511  			mux.Handle("/bar", safehttp.MethodGet, registeredHandler, tt.config)
   512  
   513  			rw := httptest.NewRecorder()
   514  			req := httptest.NewRequest("GET", "http://foo.com/bar", nil)
   515  
   516  			mux.ServeHTTP(rw, req)
   517  
   518  			if rw.Code != int(tt.wantStatus) {
   519  				t.Errorf("rw.Code: got %v want %v", rw.Code, tt.wantStatus)
   520  			}
   521  
   522  			if diff := cmp.Diff(tt.wantHeaders, map[string][]string(rw.Header())); diff != "" {
   523  				t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff)
   524  			}
   525  
   526  			if got := rw.Body.String(); got != tt.wantBody {
   527  				t.Errorf("response body: got %q want %q", got, tt.wantBody)
   528  			}
   529  		})
   530  	}
   531  }
   532  
   533  type interceptorOne struct{}
   534  
   535  func (interceptorOne) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   536  	w.Header().Set("pizza", "diavola")
   537  	return safehttp.NotWritten()
   538  }
   539  
   540  func (interceptorOne) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   541  	if w.Header().Get("Commit2") != "b" {
   542  		panic("server bug")
   543  	}
   544  	w.Header().Set("Commit1", "a")
   545  }
   546  
   547  func (interceptorOne) Match(safehttp.InterceptorConfig) bool {
   548  	return false
   549  }
   550  
   551  type interceptorTwo struct{}
   552  
   553  func (interceptorTwo) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   554  	if w.Header().Get("pizza") != "diavola" {
   555  		panic("server bug")
   556  	}
   557  	w.Header().Set("spaghetti", "bolognese")
   558  	return safehttp.NotWritten()
   559  }
   560  
   561  func (interceptorTwo) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   562  	if w.Header().Get("Commit3") != "c" {
   563  		panic("server bug")
   564  	}
   565  	w.Header().Set("Commit2", "b")
   566  }
   567  
   568  func (interceptorTwo) Match(safehttp.InterceptorConfig) bool {
   569  	return false
   570  }
   571  
   572  type interceptorThree struct{}
   573  
   574  func (interceptorThree) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
   575  	if w.Header().Get("spaghetti") != "bolognese" {
   576  		panic("server bug")
   577  	}
   578  	w.Header().Set("dessert", "tiramisu")
   579  	return safehttp.NotWritten()
   580  }
   581  
   582  func (interceptorThree) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
   583  	if w.Header().Get("Dessert") != "tiramisu" {
   584  		panic("server bug")
   585  	}
   586  	w.Header().Set("Commit3", "c")
   587  }
   588  
   589  func (interceptorThree) Match(safehttp.InterceptorConfig) bool {
   590  	return false
   591  }
   592  
   593  func TestMuxDeterministicInterceptorOrder(t *testing.T) {
   594  	mb := safehttp.NewServeMuxConfig(nil)
   595  	mb.Intercept(interceptorOne{})
   596  	mb.Intercept(interceptorTwo{})
   597  	mb.Intercept(interceptorThree{})
   598  	mux := mb.Mux()
   599  
   600  	registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   601  		return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
   602  	})
   603  	mux.Handle("/bar", safehttp.MethodGet, registeredHandler)
   604  
   605  	rw := httptest.NewRecorder()
   606  	req := httptest.NewRequest("GET", "http://foo.com/bar", nil)
   607  
   608  	mux.ServeHTTP(rw, req)
   609  
   610  	if want := safehttp.StatusOK; rw.Code != int(want) {
   611  		t.Errorf("rw.Code: got %v want %v", rw.Code, want)
   612  	}
   613  	wantHeaders := map[string][]string{
   614  		"Dessert":      {"tiramisu"},
   615  		"Pizza":        {"diavola"},
   616  		"Spaghetti":    {"bolognese"},
   617  		"Commit1":      {"a"},
   618  		"Commit2":      {"b"},
   619  		"Commit3":      {"c"},
   620  		"Content-Type": {"text/html; charset=utf-8"},
   621  	}
   622  	if diff := cmp.Diff(wantHeaders, map[string][]string(rw.Header())); diff != "" {
   623  		t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff)
   624  	}
   625  	if got, want := rw.Body.String(), "&lt;h1&gt;Hello World!&lt;/h1&gt;"; got != want {
   626  		t.Errorf(`response body: got %q want %q`, got, want)
   627  	}
   628  }
   629  
   630  func TestMuxHandlerReturnsNotWritten(t *testing.T) {
   631  	mb := safehttp.NewServeMuxConfig(nil)
   632  	h := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   633  		return safehttp.NotWritten()
   634  	})
   635  	mux := mb.Mux()
   636  	mux.Handle("/bar", safehttp.MethodGet, h)
   637  
   638  	req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/bar", nil)
   639  	rw := httptest.NewRecorder()
   640  
   641  	mux.ServeHTTP(rw, req)
   642  
   643  	if want := safehttp.StatusNoContent; rw.Code != int(want) {
   644  		t.Errorf("rw.Code: got %v want %v", rw.Code, want)
   645  	}
   646  	if diff := cmp.Diff(map[string][]string{}, map[string][]string(rw.Header())); diff != "" {
   647  		t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff)
   648  	}
   649  	if got := rw.Body.String(); got != "" {
   650  		t.Errorf(`response body got: %q want: ""`, got)
   651  	}
   652  }
   653  
   654  func TestMuxMethodNotAllowedDefaults(t *testing.T) {
   655  	mb := safehttp.NewServeMuxConfig(nil)
   656  	mux := mb.Mux()
   657  
   658  	h := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   659  		panic("not tested")
   660  	})
   661  	mux.Handle("/", safehttp.MethodGet, h)
   662  
   663  	rw := httptest.NewRecorder()
   664  
   665  	mux.ServeHTTP(rw, httptest.NewRequest(safehttp.MethodPost, "http://foo.com/", nil))
   666  
   667  	if got, want := rw.Code, int(safehttp.StatusMethodNotAllowed); got != want {
   668  		t.Errorf("rw.Code: got %v want %v", got, want)
   669  	}
   670  
   671  	wantHeader := map[string][]string{
   672  		"Content-Type":           {"text/plain; charset=utf-8"},
   673  		"X-Content-Type-Options": {"nosniff"},
   674  	}
   675  	if diff := cmp.Diff(wantHeader, map[string][]string(rw.Header())); diff != "" {
   676  		t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff)
   677  	}
   678  
   679  	wantBody := "Method Not Allowed\n"
   680  	if got := rw.Body.String(); got != wantBody {
   681  		t.Errorf("response body: got %q want %q", got, wantBody)
   682  	}
   683  }
   684  
   685  type methodNotAllowedError struct {
   686  	message string
   687  }
   688  
   689  func (err *methodNotAllowedError) Code() safehttp.StatusCode {
   690  	return safehttp.StatusMethodNotAllowed
   691  }
   692  
   693  type methodNotAllowedDispatcher struct {
   694  	safehttp.DefaultDispatcher
   695  }
   696  
   697  func (d *methodNotAllowedDispatcher) Error(rw http.ResponseWriter, resp safehttp.ErrorResponse) error {
   698  	x := resp.(*methodNotAllowedError)
   699  	rw.Header().Set("Content-Type", "text/html; charset=utf-8")
   700  	rw.WriteHeader(int(resp.Code()))
   701  	_, err := io.WriteString(rw, "<h1>"+http.StatusText(int(resp.Code()))+"</h1>"+"<p>"+x.message+"</p>")
   702  	return err
   703  }
   704  
   705  type methodNotAllowedInterceptor struct{}
   706  
   707  func (ip *methodNotAllowedInterceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, ipcfg safehttp.InterceptorConfig) safehttp.Result {
   708  	cfg := ipcfg.(methodNotAllowedInterceptorConfig)
   709  	w.Header().Set("Before-Interceptor", cfg.before)
   710  	return safehttp.NotWritten()
   711  }
   712  
   713  // Commit runs before the response is written by the Dispatcher. If an error
   714  // is written to the ResponseWriter, then the Commit phases from the
   715  // remaining interceptors won't execute.
   716  func (ip *methodNotAllowedInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, ipcfg safehttp.InterceptorConfig) {
   717  	cfg := ipcfg.(methodNotAllowedInterceptorConfig)
   718  	w.Header().Set("Commit-Interceptor", cfg.commit)
   719  }
   720  
   721  func (*methodNotAllowedInterceptor) Match(cfg safehttp.InterceptorConfig) bool {
   722  	_, ok := cfg.(methodNotAllowedInterceptorConfig)
   723  	return ok
   724  }
   725  
   726  type methodNotAllowedInterceptorConfig struct {
   727  	before, commit string
   728  }
   729  
   730  func TestMuxMethodNotAllowedCustom(t *testing.T) {
   731  	mb := safehttp.NewServeMuxConfig(&methodNotAllowedDispatcher{})
   732  	mb.Intercept(&methodNotAllowedInterceptor{})
   733  	mb.HandleMethodNotAllowed(safehttp.HandlerFunc(func(rw safehttp.ResponseWriter, ir *safehttp.IncomingRequest) safehttp.Result {
   734  		return rw.WriteError(&methodNotAllowedError{"custom message"})
   735  	}), methodNotAllowedInterceptorConfig{before: "foo", commit: "bar"})
   736  	mux := mb.Mux()
   737  
   738  	mux.Handle("/", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   739  		panic("not tested")
   740  	}))
   741  
   742  	rw := httptest.NewRecorder()
   743  
   744  	mux.ServeHTTP(rw, httptest.NewRequest(safehttp.MethodPost, "http://foo.com/", nil))
   745  
   746  	if got, want := rw.Code, int(safehttp.StatusMethodNotAllowed); got != want {
   747  		t.Errorf("rw.Code: got %v want %v", got, want)
   748  	}
   749  
   750  	wantHeader := map[string][]string{
   751  		"Content-Type":       {"text/html; charset=utf-8"},
   752  		"Before-Interceptor": {"foo"},
   753  		"Commit-Interceptor": {"bar"},
   754  	}
   755  	if diff := cmp.Diff(wantHeader, map[string][]string(rw.Header())); diff != "" {
   756  		t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff)
   757  	}
   758  
   759  	wantBody := "<h1>Method Not Allowed</h1><p>custom message</p>"
   760  	if got := rw.Body.String(); got != wantBody {
   761  		t.Errorf("response body: got %q want %q", got, wantBody)
   762  	}
   763  }