github.com/segakazzz/buffalo@v0.16.22-0.20210119082501-1f52048d3feb/default_context.go (about)

     1  package buffalo
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"reflect"
    12  	"sort"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/gobuffalo/buffalo/binding"
    18  	"github.com/gobuffalo/buffalo/render"
    19  	"github.com/segakazzz/buffalo/internal/takeon/github.com/markbates/errx"
    20  )
    21  
    22  // assert that DefaultContext is implementing Context
    23  var _ Context = &DefaultContext{}
    24  var _ context.Context = &DefaultContext{}
    25  
    26  // DefaultContext is, as its name implies, a default
    27  // implementation of the Context interface.
    28  type DefaultContext struct {
    29  	context.Context
    30  	response    http.ResponseWriter
    31  	request     *http.Request
    32  	params      url.Values
    33  	logger      Logger
    34  	session     *Session
    35  	contentType string
    36  	data        *sync.Map
    37  	flash       *Flash
    38  }
    39  
    40  // Response returns the original Response for the request.
    41  func (d *DefaultContext) Response() http.ResponseWriter {
    42  	return d.response
    43  }
    44  
    45  // Request returns the original Request.
    46  func (d *DefaultContext) Request() *http.Request {
    47  	return d.request
    48  }
    49  
    50  // Params returns all of the parameters for the request,
    51  // including both named params and query string parameters.
    52  func (d *DefaultContext) Params() ParamValues {
    53  	return d.params
    54  }
    55  
    56  // Logger returns the Logger for this context.
    57  func (d *DefaultContext) Logger() Logger {
    58  	return d.logger
    59  }
    60  
    61  // Param returns a param, either named or query string,
    62  // based on the key.
    63  func (d *DefaultContext) Param(key string) string {
    64  	return d.Params().Get(key)
    65  }
    66  
    67  // Set a value onto the Context. Any value set onto the Context
    68  // will be automatically available in templates.
    69  func (d *DefaultContext) Set(key string, value interface{}) {
    70  	d.data.Store(key, value)
    71  }
    72  
    73  // Value that has previously stored on the context.
    74  func (d *DefaultContext) Value(key interface{}) interface{} {
    75  	if k, ok := key.(string); ok {
    76  		if v, ok := d.data.Load(k); ok {
    77  			return v
    78  		}
    79  	}
    80  	return d.Context.Value(key)
    81  }
    82  
    83  // Session for the associated Request.
    84  func (d *DefaultContext) Session() *Session {
    85  	return d.session
    86  }
    87  
    88  // Cookies for the associated request and response.
    89  func (d *DefaultContext) Cookies() *Cookies {
    90  	return &Cookies{d.request, d.response}
    91  }
    92  
    93  // Flash messages for the associated Request.
    94  func (d *DefaultContext) Flash() *Flash {
    95  	return d.flash
    96  }
    97  
    98  type paginable interface {
    99  	Paginate() string
   100  }
   101  
   102  // Render a status code and render.Renderer to the associated Response.
   103  // The request parameters will be made available to the render.Renderer
   104  // "{{.params}}". Any values set onto the Context will also automatically
   105  // be made available to the render.Renderer. To render "no content" pass
   106  // in a nil render.Renderer.
   107  func (d *DefaultContext) Render(status int, rr render.Renderer) error {
   108  	start := time.Now()
   109  	defer func() {
   110  		d.LogField("render", time.Since(start))
   111  	}()
   112  	if rr != nil {
   113  		data := d.Data()
   114  		pp := map[string]string{}
   115  		for k, v := range d.params {
   116  			pp[k] = v[0]
   117  		}
   118  		data["params"] = pp
   119  		data["flash"] = d.Flash().data
   120  		data["session"] = d.Session()
   121  		data["request"] = d.Request()
   122  		data["status"] = status
   123  		bb := &bytes.Buffer{}
   124  
   125  		err := rr.Render(bb, data)
   126  		if err != nil {
   127  			if er, ok := errx.Unwrap(err).(render.ErrRedirect); ok {
   128  				return d.Redirect(er.Status, er.URL)
   129  			}
   130  			return HTTPError{Status: http.StatusInternalServerError, Cause: err}
   131  		}
   132  
   133  		if d.Session() != nil {
   134  			d.Flash().Clear()
   135  			d.Flash().persist(d.Session())
   136  		}
   137  
   138  		d.Response().Header().Set("Content-Type", rr.ContentType())
   139  		if p, ok := data["pagination"].(paginable); ok {
   140  			d.Response().Header().Set("X-Pagination", p.Paginate())
   141  		}
   142  		d.Response().WriteHeader(status)
   143  		_, err = io.Copy(d.Response(), bb)
   144  		if err != nil {
   145  			return HTTPError{Status: http.StatusInternalServerError, Cause: err}
   146  		}
   147  
   148  		return nil
   149  	}
   150  	d.Response().WriteHeader(status)
   151  	return nil
   152  }
   153  
   154  // Bind the interface to the request.Body. The type of binding
   155  // is dependent on the "Content-Type" for the request. If the type
   156  // is "application/json" it will use "json.NewDecoder". If the type
   157  // is "application/xml" it will use "xml.NewDecoder". See the
   158  // github.com/gobuffalo/buffalo/binding package for more details.
   159  func (d *DefaultContext) Bind(value interface{}) error {
   160  	return binding.Exec(d.Request(), value)
   161  }
   162  
   163  // LogField adds the key/value pair onto the Logger to be printed out
   164  // as part of the request logging. This allows you to easily add things
   165  // like metrics (think DB times) to your request.
   166  func (d *DefaultContext) LogField(key string, value interface{}) {
   167  	d.logger = d.logger.WithField(key, value)
   168  }
   169  
   170  // LogFields adds the key/value pairs onto the Logger to be printed out
   171  // as part of the request logging. This allows you to easily add things
   172  // like metrics (think DB times) to your request.
   173  func (d *DefaultContext) LogFields(values map[string]interface{}) {
   174  	d.logger = d.logger.WithFields(values)
   175  }
   176  
   177  func (d *DefaultContext) Error(status int, err error) error {
   178  	return HTTPError{Status: status, Cause: err}
   179  }
   180  
   181  var mapType = reflect.ValueOf(map[string]interface{}{}).Type()
   182  
   183  // Redirect a request with the given status to the given URL.
   184  func (d *DefaultContext) Redirect(status int, url string, args ...interface{}) error {
   185  	d.Flash().persist(d.Session())
   186  
   187  	if strings.HasSuffix(url, "Path()") {
   188  		if len(args) > 1 {
   189  			return fmt.Errorf("you must pass only a map[string]interface{} to a route path: %T", args)
   190  		}
   191  		var m map[string]interface{}
   192  		if len(args) == 1 {
   193  			rv := reflect.Indirect(reflect.ValueOf(args[0]))
   194  			if !rv.Type().ConvertibleTo(mapType) {
   195  				return fmt.Errorf("you must pass only a map[string]interface{} to a route path: %T", args)
   196  			}
   197  			m = rv.Convert(mapType).Interface().(map[string]interface{})
   198  		}
   199  		h, ok := d.Value(strings.TrimSuffix(url, "()")).(RouteHelperFunc)
   200  		if !ok {
   201  			return fmt.Errorf("could not find a route helper named %s", url)
   202  		}
   203  		url, err := h(m)
   204  		if err != nil {
   205  			return err
   206  		}
   207  		http.Redirect(d.Response(), d.Request(), string(url), status)
   208  		return nil
   209  	}
   210  
   211  	if len(args) > 0 {
   212  		url = fmt.Sprintf(url, args...)
   213  	}
   214  	http.Redirect(d.Response(), d.Request(), url, status)
   215  	return nil
   216  }
   217  
   218  // Data contains all the values set through Get/Set.
   219  func (d *DefaultContext) Data() map[string]interface{} {
   220  	m := map[string]interface{}{}
   221  	d.data.Range(func(k, v interface{}) bool {
   222  		s, ok := k.(string)
   223  		if !ok {
   224  			return false
   225  		}
   226  		m[s] = v
   227  		return true
   228  	})
   229  	return m
   230  }
   231  
   232  func (d *DefaultContext) String() string {
   233  	data := d.Data()
   234  	bb := make([]string, 0, len(data))
   235  
   236  	for k, v := range data {
   237  		if _, ok := v.(RouteHelperFunc); !ok {
   238  			bb = append(bb, fmt.Sprintf("%s: %s", k, v))
   239  		}
   240  	}
   241  	sort.Strings(bb)
   242  	return strings.Join(bb, "\n\n")
   243  }
   244  
   245  // File returns an uploaded file by name, or an error
   246  func (d *DefaultContext) File(name string) (binding.File, error) {
   247  	req := d.Request()
   248  	if err := req.ParseMultipartForm(5 * 1024 * 1024); err != nil {
   249  		return binding.File{}, err
   250  	}
   251  	f, h, err := req.FormFile(name)
   252  	bf := binding.File{
   253  		File:       f,
   254  		FileHeader: h,
   255  	}
   256  	if err != nil {
   257  		return bf, err
   258  	}
   259  	return bf, nil
   260  }
   261  
   262  // MarshalJSON implements json marshaling for the context
   263  func (d *DefaultContext) MarshalJSON() ([]byte, error) {
   264  	m := map[string]interface{}{}
   265  	data := d.Data()
   266  	for k, v := range data {
   267  		// don't try and marshal ourself
   268  		if _, ok := v.(*DefaultContext); ok {
   269  			continue
   270  		}
   271  		if _, err := json.Marshal(v); err == nil {
   272  			// it can be marshaled, so add it:
   273  			m[k] = v
   274  		}
   275  	}
   276  	return json.Marshal(m)
   277  }