github.com/segakazzz/buffalo@v0.16.22-0.20210119082501-1f52048d3feb/errors_test.go (about)

     1  package buffalo
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"os"
     7  	"testing"
     8  
     9  	"github.com/gobuffalo/httptest"
    10  	"github.com/gobuffalo/logger"
    11  	"github.com/sirupsen/logrus"
    12  
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  //testLoggerHook is useful to test whats being logged.
    17  type testLoggerHook struct {
    18  	errors []*logrus.Entry
    19  }
    20  
    21  func (lh *testLoggerHook) Fire(entry *logrus.Entry) error {
    22  	lh.errors = append(lh.errors, entry)
    23  	return nil
    24  }
    25  
    26  func (lh *testLoggerHook) Levels() []logrus.Level {
    27  	return []logrus.Level{
    28  		logrus.ErrorLevel,
    29  	}
    30  }
    31  
    32  func Test_defaultErrorHandler_SetsContentType(t *testing.T) {
    33  	r := require.New(t)
    34  	app := New(Options{})
    35  	app.GET("/", func(c Context) error {
    36  		return c.Error(http.StatusUnauthorized, fmt.Errorf("boom"))
    37  	})
    38  
    39  	w := httptest.New(app)
    40  	res := w.HTML("/").Get()
    41  	r.Equal(http.StatusUnauthorized, res.Code)
    42  	ct := res.Header().Get("content-type")
    43  	r.Equal("text/html; charset=utf-8", ct)
    44  }
    45  
    46  func Test_defaultErrorHandler_Logger(t *testing.T) {
    47  	r := require.New(t)
    48  	app := New(Options{})
    49  	app.GET("/", func(c Context) error {
    50  		return c.Error(http.StatusUnauthorized, fmt.Errorf("boom"))
    51  	})
    52  
    53  	testHook := &testLoggerHook{}
    54  	l := logrus.New()
    55  	l.SetOutput(os.Stdout)
    56  	l.AddHook(testHook)
    57  	log := logger.Logrus{
    58  		FieldLogger: l,
    59  	}
    60  	app.Logger = log
    61  
    62  	w := httptest.New(app)
    63  	res := w.HTML("/").Get()
    64  	r.Equal(http.StatusUnauthorized, res.Code)
    65  	r.Equal(http.StatusUnauthorized, testHook.errors[0].Data["status"])
    66  }
    67  
    68  func Test_defaultErrorHandler_JSON(t *testing.T) {
    69  	r := require.New(t)
    70  	app := New(Options{})
    71  	app.GET("/", func(c Context) error {
    72  		return c.Error(http.StatusUnauthorized, fmt.Errorf("boom"))
    73  	})
    74  
    75  	w := httptest.New(app)
    76  	res := w.JSON("/").Get()
    77  	r.Equal(http.StatusUnauthorized, res.Code)
    78  	ct := res.Header().Get("content-type")
    79  	r.Equal("application/json", ct)
    80  	b := res.Body.String()
    81  	r.Contains(b, `"code":401`)
    82  	r.Contains(b, `"error":"boom"`)
    83  	r.Contains(b, `"trace":"`)
    84  }
    85  
    86  func Test_defaultErrorHandler_XML(t *testing.T) {
    87  	r := require.New(t)
    88  	app := New(Options{})
    89  	app.GET("/", func(c Context) error {
    90  		return c.Error(http.StatusUnauthorized, fmt.Errorf("boom"))
    91  	})
    92  
    93  	w := httptest.New(app)
    94  	res := w.XML("/").Get()
    95  	r.Equal(http.StatusUnauthorized, res.Code)
    96  	ct := res.Header().Get("content-type")
    97  	r.Equal("text/xml", ct)
    98  	b := res.Body.String()
    99  	r.Contains(b, `<response code="401">`)
   100  	r.Contains(b, `<error>boom</error>`)
   101  	r.Contains(b, `<trace>`)
   102  	r.Contains(b, `</trace>`)
   103  	r.Contains(b, `</response>`)
   104  }
   105  
   106  func Test_PanicHandler(t *testing.T) {
   107  	app := New(Options{})
   108  	app.GET("/string", func(c Context) error {
   109  		panic("string boom")
   110  	})
   111  	app.GET("/error", func(c Context) error {
   112  		panic(fmt.Errorf("error boom"))
   113  	})
   114  
   115  	table := []struct {
   116  		path     string
   117  		expected string
   118  	}{
   119  		{"/string", "string boom"},
   120  		{"/error", "error boom"},
   121  	}
   122  
   123  	const stack = `github.com/gobuffalo/buffalo.Test_PanicHandler`
   124  
   125  	w := httptest.New(app)
   126  	for _, tt := range table {
   127  		t.Run(tt.path, func(st *testing.T) {
   128  			r := require.New(st)
   129  
   130  			res := w.HTML(tt.path).Get()
   131  			r.Equal(http.StatusInternalServerError, res.Code)
   132  
   133  			body := res.Body.String()
   134  			r.Contains(body, tt.expected)
   135  			r.Contains(body, stack)
   136  		})
   137  	}
   138  }
   139  
   140  func Test_defaultErrorMiddleware(t *testing.T) {
   141  	r := require.New(t)
   142  	app := New(Options{})
   143  	var x string
   144  	var ok bool
   145  	app.ErrorHandlers[http.StatusUnprocessableEntity] = func(code int, err error, c Context) error {
   146  		x, ok = c.Value("T").(string)
   147  		c.Response().WriteHeader(code)
   148  		c.Response().Write([]byte(err.Error()))
   149  		return nil
   150  	}
   151  	app.Use(func(next Handler) Handler {
   152  		return func(c Context) error {
   153  			c.Set("T", "t")
   154  			return c.Error(http.StatusUnprocessableEntity, fmt.Errorf("boom"))
   155  		}
   156  	})
   157  	app.GET("/", func(c Context) error {
   158  		return nil
   159  	})
   160  
   161  	w := httptest.New(app)
   162  	res := w.HTML("/").Get()
   163  	r.Equal(http.StatusUnprocessableEntity, res.Code)
   164  	r.True(ok)
   165  	r.Equal("t", x)
   166  }
   167  
   168  func Test_SetErrorMiddleware(t *testing.T) {
   169  	r := require.New(t)
   170  	app := New(Options{})
   171  	app.ErrorHandlers.Default(func(code int, err error, c Context) error {
   172  		res := c.Response()
   173  		res.WriteHeader(http.StatusTeapot)
   174  		res.Write([]byte("i'm a teapot"))
   175  		return nil
   176  	})
   177  	app.GET("/", func(c Context) error {
   178  		return c.Error(http.StatusUnprocessableEntity, fmt.Errorf("boom"))
   179  	})
   180  
   181  	w := httptest.New(app)
   182  	res := w.HTML("/").Get()
   183  	r.Equal(http.StatusTeapot, res.Code)
   184  	r.Equal("i'm a teapot", res.Body.String())
   185  }