github.com/decred/politeia@v1.4.0/politeiawww/middleware_test.go (about)

     1  // Copyright (c) 2021 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"io"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/gorilla/mux"
    15  )
    16  
    17  func TestReqBodySizeMiddleware(t *testing.T) {
    18  	// Setup the test router
    19  	router := mux.NewRouter()
    20  	m := middleware{
    21  		reqBodySizeLimit: 5,
    22  	}
    23  	router.Use(closeBodyMiddleware)
    24  	router.Use(m.reqBodySizeLimitMiddleware)
    25  
    26  	// Setup a route handler that reads the request body. Reading
    27  	// the request body is required in order to trigger the error.
    28  	testRoute := "/test"
    29  	router.HandleFunc(testRoute, func(w http.ResponseWriter, r *http.Request) {
    30  		_, err := io.ReadAll(r.Body)
    31  		if err != nil {
    32  			w.WriteHeader(http.StatusBadRequest)
    33  			return
    34  		}
    35  		w.WriteHeader(http.StatusOK)
    36  	})
    37  
    38  	// Setup test request bodies
    39  	const (
    40  		fourBytes = "1234"
    41  		fiveBytes = "12345"
    42  		sixBytes  = "123456"
    43  	)
    44  
    45  	// Setup tests
    46  	var tests = []struct {
    47  		name     string
    48  		reqBody  string
    49  		wantCode int
    50  	}{
    51  		{
    52  			"no request body",
    53  			"",
    54  			http.StatusOK,
    55  		},
    56  		{
    57  			"under the req body limit",
    58  			fourBytes,
    59  			http.StatusOK,
    60  		},
    61  		{
    62  			"at the req body limit",
    63  			fiveBytes,
    64  			http.StatusOK,
    65  		},
    66  		{
    67  			"over the req body limit",
    68  			sixBytes,
    69  			http.StatusBadRequest,
    70  		},
    71  	}
    72  
    73  	// Run tests
    74  	for _, tc := range tests {
    75  		t.Run(tc.name, func(t *testing.T) {
    76  			// Setup the test request
    77  			req, err := http.NewRequest(http.MethodPost,
    78  				testRoute, strings.NewReader(tc.reqBody))
    79  			if err != nil {
    80  				t.Fatal(err)
    81  			}
    82  
    83  			// Send the test request
    84  			rr := httptest.NewRecorder()
    85  			router.ServeHTTP(rr, req)
    86  
    87  			// Verify the response
    88  			if rr.Code != tc.wantCode {
    89  				t.Errorf("wrong http response code: got %v, want %v",
    90  					rr.Code, tc.wantCode)
    91  			}
    92  		})
    93  	}
    94  }