github.com/msales/pkg/v3@v3.24.0/httpx/middleware/recovery_test.go (about)

     1  package middleware_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  
    10  	"github.com/msales/pkg/v3/httpx/middleware"
    11  	"github.com/msales/pkg/v3/log"
    12  	"github.com/msales/pkg/v3/mocks"
    13  	"github.com/msales/pkg/v3/stats"
    14  	"github.com/stretchr/testify/mock"
    15  )
    16  
    17  func TestWithRecovery(t *testing.T) {
    18  	h := middleware.WithRecovery(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    19  		panic("panic")
    20  	}))
    21  
    22  	ctx := context.Background()
    23  	logger := new(mocks.Logger)
    24  	logger.On("Error", "panic", "url", "/", "stack", mock.AnythingOfType("string"))
    25  	s := new(MockStats)
    26  	s.On("Inc", "panic_recovery", int64(1), float32(1.0), mock.Anything).Return(nil).Once()
    27  
    28  	req, _ := http.NewRequest("GET", "/", nil)
    29  	req = req.WithContext(stats.WithStats(log.WithLogger(ctx, logger), s))
    30  	resp := httptest.NewRecorder()
    31  
    32  	defer func() {
    33  		if err := recover(); err != nil {
    34  			t.Fatal("Expected the panic to be handled.")
    35  		}
    36  	}()
    37  
    38  	h.ServeHTTP(resp, req)
    39  
    40  	logger.AssertExpectations(t)
    41  	s.AssertExpectations(t)
    42  }
    43  
    44  func TestWithRecovery_WithoutStack(t *testing.T) {
    45  	h := middleware.WithRecovery(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    46  		panic("panic")
    47  	}), middleware.WithoutStack())
    48  
    49  	ctx := context.Background()
    50  	logger := new(mocks.Logger)
    51  	logger.On("Error", "panic", "url", "/")
    52  	s := new(MockStats)
    53  	s.On("Inc", "panic_recovery", int64(1), float32(1.0), mock.Anything).Return(nil).Once()
    54  
    55  	req, _ := http.NewRequest("GET", "/", nil)
    56  	req = req.WithContext(stats.WithStats(log.WithLogger(ctx, logger), s))
    57  	resp := httptest.NewRecorder()
    58  
    59  	defer func() {
    60  		if err := recover(); err != nil {
    61  			t.Fatal("Expected the panic to be handled.")
    62  		}
    63  	}()
    64  
    65  	h.ServeHTTP(resp, req)
    66  
    67  	logger.AssertExpectations(t)
    68  	s.AssertExpectations(t)
    69  }
    70  
    71  func TestWithRecovery_Error(t *testing.T) {
    72  	h := middleware.WithRecovery(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    73  		panic(errors.New("panic"))
    74  	}))
    75  
    76  	req, _ := http.NewRequest("GET", "/", nil)
    77  	req = req.WithContext(stats.WithStats(log.WithLogger(context.Background(), log.Null), stats.Null))
    78  	resp := httptest.NewRecorder()
    79  
    80  	defer func() {
    81  		if err := recover(); err != nil {
    82  			t.Fatal("Expected the panic to be handled.")
    83  		}
    84  	}()
    85  
    86  	h.ServeHTTP(resp, req)
    87  }