github.com/e154/smart-home@v0.17.2-0.20240311175135-e530a6e5cd45/api/controllers/common.go (about)

     1  package controllers
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  
    11  	"github.com/e154/smart-home/api/dto"
    12  	"github.com/e154/smart-home/system/access_list"
    13  
    14  	"github.com/iancoleman/strcase"
    15  	"github.com/labstack/echo/v4"
    16  	"github.com/pkg/errors"
    17  
    18  	"github.com/e154/smart-home/common"
    19  	"github.com/e154/smart-home/common/apperr"
    20  	"github.com/e154/smart-home/common/logger"
    21  	"github.com/e154/smart-home/endpoint"
    22  	m "github.com/e154/smart-home/models"
    23  	"github.com/e154/smart-home/system/validation"
    24  )
    25  
    26  var (
    27  	log = logger.MustGetLogger("controllers")
    28  )
    29  
    30  // ControllerCommon ...
    31  type ControllerCommon struct {
    32  	endpoint   *endpoint.Endpoint
    33  	accessList access_list.AccessListService
    34  	validation *validation.Validate
    35  	dto        dto.Dto
    36  	appConfig  *m.AppConfig
    37  }
    38  
    39  // NewControllerCommon ...
    40  func NewControllerCommon(endpoint *endpoint.Endpoint,
    41  	accessList access_list.AccessListService,
    42  	appConfig *m.AppConfig,
    43  	validation *validation.Validate) *ControllerCommon {
    44  	return &ControllerCommon{
    45  		endpoint:   endpoint,
    46  		appConfig:  appConfig,
    47  		validation: validation,
    48  		accessList: accessList,
    49  		dto:        dto.NewDto(),
    50  	}
    51  }
    52  
    53  func (c ControllerCommon) Body(ctx echo.Context, obj interface{}) error {
    54  	dec := json.NewDecoder(ctx.Request().Body)
    55  	if err := dec.Decode(obj); err != nil {
    56  		if strings.Contains(err.Error(), "unknown field") {
    57  			return apperr.ErrorWithCode("BAD_REQUEST", err.Error(), apperr.ErrUnknownField)
    58  		}
    59  		return apperr.ErrorWithCode("BAD_JSON_REQUEST", err.Error(), apperr.ErrBadJSONRequest)
    60  	}
    61  	return nil
    62  }
    63  
    64  // HTTP200 ...
    65  func (c ControllerCommon) HTTP200(ctx echo.Context, data interface{}) error {
    66  	return ctx.JSON(http.StatusOK, data)
    67  }
    68  
    69  // HTTP201 ...
    70  func (c ControllerCommon) HTTP201(ctx echo.Context, data interface{}) error {
    71  	return ctx.JSON(http.StatusCreated, data)
    72  }
    73  
    74  // HTTP401 ...
    75  func (c ControllerCommon) HTTP401(ctx echo.Context, err error) error {
    76  	e := apperr.GetError(err)
    77  	if e != nil {
    78  		return ctx.JSON(http.StatusUnauthorized, ResponseWithError(ctx, &ErrorBase{
    79  			Code:    common.String(e.Code()),
    80  			Message: common.String(e.Message()),
    81  		}))
    82  	}
    83  	return ctx.JSON(http.StatusUnauthorized, ResponseWithError(ctx, &ErrorBase{
    84  		Code: common.String("UNAUTHORIZED"),
    85  	}))
    86  }
    87  
    88  // HTTP403 ...
    89  func (c ControllerCommon) HTTP403(ctx echo.Context, err error) error {
    90  	e := apperr.GetError(err)
    91  	if e != nil {
    92  		return ctx.JSON(http.StatusForbidden, ResponseWithError(ctx, &ErrorBase{
    93  			Code:    common.String(e.Code()),
    94  			Message: common.String(e.Message()),
    95  		}))
    96  	}
    97  	return ctx.JSON(http.StatusForbidden, ResponseWithError(ctx, &ErrorBase{
    98  		Code: common.String("ACCESS_FORBIDDEN"),
    99  	}))
   100  }
   101  
   102  // HTTP404 ...
   103  func (c ControllerCommon) HTTP404(ctx echo.Context, err error) error {
   104  	code := common.String("NOT_FOUND")
   105  	message := common.String(err.Error())
   106  	e := apperr.GetError(err)
   107  	if e != nil {
   108  		code = common.String(e.Code())
   109  		message = common.String(e.Message())
   110  	}
   111  	return ctx.JSON(http.StatusNotFound, ResponseWithError(ctx, &ErrorBase{
   112  		Code:    code,
   113  		Message: message,
   114  	}))
   115  }
   116  
   117  // HTTP400 ...
   118  func (c ControllerCommon) HTTP400(ctx echo.Context, err error) error {
   119  	code := common.String("BAD_REQUEST")
   120  	message := common.String(err.Error())
   121  	e := apperr.GetError(err)
   122  	if e != nil {
   123  		code = common.String(e.Code())
   124  		message = common.String(e.Message())
   125  	}
   126  	return ctx.JSON(http.StatusBadRequest, ResponseWithError(ctx, &ErrorBase{
   127  		Code:    code,
   128  		Message: message,
   129  	}))
   130  }
   131  
   132  // HTTP409 ...
   133  func (c ControllerCommon) HTTP409(ctx echo.Context, err error) error {
   134  	code := common.String("CONFLICT")
   135  	message := common.String(err.Error())
   136  	e := apperr.GetError(err)
   137  	if e != nil {
   138  		code = common.String(e.Code())
   139  		message = common.String(e.Message())
   140  	}
   141  	return ctx.JSON(http.StatusConflict, ResponseWithError(ctx, &ErrorBase{
   142  		Code:    code,
   143  		Message: message,
   144  	}))
   145  }
   146  
   147  // HTTP500 ...
   148  func (c ControllerCommon) HTTP500(ctx echo.Context, err error) error {
   149  	code := common.String("INTERNAL_ERROR")
   150  	message := common.String(err.Error())
   151  	e := apperr.GetError(err)
   152  	if e != nil {
   153  		code = common.String(e.Code())
   154  		message = common.String(e.Message())
   155  	}
   156  	return ctx.JSON(http.StatusInternalServerError, ResponseWithError(ctx, &ErrorBase{
   157  		Code:    code,
   158  		Message: message,
   159  	}))
   160  }
   161  
   162  // HTTP422 ...
   163  func (c ControllerCommon) HTTP422(ctx echo.Context, err error) error {
   164  
   165  	var fields []ErrorField
   166  
   167  	respErr := ErrorBase{
   168  		Code: common.String("UNPROCESSABLE_ERROR"),
   169  	}
   170  
   171  	e := apperr.GetError(err)
   172  	if e != nil {
   173  		errs := e.ValidationErrors()
   174  
   175  		for fieldName, desc := range errs {
   176  			// update field name
   177  			fieldNameArr := strings.Split(fieldName, ".")
   178  			fieldName = fieldNameArr[len(fieldNameArr)-1]
   179  
   180  			fields = append(fields, ErrorField{
   181  				Name:    common.String(fieldName),
   182  				Message: common.String(desc),
   183  			})
   184  		}
   185  
   186  		respErr.Code = common.String(e.Code())
   187  		respErr.Message = common.String(e.Message())
   188  		respErr.Fields = fields
   189  	}
   190  
   191  	return ctx.JSON(http.StatusUnprocessableEntity, ResponseWithError(ctx, &respErr))
   192  }
   193  
   194  // HTTP501 ...
   195  func (c ControllerCommon) HTTP501(ctx echo.Context, data interface{}) error {
   196  	return ctx.JSON(http.StatusNotImplemented, data)
   197  }
   198  
   199  // Pagination ...
   200  func (c ControllerCommon) Pagination(page, limit *uint64, sort *string) (pagination common.PageParams) {
   201  
   202  	pagination = common.PageParams{
   203  		Limit:   200,
   204  		Offset:  0,
   205  		Order:   "desc",
   206  		SortBy:  "created_at",
   207  		PageReq: 1,
   208  		SortReq: "-created_at",
   209  	}
   210  
   211  	if limit != nil {
   212  		pagination.Limit = int64(*limit)
   213  	}
   214  	if page != nil {
   215  		pagination.PageReq = int64(*page)
   216  	}
   217  
   218  	pagination.Offset = pagination.Limit * (pagination.PageReq - 1)
   219  	if pagination.Offset < 0 {
   220  		pagination.Offset = 0
   221  	}
   222  
   223  	if sort != nil && len(*sort) > 1 {
   224  		pagination.SortReq = *sort
   225  		firstChar := string([]rune(*sort)[0])
   226  		switch firstChar {
   227  		case "+":
   228  			pagination.Order = "asc"
   229  		case "-":
   230  			pagination.Order = "desc"
   231  		}
   232  
   233  		// ToSnake converts a string to snake_case
   234  		pagination.SortBy = strcase.ToSnake(strings.Replace(*sort, firstChar, "", 1))
   235  	}
   236  
   237  	return
   238  }
   239  
   240  // Search ...
   241  func (c ControllerCommon) Search(query *string, limit, offset *int64) (search common.SearchParams) {
   242  
   243  	search = common.SearchParams{
   244  		Query:  common.StringValue(query),
   245  		Limit:  200,
   246  		Offset: 0,
   247  	}
   248  
   249  	if limit != nil {
   250  		search.Limit = common.Int64Value(limit)
   251  	}
   252  	if offset != nil {
   253  		search.Offset = common.Int64Value(offset)
   254  	}
   255  
   256  	return
   257  }
   258  
   259  // ERROR ...
   260  func (c ControllerCommon) ERROR(ctx echo.Context, err error) error {
   261  	switch {
   262  	case errors.Is(err, apperr.ErrUnknownField):
   263  		return c.HTTP400(ctx, err)
   264  	case errors.Is(err, apperr.ErrBadJSONRequest):
   265  		return c.HTTP400(ctx, err)
   266  	case errors.Is(err, apperr.ErrAccessDenied):
   267  		return c.HTTP401(ctx, err)
   268  	case errors.Is(err, apperr.ErrAccessForbidden):
   269  		return c.HTTP403(ctx, err)
   270  	case errors.Is(err, apperr.ErrNotFound):
   271  		return c.HTTP404(ctx, err)
   272  	case errors.Is(err, apperr.ErrAlreadyExists):
   273  		return c.HTTP409(ctx, err)
   274  	case errors.Is(err, apperr.ErrInvalidRequest):
   275  		return c.HTTP422(ctx, err)
   276  	case errors.Is(err, apperr.ErrInternal):
   277  		return c.HTTP500(ctx, err)
   278  	default:
   279  		var bodyStr string
   280  		body, _ := io.ReadAll(ctx.Request().Body)
   281  		if len(body) > 0 {
   282  			bodyStr = string(body)
   283  		}
   284  		url := ctx.Request().URL.String()
   285  		log.Warnf("unknown err type %v for uri %s and body %q", err, url, bodyStr)
   286  	}
   287  	log.Error(err.Error())
   288  	return nil
   289  }
   290  
   291  func (c ControllerCommon) currentUser(ctx echo.Context) (*m.User, error) {
   292  
   293  	user, ok := ctx.Get("currentUser").(*m.User)
   294  	if !ok {
   295  		return nil, errors.Wrap(apperr.ErrBadRequestParams, "bad user object")
   296  	}
   297  
   298  	return user, nil
   299  }
   300  
   301  func (c ControllerCommon) parseBasicAuth(auth string) (username, password string, ok bool) {
   302  	const prefix = "Basic "
   303  	// Case insensitive prefix match. See Issue 22736.
   304  	if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
   305  		return
   306  	}
   307  	str, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
   308  	if err != nil {
   309  		return
   310  	}
   311  	cs := string(str)
   312  	s := strings.IndexByte(cs, ':')
   313  	if s < 0 {
   314  		return
   315  	}
   316  
   317  	return cs[:s], cs[s+1:], true
   318  }
   319  
   320  type contextValue struct {
   321  	echo.Context
   322  }
   323  
   324  func NewMiddlewareContextValue(fn echo.HandlerFunc) echo.HandlerFunc {
   325  	return func(ctx echo.Context) error {
   326  		return fn(contextValue{ctx})
   327  	}
   328  }
   329  
   330  // Get retrieves data from the context.
   331  func (ctx contextValue) Get(key string) interface{} {
   332  	// get old context value
   333  	val := ctx.Context.Get(key)
   334  	if val != nil {
   335  		return val
   336  	}
   337  	return ctx.Request().Context().Value(key)
   338  }
   339  
   340  // Set saves data in the context.
   341  func (ctx contextValue) Set(key string, val interface{}) {
   342  	ctx.SetRequest(ctx.Request().WithContext(context.WithValue(ctx.Request().Context(), key, val)))
   343  }