github.com/seeker-insurance/kit@v0.0.13/web/api_context.go (about) 1 package web 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net/http" 10 "reflect" 11 "regexp" 12 "strconv" 13 "strings" 14 15 "errors" 16 17 "github.com/labstack/echo" 18 "github.com/seeker-insurance/kit/flect" 19 "github.com/seeker-insurance/kit/jsonapi" 20 "github.com/seeker-insurance/kit/maputil" 21 "github.com/seeker-insurance/kit/web/pagination" 22 ) 23 24 var reNotJsonApi = regexp.MustCompile("not a jsonapi|EOF") 25 26 func notJsonApi(err error) bool { 27 return reNotJsonApi.MatchString(err.Error()) 28 } 29 30 type ( 31 ApiContext interface { 32 echo.Context 33 34 Payload() *jsonapi.OnePayload 35 Attrs(permitted ...string) map[string]interface{} 36 AttrKeys() []string 37 RequireAttrs(...string) error 38 BindAndValidate(interface{}) error 39 BindMulti(interface{}) ([]interface{}, error) 40 BindIdParam(*int, ...string) error 41 JsonApi(interface{}, int) error 42 JsonApiOK(interface{}, ...interface{}) error 43 JsonApiOKPaged(interface{}, *pagination.Pagination, ...interface{}) error 44 ApiError(string, ...int) *echo.HTTPError 45 JsonAPIError(string, int, string) *jsonapi.ErrorObject 46 QueryParamTrue(string) (bool, bool) 47 48 RequiredQueryParams(...string) (map[string]string, error) 49 OptionalQueryParams(...string) map[string]string 50 QParams(...string) (map[string]string, error) 51 } 52 53 apiContext struct { 54 echo.Context 55 56 payload *jsonapi.OnePayload 57 manyPayload *jsonapi.ManyPayload 58 } 59 60 CommonExtendable interface { 61 CommonExtend(interface{}) error 62 } 63 64 Extendable interface { 65 Extend(interface{}) error 66 } 67 68 CommonMetable interface { 69 CommonMeta() error 70 } 71 72 Metable interface { 73 Meta() error 74 } 75 ) 76 77 func (c *apiContext) Payload() *jsonapi.OnePayload { 78 return c.payload 79 } 80 81 func (c *apiContext) Attrs(permitted ...string) map[string]interface{} { 82 //TODO: remove this once all refactoring is complete 83 if len(permitted) == 0 { 84 return c.payload.Data.Attributes 85 } 86 87 permittedAttrs := make(map[string]interface{}) 88 for _, p := range permitted { 89 if val, ok := c.payload.Data.Attributes[p]; ok { 90 permittedAttrs[p] = val 91 } 92 } 93 return permittedAttrs 94 } 95 96 func (c *apiContext) AttrKeys() []string { 97 return maputil.Keys(c.Attrs()) 98 } 99 100 func (c *apiContext) RequireAttrs(required ...string) error { 101 missing := make([]string, 0, len(required)) 102 103 for _, key := range required { 104 if c.payload.Data.Attributes[key] == nil { 105 missing = append(missing, key) 106 continue 107 } 108 } 109 110 if len(missing) > 0 { 111 return fmt.Errorf("missing required attributes: %v", missing) 112 } 113 114 return nil 115 } 116 117 //Before binding we make a copy of the req body and restore it after binding. 118 //This allows the body to be used again later 119 func (c *apiContext) Bind(i interface{}) error { 120 body, err := c.readRestoreBody() 121 if err != nil { 122 return err 123 } 124 125 ctype := c.Request().Header.Get(echo.HeaderContentType) 126 if isJSONAPI(ctype) { 127 err = jsonAPIBind(c, i) 128 } else { 129 err = c.defaultBind(i) 130 } 131 132 c.restoreBody(body) 133 134 return err 135 } 136 137 func (c *apiContext) BindMulti(containedType interface{}) ([]interface{}, error) { 138 body, err := c.readRestoreBody() 139 if err != nil { 140 return nil, err 141 } 142 143 ctype := c.Request().Header.Get(echo.HeaderContentType) 144 145 if !isJSONAPI(ctype) { 146 return nil, errors.New("BindMulti only supports JSONApi, use Bind") 147 } 148 149 i, err := jsonAPIBindMulti(c, containedType) 150 151 c.restoreBody(body) 152 153 return i, err 154 } 155 156 func (c *apiContext) readRestoreBody() ([]byte, error) { 157 b, err := ioutil.ReadAll(c.Request().Body) 158 c.restoreBody(b) 159 return b, err 160 } 161 162 func (c *apiContext) restoreBody(b []byte) { 163 c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(b)) 164 } 165 166 func (c *apiContext) defaultBind(i interface{}) error { 167 db := new(echo.DefaultBinder) 168 return db.Bind(i, c) 169 } 170 171 func isJSONAPI(s string) bool { 172 const MIMEJsonAPI = "application/vnd.api+json" 173 return strings.HasPrefix(s, MIMEJsonAPI) 174 } 175 176 func (c *apiContext) BindAndValidate(i interface{}) error { 177 if err := c.Bind(i); err != nil { 178 return err 179 } 180 if err := c.Validate(i); err != nil { 181 return err 182 } 183 return nil 184 } 185 186 func (c *apiContext) JsonApiPaged(i interface{}, status int, page *pagination.Pagination) error { 187 var buf bytes.Buffer 188 if err := jsonapi.MarshalPayloadPaged(&buf, i, page); err != nil { 189 return err 190 } 191 192 // These methods have to be the last thing called, *after* any error checks. 193 // Once any of the Write methods are called, the response is "committed" and 194 // cannot be changed. This causes error responses with 200 statuses. 195 c.Response().Header().Set(echo.HeaderContentType, jsonapi.MediaType) 196 c.Response().WriteHeader(status) 197 c.Response().Write(buf.Bytes()) 198 return nil 199 } 200 201 func (c *apiContext) JsonApi(i interface{}, status int) error { 202 var buf bytes.Buffer 203 if err := jsonapi.MarshalPayload(&buf, i); err != nil { 204 return err 205 } 206 207 // These methods have to be the last thing called, *after* any error checks. 208 // Once any of the Write methods are called, the response is "committed" and 209 // cannot be changed. This causes error responses with 200 statuses. 210 c.Response().Header().Set(echo.HeaderContentType, jsonapi.MediaType) 211 c.Response().WriteHeader(status) 212 c.Response().Write(buf.Bytes()) 213 return nil 214 } 215 216 func applyCommon(i interface{}, page *pagination.Pagination, extendData interface{}) error { 217 if casted, ok := i.(CommonExtendable); ok { 218 if err := casted.CommonExtend(extendData); err != nil { 219 return err 220 } 221 } 222 223 if casted, ok := i.(CommonMetable); ok { 224 if err := casted.CommonMeta(); err != nil { 225 return err 226 } 227 } 228 return nil 229 } 230 231 func apply(i interface{}, page *pagination.Pagination, extendData interface{}) error { 232 if casted, ok := i.(Extendable); ok { 233 if err := casted.Extend(extendData); err != nil { 234 return err 235 } 236 } 237 238 if casted, ok := i.(Metable); ok { 239 if err := casted.Meta(); err != nil { 240 return err 241 } 242 } 243 return nil 244 } 245 246 func extendAndExtract(i interface{}, page *pagination.Pagination, extendData interface{}) (data interface{}, err error) { 247 if flect.IsSlice(i) { 248 slice := reflect.ValueOf(i) 249 for idx := 0; idx < slice.Len(); idx++ { 250 elementInterface := slice.Index(idx).Interface() 251 if err := applyCommon(elementInterface, page, extendData); err != nil { 252 return nil, err 253 } 254 } 255 return i, nil 256 } 257 258 if err := applyCommon(i, page, extendData); err != nil { 259 return nil, err 260 } 261 262 if err := apply(i, page, extendData); err != nil { 263 return nil, err 264 } 265 return i, nil 266 } 267 268 func (c *apiContext) JsonApiOK(i interface{}, extendData ...interface{}) error { 269 var ed interface{} 270 if len(extendData) > 0 { 271 ed = extendData[0] 272 } 273 data, err := extendAndExtract(i, nil, ed) 274 if err != nil { 275 return err 276 } 277 return c.JsonApi(data, http.StatusOK) 278 } 279 280 func (c *apiContext) JsonApiOKPaged(i interface{}, page *pagination.Pagination, extendData ...interface{}) error { 281 var ed interface{} 282 if len(extendData) > 0 { 283 ed = extendData[0] 284 } 285 data, err := extendAndExtract(i, page, ed) 286 if err != nil { 287 return err 288 } 289 page.Url = *c.Request().URL 290 return c.JsonApiPaged(data, http.StatusOK, page) 291 } 292 293 func (c *apiContext) BindIdParam(idValue *int, named ...string) (err error) { 294 paramName := "id" 295 if len(named) > 0 { 296 paramName = named[0] 297 } 298 *idValue, err = strconv.Atoi(c.Param(paramName)) 299 return err 300 } 301 302 func (c *apiContext) QueryParamTrue(name string) (val, ok bool) { 303 switch strings.ToLower(c.QueryParam(name)) { 304 case "true", "1": 305 return true, true 306 case "false", "0": 307 return false, true 308 default: 309 return false, false 310 } 311 } 312 313 func jsonAPIBindMulti(c *apiContext, elementType interface{}) ([]interface{}, error) { 314 buf := new(bytes.Buffer) 315 tee := io.TeeReader(c.Request().Body, buf) 316 317 unmarshaled, err := jsonapi.UnmarshalManyPayload(tee, reflect.TypeOf(elementType)) 318 if err != nil { 319 return nil, err 320 } 321 322 c.manyPayload = new(jsonapi.ManyPayload) 323 return unmarshaled, json.Unmarshal(buf.Bytes(), c.manyPayload) 324 } 325 326 func jsonAPIBind(c *apiContext, i interface{}) error { 327 buf := new(bytes.Buffer) 328 tee := io.TeeReader(c.Request().Body, buf) 329 330 rType := reflect.TypeOf(i) 331 332 if rType.Kind() == reflect.Slice { 333 value := reflect.TypeOf(rType.Elem()) 334 335 unmarshaled, err := jsonapi.UnmarshalManyPayload(tee, value) 336 if err != nil { 337 return err 338 } 339 i = unmarshaled 340 } else { 341 if err := jsonapi.UnmarshalPayload(tee, i); err != nil { 342 if notJsonApi(err) { 343 return c.ApiError("Request Body is not valid JsonAPI") 344 } 345 return err 346 } 347 } 348 349 c.payload = new(jsonapi.OnePayload) 350 return json.Unmarshal(buf.Bytes(), c.payload) 351 } 352 353 func (c *apiContext) ApiError(msg string, codes ...int) *echo.HTTPError { 354 if len(codes) > 0 { 355 return echo.NewHTTPError(codes[0], msg) 356 } 357 // TODO: return jsonapi error instead 358 return echo.NewHTTPError(http.StatusBadRequest, msg) 359 } 360 361 func (c *apiContext) JsonAPIError(msg string, code int, param string) *jsonapi.ErrorObject { 362 return &jsonapi.ErrorObject{ 363 Status: fmt.Sprintf("%d", code), 364 Title: http.StatusText(code), 365 Detail: msg, 366 Meta: &map[string]interface{}{ 367 "parameter": param, 368 }, 369 } 370 } 371 372 func (c *apiContext) RequiredQueryParams(required ...string) (map[string]string, error) { 373 missing := make([]string, 0, len(required)) 374 params := make(map[string]string) 375 376 for _, key := range required { 377 val := c.QueryParam(key) 378 if val == "" { 379 missing = append(missing, key) 380 continue 381 } 382 params[key] = val 383 } 384 385 if len(missing) > 0 { 386 return nil, fmt.Errorf("missing required params: %v", missing) 387 } 388 389 return params, nil 390 } 391 392 func (c *apiContext) QParams(required ...string) (map[string]string, error) { 393 return QParams(c, required...) 394 } 395 396 func QParams(c echo.Context, required ...string) (map[string]string, error) { 397 missing := make([]string, 0, len(required)) 398 params := make(map[string]string) 399 400 for k := range c.QueryParams() { 401 params[k] = c.QueryParam(k) 402 } 403 404 for _, k := range required { 405 if _, ok := params[k]; !ok { 406 missing = append(missing, k) 407 } 408 } 409 410 if len(missing) > 0 { 411 return nil, fmt.Errorf("missing required params: %v", missing) 412 } 413 414 return params, nil 415 } 416 417 func (c *apiContext) OptionalQueryParams(optional ...string) map[string]string { 418 params := make(map[string]string) 419 for _, key := range optional { 420 val := c.QueryParam(key) 421 params[key] = val 422 } 423 return params 424 } 425 426 func ApiContextMiddleWare() func(echo.HandlerFunc) echo.HandlerFunc { 427 return func(next echo.HandlerFunc) echo.HandlerFunc { 428 return func(c echo.Context) error { 429 return next(&apiContext{c, nil, nil}) 430 } 431 } 432 } 433 434 func restrictedValue(value string, allowed []string, errorText string) (string, error) { 435 if contains(allowed, value) { 436 return value, nil 437 } 438 return "", fmt.Errorf(errorText, value) 439 } 440 441 func contains(set []string, s string) bool { 442 for _, v := range set { 443 if s == v { 444 return true 445 } 446 } 447 return false 448 }