github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/migration_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp_test
    16  
    17  import (
    18  	"fmt"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/google/go-safeweb/safehttp"
    25  	"github.com/google/safehtml"
    26  )
    27  
    28  func TestRegisteredHandler(t *testing.T) {
    29  	mb := safehttp.NewServeMuxConfig(nil)
    30  	safeMux := mb.Mux()
    31  
    32  	safeMux.Handle("/abc", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    33  		return w.Write(safehtml.HTMLEscaped("Welcome!"))
    34  	}))
    35  	safeMux.Handle("/abc", safehttp.MethodPost, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    36  		f, err := r.PostForm()
    37  		if err != nil {
    38  			return w.WriteError(safehttp.StatusBadRequest)
    39  		}
    40  		animal := f.String("animal", "")
    41  		if animal == "" {
    42  			return w.WriteError(safehttp.StatusBadRequest)
    43  		}
    44  		return w.Write(safehtml.HTMLEscaped(fmt.Sprintf("Added %s.", animal)))
    45  	}))
    46  	safeMux.Handle("/def", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    47  		return w.Write(safehtml.HTMLEscaped("Bye!"))
    48  	}))
    49  
    50  	mux := http.NewServeMux()
    51  	mux.Handle("/abc", safehttp.RegisteredHandler(safeMux, "/abc"))
    52  	mux.Handle("/def", safehttp.RegisteredHandler(safeMux, "/def"))
    53  
    54  	var tests = []struct {
    55  		name       string
    56  		req        *http.Request
    57  		wantStatus safehttp.StatusCode
    58  		wantBody   string
    59  	}{
    60  		{
    61  			name:       "Valid GET Request",
    62  			req:        httptest.NewRequest(safehttp.MethodGet, "http://foo.com/abc", nil),
    63  			wantStatus: safehttp.StatusOK,
    64  			wantBody:   "Welcome!",
    65  		},
    66  		{
    67  			name: "Valid POST Request",
    68  			req: func() *http.Request {
    69  				req := httptest.NewRequest(safehttp.MethodPost, "http://foo.com/abc", strings.NewReader("animal=cat"))
    70  				req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
    71  				return req
    72  			}(),
    73  			wantStatus: safehttp.StatusOK,
    74  			wantBody:   "Added cat.",
    75  		},
    76  		{
    77  			name:       "Different handler",
    78  			req:        httptest.NewRequest(safehttp.MethodGet, "http://foo.com/def", nil),
    79  			wantStatus: safehttp.StatusOK,
    80  			wantBody:   "Bye!",
    81  		},
    82  		{
    83  			name:       "Invalid Method",
    84  			req:        httptest.NewRequest(safehttp.MethodHead, "http://foo.com/abc", nil),
    85  			wantStatus: safehttp.StatusMethodNotAllowed,
    86  			wantBody:   "Method Not Allowed\n",
    87  		},
    88  	}
    89  
    90  	for _, tt := range tests {
    91  		t.Run(tt.name, func(t *testing.T) {
    92  			rw := httptest.NewRecorder()
    93  
    94  			mux.ServeHTTP(rw, tt.req)
    95  
    96  			if rw.Code != int(tt.wantStatus) {
    97  				t.Errorf("rw.Code: got %v want %v", rw.Code, tt.wantStatus)
    98  			}
    99  
   100  			if got := rw.Body.String(); got != tt.wantBody {
   101  				t.Errorf("response body: got %q want %q", got, tt.wantBody)
   102  			}
   103  		})
   104  	}
   105  }
   106  
   107  func TestRegisteredHandler_StrictPatterns(t *testing.T) {
   108  	mb := safehttp.NewServeMuxConfig(nil)
   109  	safeMux := mb.Mux()
   110  
   111  	safeMux.Handle("/foo/", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
   112  		return w.Write(safehtml.HTMLEscaped("Homepage!"))
   113  	}))
   114  
   115  	if safehttp.RegisteredHandler(safeMux, "/foo/") == nil {
   116  		t.Error(`RegisteredHandler(_, "/foo/") got nil, want non-nil`)
   117  	}
   118  	if got := safehttp.RegisteredHandler(safeMux, "/foo/subpath"); got != nil {
   119  		t.Errorf(`RegisteredHandler(_, "/foo/subpath") got %v, want nil`, got)
   120  	}
   121  }