github.com/ethersphere/bee/v2@v2.2.0/pkg/jsonhttp/handlers_test.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonhttp_test
     6  
     7  import (
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/ethersphere/bee/v2/pkg/jsonhttp"
    17  )
    18  
    19  func TestMethodHandler(t *testing.T) {
    20  	t.Parallel()
    21  
    22  	contentType := "application/swarm"
    23  
    24  	h := jsonhttp.MethodHandler{
    25  		"POST": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    26  			got, err := io.ReadAll(r.Body)
    27  			if err != nil {
    28  				t.Fatal(err)
    29  			}
    30  			w.Header().Set("Content-Type", contentType)
    31  			fmt.Fprint(w, "got: ", string(got))
    32  		}),
    33  	}
    34  
    35  	t.Run("method allowed", func(t *testing.T) {
    36  		t.Parallel()
    37  
    38  		body := "test body"
    39  
    40  		r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
    41  		w := httptest.NewRecorder()
    42  
    43  		h.ServeHTTP(w, r)
    44  
    45  		statusCode := w.Result().StatusCode
    46  		if statusCode != http.StatusOK {
    47  			t.Errorf("got status code %d, want %d", statusCode, http.StatusOK)
    48  		}
    49  
    50  		wantBody := "got: " + body
    51  		gotBody := w.Body.String()
    52  
    53  		if gotBody != wantBody {
    54  			t.Errorf("got body %q, want %q", gotBody, wantBody)
    55  		}
    56  
    57  		if got := w.Header().Get("Content-Type"); got != contentType {
    58  			t.Errorf("got content type %q, want %q", got, contentType)
    59  		}
    60  	})
    61  
    62  	t.Run("method not allowed", func(t *testing.T) {
    63  		t.Parallel()
    64  
    65  		r := httptest.NewRequest(http.MethodGet, "/", nil)
    66  		w := httptest.NewRecorder()
    67  
    68  		h.ServeHTTP(w, r)
    69  
    70  		statusCode := w.Result().StatusCode
    71  		wantCode := http.StatusMethodNotAllowed
    72  		if statusCode != wantCode {
    73  			t.Errorf("got status code %d, want %d", statusCode, wantCode)
    74  		}
    75  
    76  		var m *jsonhttp.StatusResponse
    77  
    78  		if err := json.Unmarshal(w.Body.Bytes(), &m); err != nil {
    79  			t.Errorf("json unmarshal response body: %s", err)
    80  		}
    81  
    82  		if m.Code != wantCode {
    83  			t.Errorf("got message code %d, want %d", m.Code, wantCode)
    84  		}
    85  
    86  		wantMessage := http.StatusText(wantCode)
    87  		if m.Message != wantMessage {
    88  			t.Errorf("got message message %q, want %q", m.Message, wantMessage)
    89  		}
    90  
    91  		testContentType(t, w)
    92  	})
    93  }
    94  
    95  func TestNotFoundHandler(t *testing.T) {
    96  	t.Parallel()
    97  
    98  	w := httptest.NewRecorder()
    99  
   100  	jsonhttp.NotFoundHandler(w, nil)
   101  
   102  	statusCode := w.Result().StatusCode
   103  	wantCode := http.StatusNotFound
   104  	if statusCode != wantCode {
   105  		t.Errorf("got status code %d, want %d", statusCode, wantCode)
   106  	}
   107  
   108  	var m *jsonhttp.StatusResponse
   109  
   110  	if err := json.Unmarshal(w.Body.Bytes(), &m); err != nil {
   111  		t.Errorf("json unmarshal response body: %s", err)
   112  	}
   113  
   114  	if m.Code != wantCode {
   115  		t.Errorf("got message code %d, want %d", m.Code, wantCode)
   116  	}
   117  
   118  	wantMessage := http.StatusText(wantCode)
   119  	if m.Message != wantMessage {
   120  		t.Errorf("got message message %q, want %q", m.Message, wantMessage)
   121  	}
   122  
   123  	testContentType(t, w)
   124  }
   125  
   126  func TestNewMaxBodyBytesHandler(t *testing.T) {
   127  	t.Parallel()
   128  
   129  	var limit int64 = 10
   130  
   131  	h := jsonhttp.NewMaxBodyBytesHandler(limit)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   132  		_, err := io.ReadAll(r.Body)
   133  		if err != nil {
   134  			if jsonhttp.HandleBodyReadError(err, w) {
   135  				return
   136  			}
   137  			jsonhttp.InternalServerError(w, nil)
   138  			return
   139  		}
   140  		jsonhttp.OK(w, nil)
   141  	}))
   142  
   143  	for _, tc := range []struct {
   144  		name                 string
   145  		body                 string
   146  		withoutContentLength bool
   147  		wantCode             int
   148  	}{
   149  		{
   150  			name:     "empty",
   151  			wantCode: http.StatusOK,
   152  		},
   153  		{
   154  			name:                 "within limit without content length header",
   155  			body:                 "data",
   156  			withoutContentLength: true,
   157  			wantCode:             http.StatusOK,
   158  		},
   159  		{
   160  			name:     "within limit",
   161  			body:     "data",
   162  			wantCode: http.StatusOK,
   163  		},
   164  		{
   165  			name:     "over limit",
   166  			body:     "long test data",
   167  			wantCode: http.StatusRequestEntityTooLarge,
   168  		},
   169  		{
   170  			name:                 "over limit without content length header",
   171  			body:                 "long test data",
   172  			withoutContentLength: true,
   173  			wantCode:             http.StatusRequestEntityTooLarge,
   174  		},
   175  	} {
   176  		tc := tc
   177  		t.Run(tc.name, func(t *testing.T) {
   178  			t.Parallel()
   179  
   180  			r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.body))
   181  			if tc.withoutContentLength {
   182  				r.Header.Del("Content-Length")
   183  				r.ContentLength = 0
   184  			}
   185  			w := httptest.NewRecorder()
   186  
   187  			h.ServeHTTP(w, r)
   188  
   189  			if w.Code != tc.wantCode {
   190  				t.Errorf("got http response code %d, want %d", w.Code, tc.wantCode)
   191  			}
   192  
   193  			var m *jsonhttp.StatusResponse
   194  
   195  			if err := json.Unmarshal(w.Body.Bytes(), &m); err != nil {
   196  				t.Errorf("json unmarshal response body: %s", err)
   197  			}
   198  
   199  			if m.Code != tc.wantCode {
   200  				t.Errorf("got message code %d, want %d", m.Code, tc.wantCode)
   201  			}
   202  
   203  			wantMessage := http.StatusText(tc.wantCode)
   204  			if m.Message != wantMessage {
   205  				t.Errorf("got message message %q, want %q", m.Message, wantMessage)
   206  			}
   207  		})
   208  	}
   209  }