github.com/e154/smart-home@v0.17.2-0.20240311175135-e530a6e5cd45/api/controllers/common.go (about) 1 package controllers 2 3 import ( 4 "context" 5 "encoding/base64" 6 "encoding/json" 7 "io" 8 "net/http" 9 "strings" 10 11 "github.com/e154/smart-home/api/dto" 12 "github.com/e154/smart-home/system/access_list" 13 14 "github.com/iancoleman/strcase" 15 "github.com/labstack/echo/v4" 16 "github.com/pkg/errors" 17 18 "github.com/e154/smart-home/common" 19 "github.com/e154/smart-home/common/apperr" 20 "github.com/e154/smart-home/common/logger" 21 "github.com/e154/smart-home/endpoint" 22 m "github.com/e154/smart-home/models" 23 "github.com/e154/smart-home/system/validation" 24 ) 25 26 var ( 27 log = logger.MustGetLogger("controllers") 28 ) 29 30 // ControllerCommon ... 31 type ControllerCommon struct { 32 endpoint *endpoint.Endpoint 33 accessList access_list.AccessListService 34 validation *validation.Validate 35 dto dto.Dto 36 appConfig *m.AppConfig 37 } 38 39 // NewControllerCommon ... 40 func NewControllerCommon(endpoint *endpoint.Endpoint, 41 accessList access_list.AccessListService, 42 appConfig *m.AppConfig, 43 validation *validation.Validate) *ControllerCommon { 44 return &ControllerCommon{ 45 endpoint: endpoint, 46 appConfig: appConfig, 47 validation: validation, 48 accessList: accessList, 49 dto: dto.NewDto(), 50 } 51 } 52 53 func (c ControllerCommon) Body(ctx echo.Context, obj interface{}) error { 54 dec := json.NewDecoder(ctx.Request().Body) 55 if err := dec.Decode(obj); err != nil { 56 if strings.Contains(err.Error(), "unknown field") { 57 return apperr.ErrorWithCode("BAD_REQUEST", err.Error(), apperr.ErrUnknownField) 58 } 59 return apperr.ErrorWithCode("BAD_JSON_REQUEST", err.Error(), apperr.ErrBadJSONRequest) 60 } 61 return nil 62 } 63 64 // HTTP200 ... 65 func (c ControllerCommon) HTTP200(ctx echo.Context, data interface{}) error { 66 return ctx.JSON(http.StatusOK, data) 67 } 68 69 // HTTP201 ... 70 func (c ControllerCommon) HTTP201(ctx echo.Context, data interface{}) error { 71 return ctx.JSON(http.StatusCreated, data) 72 } 73 74 // HTTP401 ... 75 func (c ControllerCommon) HTTP401(ctx echo.Context, err error) error { 76 e := apperr.GetError(err) 77 if e != nil { 78 return ctx.JSON(http.StatusUnauthorized, ResponseWithError(ctx, &ErrorBase{ 79 Code: common.String(e.Code()), 80 Message: common.String(e.Message()), 81 })) 82 } 83 return ctx.JSON(http.StatusUnauthorized, ResponseWithError(ctx, &ErrorBase{ 84 Code: common.String("UNAUTHORIZED"), 85 })) 86 } 87 88 // HTTP403 ... 89 func (c ControllerCommon) HTTP403(ctx echo.Context, err error) error { 90 e := apperr.GetError(err) 91 if e != nil { 92 return ctx.JSON(http.StatusForbidden, ResponseWithError(ctx, &ErrorBase{ 93 Code: common.String(e.Code()), 94 Message: common.String(e.Message()), 95 })) 96 } 97 return ctx.JSON(http.StatusForbidden, ResponseWithError(ctx, &ErrorBase{ 98 Code: common.String("ACCESS_FORBIDDEN"), 99 })) 100 } 101 102 // HTTP404 ... 103 func (c ControllerCommon) HTTP404(ctx echo.Context, err error) error { 104 code := common.String("NOT_FOUND") 105 message := common.String(err.Error()) 106 e := apperr.GetError(err) 107 if e != nil { 108 code = common.String(e.Code()) 109 message = common.String(e.Message()) 110 } 111 return ctx.JSON(http.StatusNotFound, ResponseWithError(ctx, &ErrorBase{ 112 Code: code, 113 Message: message, 114 })) 115 } 116 117 // HTTP400 ... 118 func (c ControllerCommon) HTTP400(ctx echo.Context, err error) error { 119 code := common.String("BAD_REQUEST") 120 message := common.String(err.Error()) 121 e := apperr.GetError(err) 122 if e != nil { 123 code = common.String(e.Code()) 124 message = common.String(e.Message()) 125 } 126 return ctx.JSON(http.StatusBadRequest, ResponseWithError(ctx, &ErrorBase{ 127 Code: code, 128 Message: message, 129 })) 130 } 131 132 // HTTP409 ... 133 func (c ControllerCommon) HTTP409(ctx echo.Context, err error) error { 134 code := common.String("CONFLICT") 135 message := common.String(err.Error()) 136 e := apperr.GetError(err) 137 if e != nil { 138 code = common.String(e.Code()) 139 message = common.String(e.Message()) 140 } 141 return ctx.JSON(http.StatusConflict, ResponseWithError(ctx, &ErrorBase{ 142 Code: code, 143 Message: message, 144 })) 145 } 146 147 // HTTP500 ... 148 func (c ControllerCommon) HTTP500(ctx echo.Context, err error) error { 149 code := common.String("INTERNAL_ERROR") 150 message := common.String(err.Error()) 151 e := apperr.GetError(err) 152 if e != nil { 153 code = common.String(e.Code()) 154 message = common.String(e.Message()) 155 } 156 return ctx.JSON(http.StatusInternalServerError, ResponseWithError(ctx, &ErrorBase{ 157 Code: code, 158 Message: message, 159 })) 160 } 161 162 // HTTP422 ... 163 func (c ControllerCommon) HTTP422(ctx echo.Context, err error) error { 164 165 var fields []ErrorField 166 167 respErr := ErrorBase{ 168 Code: common.String("UNPROCESSABLE_ERROR"), 169 } 170 171 e := apperr.GetError(err) 172 if e != nil { 173 errs := e.ValidationErrors() 174 175 for fieldName, desc := range errs { 176 // update field name 177 fieldNameArr := strings.Split(fieldName, ".") 178 fieldName = fieldNameArr[len(fieldNameArr)-1] 179 180 fields = append(fields, ErrorField{ 181 Name: common.String(fieldName), 182 Message: common.String(desc), 183 }) 184 } 185 186 respErr.Code = common.String(e.Code()) 187 respErr.Message = common.String(e.Message()) 188 respErr.Fields = fields 189 } 190 191 return ctx.JSON(http.StatusUnprocessableEntity, ResponseWithError(ctx, &respErr)) 192 } 193 194 // HTTP501 ... 195 func (c ControllerCommon) HTTP501(ctx echo.Context, data interface{}) error { 196 return ctx.JSON(http.StatusNotImplemented, data) 197 } 198 199 // Pagination ... 200 func (c ControllerCommon) Pagination(page, limit *uint64, sort *string) (pagination common.PageParams) { 201 202 pagination = common.PageParams{ 203 Limit: 200, 204 Offset: 0, 205 Order: "desc", 206 SortBy: "created_at", 207 PageReq: 1, 208 SortReq: "-created_at", 209 } 210 211 if limit != nil { 212 pagination.Limit = int64(*limit) 213 } 214 if page != nil { 215 pagination.PageReq = int64(*page) 216 } 217 218 pagination.Offset = pagination.Limit * (pagination.PageReq - 1) 219 if pagination.Offset < 0 { 220 pagination.Offset = 0 221 } 222 223 if sort != nil && len(*sort) > 1 { 224 pagination.SortReq = *sort 225 firstChar := string([]rune(*sort)[0]) 226 switch firstChar { 227 case "+": 228 pagination.Order = "asc" 229 case "-": 230 pagination.Order = "desc" 231 } 232 233 // ToSnake converts a string to snake_case 234 pagination.SortBy = strcase.ToSnake(strings.Replace(*sort, firstChar, "", 1)) 235 } 236 237 return 238 } 239 240 // Search ... 241 func (c ControllerCommon) Search(query *string, limit, offset *int64) (search common.SearchParams) { 242 243 search = common.SearchParams{ 244 Query: common.StringValue(query), 245 Limit: 200, 246 Offset: 0, 247 } 248 249 if limit != nil { 250 search.Limit = common.Int64Value(limit) 251 } 252 if offset != nil { 253 search.Offset = common.Int64Value(offset) 254 } 255 256 return 257 } 258 259 // ERROR ... 260 func (c ControllerCommon) ERROR(ctx echo.Context, err error) error { 261 switch { 262 case errors.Is(err, apperr.ErrUnknownField): 263 return c.HTTP400(ctx, err) 264 case errors.Is(err, apperr.ErrBadJSONRequest): 265 return c.HTTP400(ctx, err) 266 case errors.Is(err, apperr.ErrAccessDenied): 267 return c.HTTP401(ctx, err) 268 case errors.Is(err, apperr.ErrAccessForbidden): 269 return c.HTTP403(ctx, err) 270 case errors.Is(err, apperr.ErrNotFound): 271 return c.HTTP404(ctx, err) 272 case errors.Is(err, apperr.ErrAlreadyExists): 273 return c.HTTP409(ctx, err) 274 case errors.Is(err, apperr.ErrInvalidRequest): 275 return c.HTTP422(ctx, err) 276 case errors.Is(err, apperr.ErrInternal): 277 return c.HTTP500(ctx, err) 278 default: 279 var bodyStr string 280 body, _ := io.ReadAll(ctx.Request().Body) 281 if len(body) > 0 { 282 bodyStr = string(body) 283 } 284 url := ctx.Request().URL.String() 285 log.Warnf("unknown err type %v for uri %s and body %q", err, url, bodyStr) 286 } 287 log.Error(err.Error()) 288 return nil 289 } 290 291 func (c ControllerCommon) currentUser(ctx echo.Context) (*m.User, error) { 292 293 user, ok := ctx.Get("currentUser").(*m.User) 294 if !ok { 295 return nil, errors.Wrap(apperr.ErrBadRequestParams, "bad user object") 296 } 297 298 return user, nil 299 } 300 301 func (c ControllerCommon) parseBasicAuth(auth string) (username, password string, ok bool) { 302 const prefix = "Basic " 303 // Case insensitive prefix match. See Issue 22736. 304 if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { 305 return 306 } 307 str, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) 308 if err != nil { 309 return 310 } 311 cs := string(str) 312 s := strings.IndexByte(cs, ':') 313 if s < 0 { 314 return 315 } 316 317 return cs[:s], cs[s+1:], true 318 } 319 320 type contextValue struct { 321 echo.Context 322 } 323 324 func NewMiddlewareContextValue(fn echo.HandlerFunc) echo.HandlerFunc { 325 return func(ctx echo.Context) error { 326 return fn(contextValue{ctx}) 327 } 328 } 329 330 // Get retrieves data from the context. 331 func (ctx contextValue) Get(key string) interface{} { 332 // get old context value 333 val := ctx.Context.Get(key) 334 if val != nil { 335 return val 336 } 337 return ctx.Request().Context().Value(key) 338 } 339 340 // Set saves data in the context. 341 func (ctx contextValue) Set(key string, val interface{}) { 342 ctx.SetRequest(ctx.Request().WithContext(context.WithValue(ctx.Request().Context(), key, val))) 343 }