github.com/seeker-insurance/kit@v0.0.13/web/api_context.go (about)

     1  package web
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"reflect"
    11  	"regexp"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"errors"
    16  
    17  	"github.com/labstack/echo"
    18  	"github.com/seeker-insurance/kit/flect"
    19  	"github.com/seeker-insurance/kit/jsonapi"
    20  	"github.com/seeker-insurance/kit/maputil"
    21  	"github.com/seeker-insurance/kit/web/pagination"
    22  )
    23  
    24  var reNotJsonApi = regexp.MustCompile("not a jsonapi|EOF")
    25  
    26  func notJsonApi(err error) bool {
    27  	return reNotJsonApi.MatchString(err.Error())
    28  }
    29  
    30  type (
    31  	ApiContext interface {
    32  		echo.Context
    33  
    34  		Payload() *jsonapi.OnePayload
    35  		Attrs(permitted ...string) map[string]interface{}
    36  		AttrKeys() []string
    37  		RequireAttrs(...string) error
    38  		BindAndValidate(interface{}) error
    39  		BindMulti(interface{}) ([]interface{}, error)
    40  		BindIdParam(*int, ...string) error
    41  		JsonApi(interface{}, int) error
    42  		JsonApiOK(interface{}, ...interface{}) error
    43  		JsonApiOKPaged(interface{}, *pagination.Pagination, ...interface{}) error
    44  		ApiError(string, ...int) *echo.HTTPError
    45  		JsonAPIError(string, int, string) *jsonapi.ErrorObject
    46  		QueryParamTrue(string) (bool, bool)
    47  
    48  		RequiredQueryParams(...string) (map[string]string, error)
    49  		OptionalQueryParams(...string) map[string]string
    50  		QParams(...string) (map[string]string, error)
    51  	}
    52  
    53  	apiContext struct {
    54  		echo.Context
    55  
    56  		payload     *jsonapi.OnePayload
    57  		manyPayload *jsonapi.ManyPayload
    58  	}
    59  
    60  	CommonExtendable interface {
    61  		CommonExtend(interface{}) error
    62  	}
    63  
    64  	Extendable interface {
    65  		Extend(interface{}) error
    66  	}
    67  
    68  	CommonMetable interface {
    69  		CommonMeta() error
    70  	}
    71  
    72  	Metable interface {
    73  		Meta() error
    74  	}
    75  )
    76  
    77  func (c *apiContext) Payload() *jsonapi.OnePayload {
    78  	return c.payload
    79  }
    80  
    81  func (c *apiContext) Attrs(permitted ...string) map[string]interface{} {
    82  	//TODO: remove this once all refactoring is complete
    83  	if len(permitted) == 0 {
    84  		return c.payload.Data.Attributes
    85  	}
    86  
    87  	permittedAttrs := make(map[string]interface{})
    88  	for _, p := range permitted {
    89  		if val, ok := c.payload.Data.Attributes[p]; ok {
    90  			permittedAttrs[p] = val
    91  		}
    92  	}
    93  	return permittedAttrs
    94  }
    95  
    96  func (c *apiContext) AttrKeys() []string {
    97  	return maputil.Keys(c.Attrs())
    98  }
    99  
   100  func (c *apiContext) RequireAttrs(required ...string) error {
   101  	missing := make([]string, 0, len(required))
   102  
   103  	for _, key := range required {
   104  		if c.payload.Data.Attributes[key] == nil {
   105  			missing = append(missing, key)
   106  			continue
   107  		}
   108  	}
   109  
   110  	if len(missing) > 0 {
   111  		return fmt.Errorf("missing required attributes: %v", missing)
   112  	}
   113  
   114  	return nil
   115  }
   116  
   117  //Before binding we make a copy of the req body and restore it after binding.
   118  //This allows the body to be used again later
   119  func (c *apiContext) Bind(i interface{}) error {
   120  	body, err := c.readRestoreBody()
   121  	if err != nil {
   122  		return err
   123  	}
   124  
   125  	ctype := c.Request().Header.Get(echo.HeaderContentType)
   126  	if isJSONAPI(ctype) {
   127  		err = jsonAPIBind(c, i)
   128  	} else {
   129  		err = c.defaultBind(i)
   130  	}
   131  
   132  	c.restoreBody(body)
   133  
   134  	return err
   135  }
   136  
   137  func (c *apiContext) BindMulti(containedType interface{}) ([]interface{}, error) {
   138  	body, err := c.readRestoreBody()
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	ctype := c.Request().Header.Get(echo.HeaderContentType)
   144  
   145  	if !isJSONAPI(ctype) {
   146  		return nil, errors.New("BindMulti only supports JSONApi, use Bind")
   147  	}
   148  
   149  	i, err := jsonAPIBindMulti(c, containedType)
   150  
   151  	c.restoreBody(body)
   152  
   153  	return i, err
   154  }
   155  
   156  func (c *apiContext) readRestoreBody() ([]byte, error) {
   157  	b, err := ioutil.ReadAll(c.Request().Body)
   158  	c.restoreBody(b)
   159  	return b, err
   160  }
   161  
   162  func (c *apiContext) restoreBody(b []byte) {
   163  	c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(b))
   164  }
   165  
   166  func (c *apiContext) defaultBind(i interface{}) error {
   167  	db := new(echo.DefaultBinder)
   168  	return db.Bind(i, c)
   169  }
   170  
   171  func isJSONAPI(s string) bool {
   172  	const MIMEJsonAPI = "application/vnd.api+json"
   173  	return strings.HasPrefix(s, MIMEJsonAPI)
   174  }
   175  
   176  func (c *apiContext) BindAndValidate(i interface{}) error {
   177  	if err := c.Bind(i); err != nil {
   178  		return err
   179  	}
   180  	if err := c.Validate(i); err != nil {
   181  		return err
   182  	}
   183  	return nil
   184  }
   185  
   186  func (c *apiContext) JsonApiPaged(i interface{}, status int, page *pagination.Pagination) error {
   187  	var buf bytes.Buffer
   188  	if err := jsonapi.MarshalPayloadPaged(&buf, i, page); err != nil {
   189  		return err
   190  	}
   191  
   192  	// These methods have to be the last thing called, *after* any error checks.
   193  	// Once any of the Write methods are called, the response is "committed" and
   194  	// cannot be changed. This causes error responses with 200 statuses.
   195  	c.Response().Header().Set(echo.HeaderContentType, jsonapi.MediaType)
   196  	c.Response().WriteHeader(status)
   197  	c.Response().Write(buf.Bytes())
   198  	return nil
   199  }
   200  
   201  func (c *apiContext) JsonApi(i interface{}, status int) error {
   202  	var buf bytes.Buffer
   203  	if err := jsonapi.MarshalPayload(&buf, i); err != nil {
   204  		return err
   205  	}
   206  
   207  	// These methods have to be the last thing called, *after* any error checks.
   208  	// Once any of the Write methods are called, the response is "committed" and
   209  	// cannot be changed. This causes error responses with 200 statuses.
   210  	c.Response().Header().Set(echo.HeaderContentType, jsonapi.MediaType)
   211  	c.Response().WriteHeader(status)
   212  	c.Response().Write(buf.Bytes())
   213  	return nil
   214  }
   215  
   216  func applyCommon(i interface{}, page *pagination.Pagination, extendData interface{}) error {
   217  	if casted, ok := i.(CommonExtendable); ok {
   218  		if err := casted.CommonExtend(extendData); err != nil {
   219  			return err
   220  		}
   221  	}
   222  
   223  	if casted, ok := i.(CommonMetable); ok {
   224  		if err := casted.CommonMeta(); err != nil {
   225  			return err
   226  		}
   227  	}
   228  	return nil
   229  }
   230  
   231  func apply(i interface{}, page *pagination.Pagination, extendData interface{}) error {
   232  	if casted, ok := i.(Extendable); ok {
   233  		if err := casted.Extend(extendData); err != nil {
   234  			return err
   235  		}
   236  	}
   237  
   238  	if casted, ok := i.(Metable); ok {
   239  		if err := casted.Meta(); err != nil {
   240  			return err
   241  		}
   242  	}
   243  	return nil
   244  }
   245  
   246  func extendAndExtract(i interface{}, page *pagination.Pagination, extendData interface{}) (data interface{}, err error) {
   247  	if flect.IsSlice(i) {
   248  		slice := reflect.ValueOf(i)
   249  		for idx := 0; idx < slice.Len(); idx++ {
   250  			elementInterface := slice.Index(idx).Interface()
   251  			if err := applyCommon(elementInterface, page, extendData); err != nil {
   252  				return nil, err
   253  			}
   254  		}
   255  		return i, nil
   256  	}
   257  
   258  	if err := applyCommon(i, page, extendData); err != nil {
   259  		return nil, err
   260  	}
   261  
   262  	if err := apply(i, page, extendData); err != nil {
   263  		return nil, err
   264  	}
   265  	return i, nil
   266  }
   267  
   268  func (c *apiContext) JsonApiOK(i interface{}, extendData ...interface{}) error {
   269  	var ed interface{}
   270  	if len(extendData) > 0 {
   271  		ed = extendData[0]
   272  	}
   273  	data, err := extendAndExtract(i, nil, ed)
   274  	if err != nil {
   275  		return err
   276  	}
   277  	return c.JsonApi(data, http.StatusOK)
   278  }
   279  
   280  func (c *apiContext) JsonApiOKPaged(i interface{}, page *pagination.Pagination, extendData ...interface{}) error {
   281  	var ed interface{}
   282  	if len(extendData) > 0 {
   283  		ed = extendData[0]
   284  	}
   285  	data, err := extendAndExtract(i, page, ed)
   286  	if err != nil {
   287  		return err
   288  	}
   289  	page.Url = *c.Request().URL
   290  	return c.JsonApiPaged(data, http.StatusOK, page)
   291  }
   292  
   293  func (c *apiContext) BindIdParam(idValue *int, named ...string) (err error) {
   294  	paramName := "id"
   295  	if len(named) > 0 {
   296  		paramName = named[0]
   297  	}
   298  	*idValue, err = strconv.Atoi(c.Param(paramName))
   299  	return err
   300  }
   301  
   302  func (c *apiContext) QueryParamTrue(name string) (val, ok bool) {
   303  	switch strings.ToLower(c.QueryParam(name)) {
   304  	case "true", "1":
   305  		return true, true
   306  	case "false", "0":
   307  		return false, true
   308  	default:
   309  		return false, false
   310  	}
   311  }
   312  
   313  func jsonAPIBindMulti(c *apiContext, elementType interface{}) ([]interface{}, error) {
   314  	buf := new(bytes.Buffer)
   315  	tee := io.TeeReader(c.Request().Body, buf)
   316  
   317  	unmarshaled, err := jsonapi.UnmarshalManyPayload(tee, reflect.TypeOf(elementType))
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	c.manyPayload = new(jsonapi.ManyPayload)
   323  	return unmarshaled, json.Unmarshal(buf.Bytes(), c.manyPayload)
   324  }
   325  
   326  func jsonAPIBind(c *apiContext, i interface{}) error {
   327  	buf := new(bytes.Buffer)
   328  	tee := io.TeeReader(c.Request().Body, buf)
   329  
   330  	rType := reflect.TypeOf(i)
   331  
   332  	if rType.Kind() == reflect.Slice {
   333  		value := reflect.TypeOf(rType.Elem())
   334  
   335  		unmarshaled, err := jsonapi.UnmarshalManyPayload(tee, value)
   336  		if err != nil {
   337  			return err
   338  		}
   339  		i = unmarshaled
   340  	} else {
   341  		if err := jsonapi.UnmarshalPayload(tee, i); err != nil {
   342  			if notJsonApi(err) {
   343  				return c.ApiError("Request Body is not valid JsonAPI")
   344  			}
   345  			return err
   346  		}
   347  	}
   348  
   349  	c.payload = new(jsonapi.OnePayload)
   350  	return json.Unmarshal(buf.Bytes(), c.payload)
   351  }
   352  
   353  func (c *apiContext) ApiError(msg string, codes ...int) *echo.HTTPError {
   354  	if len(codes) > 0 {
   355  		return echo.NewHTTPError(codes[0], msg)
   356  	}
   357  	// TODO: return jsonapi error instead
   358  	return echo.NewHTTPError(http.StatusBadRequest, msg)
   359  }
   360  
   361  func (c *apiContext) JsonAPIError(msg string, code int, param string) *jsonapi.ErrorObject {
   362  	return &jsonapi.ErrorObject{
   363  		Status: fmt.Sprintf("%d", code),
   364  		Title:  http.StatusText(code),
   365  		Detail: msg,
   366  		Meta: &map[string]interface{}{
   367  			"parameter": param,
   368  		},
   369  	}
   370  }
   371  
   372  func (c *apiContext) RequiredQueryParams(required ...string) (map[string]string, error) {
   373  	missing := make([]string, 0, len(required))
   374  	params := make(map[string]string)
   375  
   376  	for _, key := range required {
   377  		val := c.QueryParam(key)
   378  		if val == "" {
   379  			missing = append(missing, key)
   380  			continue
   381  		}
   382  		params[key] = val
   383  	}
   384  
   385  	if len(missing) > 0 {
   386  		return nil, fmt.Errorf("missing required params: %v", missing)
   387  	}
   388  
   389  	return params, nil
   390  }
   391  
   392  func (c *apiContext) QParams(required ...string) (map[string]string, error) {
   393  	return QParams(c, required...)
   394  }
   395  
   396  func QParams(c echo.Context, required ...string) (map[string]string, error) {
   397  	missing := make([]string, 0, len(required))
   398  	params := make(map[string]string)
   399  
   400  	for k := range c.QueryParams() {
   401  		params[k] = c.QueryParam(k)
   402  	}
   403  
   404  	for _, k := range required {
   405  		if _, ok := params[k]; !ok {
   406  			missing = append(missing, k)
   407  		}
   408  	}
   409  
   410  	if len(missing) > 0 {
   411  		return nil, fmt.Errorf("missing required params: %v", missing)
   412  	}
   413  
   414  	return params, nil
   415  }
   416  
   417  func (c *apiContext) OptionalQueryParams(optional ...string) map[string]string {
   418  	params := make(map[string]string)
   419  	for _, key := range optional {
   420  		val := c.QueryParam(key)
   421  		params[key] = val
   422  	}
   423  	return params
   424  }
   425  
   426  func ApiContextMiddleWare() func(echo.HandlerFunc) echo.HandlerFunc {
   427  	return func(next echo.HandlerFunc) echo.HandlerFunc {
   428  		return func(c echo.Context) error {
   429  			return next(&apiContext{c, nil, nil})
   430  		}
   431  	}
   432  }
   433  
   434  func restrictedValue(value string, allowed []string, errorText string) (string, error) {
   435  	if contains(allowed, value) {
   436  		return value, nil
   437  	}
   438  	return "", fmt.Errorf(errorText, value)
   439  }
   440  
   441  func contains(set []string, s string) bool {
   442  	for _, v := range set {
   443  		if s == v {
   444  			return true
   445  		}
   446  	}
   447  	return false
   448  }