github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/flight_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp_test
    16  
    17  import (
    18  	"fmt"
    19  	"net/http/httptest"
    20  	"testing"
    21  
    22  	"github.com/google/go-safeweb/safehttp"
    23  	"github.com/google/safehtml"
    24  )
    25  
    26  type panickingInterceptor struct {
    27  	before, commit, onError bool
    28  }
    29  
    30  func (p panickingInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result {
    31  	if p.before {
    32  		panic("before")
    33  	}
    34  	return safehttp.NotWritten()
    35  }
    36  
    37  func (p panickingInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) {
    38  	if p.commit {
    39  		panic("commit")
    40  	}
    41  }
    42  
    43  func (panickingInterceptor) Match(safehttp.InterceptorConfig) bool {
    44  	return false
    45  }
    46  
    47  func TestFlightInterceptorPanic(t *testing.T) {
    48  	tests := []struct {
    49  		desc        string
    50  		interceptor panickingInterceptor
    51  		wantPanic   bool
    52  	}{
    53  		{
    54  			desc:        "panic in Before",
    55  			interceptor: panickingInterceptor{before: true},
    56  			wantPanic:   true,
    57  		},
    58  		{
    59  			desc:        "panic in Commit",
    60  			interceptor: panickingInterceptor{commit: true},
    61  			wantPanic:   true,
    62  		},
    63  	}
    64  	for _, tc := range tests {
    65  		t.Run(tc.desc, func(t *testing.T) {
    66  			mb := safehttp.NewServeMuxConfig(nil)
    67  			mb.Intercept(tc.interceptor)
    68  			mux := mb.Mux()
    69  
    70  			mux.Handle("/search", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    71  				// IMPORTANT: We are setting the header here and expecting to be
    72  				// cleared if a panic occurs.
    73  				w.Header().Set("foo", "bar")
    74  				return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
    75  			}))
    76  
    77  			req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/search", nil)
    78  			rw := httptest.NewRecorder()
    79  
    80  			defer func() {
    81  				r := recover()
    82  				if !tc.wantPanic {
    83  					if r != nil {
    84  						t.Fatalf("unexpected panic %v", r)
    85  					}
    86  					return
    87  				}
    88  				if r == nil {
    89  					t.Fatal("expected panic")
    90  				}
    91  				// Good, the panic got propagated.
    92  				if len(rw.Header()) > 0 {
    93  					t.Errorf("ResponseWriter.Header() got %v, want empty", rw.Header())
    94  				}
    95  			}()
    96  			mux.ServeHTTP(rw, req)
    97  		})
    98  	}
    99  }
   100  
   101  func TestFlightHandlerPanic(t *testing.T) {
   102  	mb := safehttp.NewServeMuxConfig(nil)
   103  	mux := mb.Mux()
   104  
   105  	mux.Handle("/search", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   106  		// IMPORTANT: We are setting the header here and expecting to be
   107  		// cleared if a panic occurs.
   108  		w.Header().Set("foo", "bar")
   109  		panic("handler")
   110  	}))
   111  
   112  	req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/search", nil)
   113  	rw := httptest.NewRecorder()
   114  
   115  	defer func() {
   116  		r := recover()
   117  		if r == nil {
   118  			t.Fatalf("expected panic")
   119  		}
   120  		// Good, the panic got propagated.
   121  		if len(rw.Header()) > 0 {
   122  			t.Errorf("ResponseWriter.Header() got %v, want empty", rw.Header())
   123  		}
   124  	}()
   125  	mux.ServeHTTP(rw, req)
   126  }
   127  
   128  func TestFlightDoubleWritePanics(t *testing.T) {
   129  	writeFuncs := map[string]func(safehttp.ResponseWriter, *safehttp.IncomingRequest) safehttp.Result{
   130  		"Write": func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   131  			return w.Write(safehtml.HTMLEscaped("Hello"))
   132  		},
   133  		"WriteError": func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   134  			return w.WriteError(safehttp.StatusPreconditionFailed)
   135  		},
   136  	}
   137  
   138  	for firstWriteName, firstWrite := range writeFuncs {
   139  		for secondWriteName, secondWrite := range writeFuncs {
   140  			t.Run(fmt.Sprintf("%s->%s", firstWriteName, secondWriteName), func(t *testing.T) {
   141  				mb := safehttp.NewServeMuxConfig(nil)
   142  				mux := mb.Mux()
   143  				mux.Handle("/search", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   144  					firstWrite(w, r)
   145  					secondWrite(w, r) // this should panic
   146  					t.Fatal("should never reach this point")
   147  					return safehttp.Result{}
   148  				}))
   149  
   150  				req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/search", nil)
   151  				rw := httptest.NewRecorder()
   152  				defer func() {
   153  					if r := recover(); r == nil {
   154  						t.Fatalf("expected panic")
   155  					}
   156  					// Good, the panic got propagated.
   157  					// Note: we are not testing the response headers here, as the first write might have already succeeded.
   158  				}()
   159  				mux.ServeHTTP(rw, req)
   160  			})
   161  
   162  		}
   163  	}
   164  
   165  }