github.com/woremacx/kocha@v0.7.1-0.20150731103243-a5889322afc9/middleware.go (about)

     1  package kocha
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net/http"
     7  	"strconv"
     8  	"time"
     9  
    10  	"github.com/woremacx/kocha/log"
    11  	"github.com/woremacx/kocha/util"
    12  	"github.com/ugorji/go/codec"
    13  )
    14  
    15  // Middleware is the interface that middleware.
    16  type Middleware interface {
    17  	Process(app *Application, c *Context, next func() error) error
    18  }
    19  
    20  // Validator is the interface to validate the middleware.
    21  type Validator interface {
    22  	// Validate validates the middleware.
    23  	// Validate will be called in boot-time of the application.
    24  	Validate() error
    25  }
    26  
    27  // PanicRecoverMiddleware is a middleware to recover a panic where occurred in request sequence.
    28  type PanicRecoverMiddleware struct{}
    29  
    30  func (m *PanicRecoverMiddleware) Process(app *Application, c *Context, next func() error) (err error) {
    31  	defer func() {
    32  		defer func() {
    33  			if perr := recover(); perr != nil {
    34  				app.logStackAndError(perr)
    35  				err = fmt.Errorf("%v", perr)
    36  			}
    37  		}()
    38  		if err != nil {
    39  			app.Logger.Error(err)
    40  			goto ERROR
    41  		} else if perr := recover(); perr != nil {
    42  			app.logStackAndError(perr)
    43  			goto ERROR
    44  		}
    45  		return
    46  	ERROR:
    47  		c.Response.reset()
    48  		if err = internalServerErrorController.GET(c); err != nil {
    49  			app.logStackAndError(err)
    50  		}
    51  	}()
    52  	return next()
    53  }
    54  
    55  // FormMiddleware is a middleware to parse a form data from query string and/or request body.
    56  type FormMiddleware struct{}
    57  
    58  // Process implements the Middleware interface.
    59  func (m *FormMiddleware) Process(app *Application, c *Context, next func() error) error {
    60  	c.Request.Body = http.MaxBytesReader(c.Response, c.Request.Body, app.Config.MaxClientBodySize)
    61  	if err := c.Request.ParseMultipartForm(app.Config.MaxClientBodySize); err != nil && err != http.ErrNotMultipart {
    62  		return err
    63  	}
    64  	c.Params = c.newParams()
    65  	return next()
    66  }
    67  
    68  // SessionMiddleware is a middleware to process a session.
    69  type SessionMiddleware struct {
    70  	// Name of cookie (key)
    71  	Name string
    72  
    73  	// Implementation of session store
    74  	Store SessionStore
    75  
    76  	// Expiration of session cookie, in seconds, from now. (not session expiration)
    77  	// 0 is for persistent.
    78  	CookieExpires time.Duration
    79  
    80  	// Expiration of session data, in seconds, from now. (not cookie expiration)
    81  	// 0 is for persistent.
    82  	SessionExpires time.Duration
    83  	HttpOnly       bool
    84  	ExpiresKey     string
    85  }
    86  
    87  func (m *SessionMiddleware) Process(app *Application, c *Context, next func() error) error {
    88  	if err := m.before(app, c); err != nil {
    89  		return err
    90  	}
    91  	if err := next(); err != nil {
    92  		return err
    93  	}
    94  	return m.after(app, c)
    95  }
    96  
    97  // Validate validates configuration of the session.
    98  func (m *SessionMiddleware) Validate() error {
    99  	if m == nil {
   100  		return fmt.Errorf("kocha: session: middleware is nil")
   101  	}
   102  	if m.Store == nil {
   103  		return fmt.Errorf("kocha: session: because Store is nil, session cannot be used")
   104  	}
   105  	if m.Name == "" {
   106  		return fmt.Errorf("kocha: session: Name must be specified")
   107  	}
   108  	if m.ExpiresKey == "" {
   109  		m.ExpiresKey = "_kocha._sess._expires"
   110  	}
   111  	if v, ok := m.Store.(Validator); ok {
   112  		return v.Validate()
   113  	}
   114  	return nil
   115  }
   116  
   117  func (m *SessionMiddleware) before(app *Application, c *Context) (err error) {
   118  	defer func() {
   119  		switch err.(type) {
   120  		case nil:
   121  			// do nothing.
   122  		case ErrSession:
   123  			app.Logger.Info(err)
   124  		default:
   125  			app.Logger.Error(err)
   126  		}
   127  		if c.Session == nil {
   128  			c.Session = make(Session)
   129  		}
   130  		err = nil
   131  	}()
   132  	cookie, err := c.Request.Cookie(m.Name)
   133  	if err != nil {
   134  		return NewErrSession("new session")
   135  	}
   136  	sess, err := m.Store.Load(cookie.Value)
   137  	if err != nil {
   138  		return err
   139  	}
   140  	expiresStr, ok := sess[m.ExpiresKey]
   141  	if !ok {
   142  		return fmt.Errorf("expires value not found")
   143  	}
   144  	expires, err := strconv.ParseInt(expiresStr, 10, 64)
   145  	if err != nil {
   146  		return err
   147  	}
   148  	if expires < util.Now().Unix() {
   149  		return NewErrSession("session has been expired")
   150  	}
   151  	c.Session = sess
   152  	return nil
   153  }
   154  
   155  func (m *SessionMiddleware) after(app *Application, c *Context) (err error) {
   156  	expires, _ := m.expiresFromDuration(m.SessionExpires)
   157  	c.Session[m.ExpiresKey] = strconv.FormatInt(expires.Unix(), 10)
   158  	cookie := m.newSessionCookie(app, c)
   159  	cookie.Value, err = m.Store.Save(c.Session)
   160  	if err != nil {
   161  		return err
   162  	}
   163  	c.Response.SetCookie(cookie)
   164  	return nil
   165  }
   166  
   167  func (m *SessionMiddleware) newSessionCookie(app *Application, c *Context) *http.Cookie {
   168  	expires, maxAge := m.expiresFromDuration(m.CookieExpires)
   169  	return &http.Cookie{
   170  		Name:     m.Name,
   171  		Value:    "",
   172  		Path:     "/",
   173  		Expires:  expires,
   174  		MaxAge:   maxAge,
   175  		Secure:   c.Request.IsSSL(),
   176  		HttpOnly: m.HttpOnly,
   177  	}
   178  }
   179  
   180  func (m *SessionMiddleware) expiresFromDuration(d time.Duration) (expires time.Time, maxAge int) {
   181  	switch d {
   182  	case -1:
   183  		// persistent
   184  		expires = util.Now().UTC().AddDate(20, 0, 0)
   185  	case 0:
   186  		expires = time.Time{}
   187  	default:
   188  		expires = util.Now().UTC().Add(d)
   189  		maxAge = int(d.Seconds())
   190  	}
   191  	return expires, maxAge
   192  }
   193  
   194  // Flash messages processing middleware.
   195  type FlashMiddleware struct{}
   196  
   197  func (m *FlashMiddleware) Process(app *Application, c *Context, next func() error) error {
   198  	if err := m.before(app, c); err != nil {
   199  		return err
   200  	}
   201  	if err := next(); err != nil {
   202  		return err
   203  	}
   204  	return m.after(app, c)
   205  }
   206  
   207  func (m *FlashMiddleware) before(app *Application, c *Context) error {
   208  	if c.Session == nil {
   209  		app.Logger.Error("kocha: FlashMiddleware hasn't been added after SessionMiddleware; it cannot be used")
   210  		return nil
   211  	}
   212  	c.Flash = Flash{}
   213  	if flash := c.Session["_flash"]; flash != "" {
   214  		if err := codec.NewDecoderBytes([]byte(flash), codecHandler).Decode(&c.Flash); err != nil {
   215  			// make a new Flash instance because there is a possibility that
   216  			// garbage data is set to c.Flash by in-place decoding of Decode().
   217  			c.Flash = Flash{}
   218  			return fmt.Errorf("kocha: flash: unexpected error in decode process: %v", err)
   219  		}
   220  	}
   221  	return nil
   222  }
   223  
   224  func (m *FlashMiddleware) after(app *Application, c *Context) error {
   225  	if c.Session == nil {
   226  		return nil
   227  	}
   228  	if c.Flash.deleteLoaded(); c.Flash.Len() == 0 {
   229  		delete(c.Session, "_flash")
   230  		return nil
   231  	}
   232  	buf := bufPool.Get().(*bytes.Buffer)
   233  	defer func() {
   234  		buf.Reset()
   235  		bufPool.Put(buf)
   236  	}()
   237  	if err := codec.NewEncoder(buf, codecHandler).Encode(c.Flash); err != nil {
   238  		return fmt.Errorf("kocha: flash: unexpected error in encode process: %v", err)
   239  	}
   240  	c.Session["_flash"] = buf.String()
   241  	return nil
   242  }
   243  
   244  // Request logging middleware.
   245  type RequestLoggingMiddleware struct{}
   246  
   247  func (m *RequestLoggingMiddleware) Process(app *Application, c *Context, next func() error) error {
   248  	defer func() {
   249  		app.Logger.With(log.Fields{
   250  			"method":   c.Request.Method,
   251  			"uri":      c.Request.RequestURI,
   252  			"protocol": c.Request.Proto,
   253  			"status":   c.Response.StatusCode,
   254  		}).Info()
   255  	}()
   256  	return next()
   257  }
   258  
   259  // DispatchMiddleware is a middleware to dispatch handler.
   260  // DispatchMiddleware should be set to last of middlewares because doesn't call other middlewares after DispatchMiddleware.
   261  type DispatchMiddleware struct{}
   262  
   263  // Process implements the Middleware interface.
   264  func (m *DispatchMiddleware) Process(app *Application, c *Context, next func() error) error {
   265  	name, handler, params, found := app.Router.dispatch(c.Request)
   266  	if !found {
   267  		handler = (&ErrorController{
   268  			StatusCode: http.StatusNotFound,
   269  		}).GET
   270  	}
   271  	c.Name = name
   272  	if c.Params == nil {
   273  		c.Params = c.newParams()
   274  	}
   275  	for _, param := range params {
   276  		c.Params.Add(param.Name, param.Value)
   277  	}
   278  	return handler(c)
   279  }