github.com/godaddy-x/freego@v1.0.156/node/filter.go (about)

     1  package node
     2  
     3  import (
     4  	"fmt"
     5  	rate "github.com/godaddy-x/freego/cache/limiter"
     6  	"github.com/godaddy-x/freego/ex"
     7  	"github.com/godaddy-x/freego/utils"
     8  	"github.com/godaddy-x/freego/utils/concurrent"
     9  	"github.com/godaddy-x/freego/zlog"
    10  	"math"
    11  	"net/http"
    12  )
    13  
    14  const (
    15  	GatewayRateLimiterFilterName = "GatewayRateLimiterFilter"
    16  	ParameterFilterName          = "ParameterFilter"
    17  	SessionFilterName            = "SessionFilter"
    18  	UserRateLimiterFilterName    = "UserRateLimiterFilter"
    19  	RoleFilterName               = "RoleFilter"
    20  	PostHandleFilterName         = "PostHandleFilter"
    21  	RenderHandleFilterName       = "RenderHandleFilter"
    22  )
    23  
    24  var filterMap = map[string]*FilterObject{
    25  	GatewayRateLimiterFilterName: {Name: GatewayRateLimiterFilterName, Order: -100, Filter: &GatewayRateLimiterFilter{}},
    26  	ParameterFilterName:          {Name: ParameterFilterName, Order: -90, Filter: &ParameterFilter{}},
    27  	SessionFilterName:            {Name: SessionFilterName, Order: -80, Filter: &SessionFilter{}},
    28  	UserRateLimiterFilterName:    {Name: UserRateLimiterFilterName, Order: -70, Filter: &UserRateLimiterFilter{}},
    29  	RoleFilterName:               {Name: RoleFilterName, Order: -60, Filter: &RoleFilter{}},
    30  	PostHandleFilterName:         {Name: PostHandleFilterName, Order: math.MaxInt, Filter: &PostHandleFilter{}},
    31  	RenderHandleFilterName:       {Name: RenderHandleFilterName, Order: math.MinInt, Filter: &RenderHandleFilter{}},
    32  }
    33  
    34  type FilterObject struct {
    35  	Name         string
    36  	Order        int
    37  	Filter       Filter
    38  	MatchPattern []string
    39  }
    40  
    41  func createFilterChain(extFilters []*FilterObject) ([]*FilterObject, error) {
    42  	var filters []*FilterObject
    43  	var fs []interface{}
    44  	for _, v := range filterMap {
    45  		extFilters = append(extFilters, v)
    46  	}
    47  	for _, v := range extFilters {
    48  		for _, check := range fs {
    49  			if check.(*FilterObject).Name == v.Name {
    50  				panic("filter name exist: " + v.Name)
    51  			}
    52  		}
    53  		fs = append(fs, v)
    54  	}
    55  	fs = concurrent.NewSorter(fs, func(a, b interface{}) bool {
    56  		o1 := a.(*FilterObject)
    57  		o2 := b.(*FilterObject)
    58  		return o1.Order < o2.Order
    59  	}).Sort()
    60  	for _, f := range fs {
    61  		v := f.(*FilterObject)
    62  		filters = append(filters, v)
    63  		zlog.Printf("add filter [%s] successful", v.Name)
    64  	}
    65  	if len(filters) == 0 {
    66  		return nil, utils.Error("filter chain is nil")
    67  	}
    68  	return filters, nil
    69  }
    70  
    71  type Filter interface {
    72  	DoFilter(chain Filter, ctx *Context, args ...interface{}) error
    73  }
    74  
    75  type filterChain struct {
    76  	pos     int
    77  	filters []*FilterObject
    78  }
    79  
    80  func (self *filterChain) DoFilter(chain Filter, ctx *Context, args ...interface{}) error {
    81  	fs := self.filters
    82  	for self.pos < len(fs) {
    83  		f := fs[self.pos]
    84  		if f == nil || f.Filter == nil {
    85  			return ex.Throw{Code: ex.SYSTEM, Msg: fmt.Sprintf("filter [%s] is nil", f.Name)}
    86  		}
    87  		self.pos++
    88  		if !utils.MatchFilterURL(ctx.Path, f.MatchPattern) {
    89  			continue
    90  		}
    91  		return f.Filter.DoFilter(chain, ctx, args...)
    92  	}
    93  	return nil
    94  }
    95  
    96  type GatewayRateLimiterFilter struct{}
    97  type ParameterFilter struct{}
    98  type SessionFilter struct{}
    99  type UserRateLimiterFilter struct{}
   100  type RoleFilter struct{}
   101  type PostHandleFilter struct{}
   102  type RenderHandleFilter struct{}
   103  
   104  var (
   105  	gatewayRateLimiter = rate.NewRateLimiter(rate.Option{Limit: 200, Bucket: 2000, Expire: 30, Distributed: true})
   106  	methodRateLimiter  = rate.NewRateLimiter(rate.Option{Limit: 200, Bucket: 2000, Expire: 30, Distributed: true})
   107  	userRateLimiter    = rate.NewRateLimiter(rate.Option{Limit: 5, Bucket: 10, Expire: 30, Distributed: true})
   108  )
   109  
   110  func SetGatewayRateLimiter(option rate.Option) {
   111  	gatewayRateLimiter = rate.NewRateLimiter(option)
   112  }
   113  
   114  func SetMethodRateLimiter(option rate.Option) {
   115  	methodRateLimiter = rate.NewRateLimiter(option)
   116  }
   117  
   118  func SetUserRateLimiter(option rate.Option) {
   119  	userRateLimiter = rate.NewRateLimiter(option)
   120  }
   121  
   122  func (self *GatewayRateLimiterFilter) DoFilter(chain Filter, ctx *Context, args ...interface{}) error {
   123  	//if b := gatewayRateLimiter.Allow("HttpThreshold"); !b {
   124  	//	return ex.Throw{Code: 429, Msg: "the gateway request is full, please try again later"}
   125  	//}
   126  	return chain.DoFilter(chain, ctx, args...)
   127  }
   128  
   129  func (self *ParameterFilter) DoFilter(chain Filter, ctx *Context, args ...interface{}) error {
   130  	if err := ctx.readParams(); err != nil {
   131  		return err
   132  	}
   133  	return chain.DoFilter(chain, ctx, args...)
   134  }
   135  
   136  func (self *SessionFilter) DoFilter(chain Filter, ctx *Context, args ...interface{}) error {
   137  	if ctx.RouterConfig.UseRSA || ctx.RouterConfig.UseHAX || ctx.RouterConfig.Guest { // 登录接口和游客模式跳过会话认证
   138  		return chain.DoFilter(chain, ctx, args...)
   139  	}
   140  	if len(ctx.Subject.GetRawBytes()) == 0 {
   141  		return ex.Throw{Code: http.StatusUnauthorized, Msg: "token is nil"}
   142  	}
   143  	if err := ctx.Subject.Verify(utils.Bytes2Str(ctx.Subject.GetRawBytes()), ctx.GetJwtConfig().TokenKey, true); err != nil {
   144  		return ex.Throw{Code: http.StatusUnauthorized, Msg: "token invalid or expired", Err: err}
   145  	}
   146  	return chain.DoFilter(chain, ctx, args...)
   147  }
   148  
   149  func (self *UserRateLimiterFilter) DoFilter(chain Filter, ctx *Context, args ...interface{}) error {
   150  	//if b := methodRateLimiter.Allow(ctx.Path); !b {
   151  	//	return ex.Throw{Code: 429, Msg: "the method request is full, please try again later"}
   152  	//}
   153  	//if ctx.Authenticated() {
   154  	//	if b := userRateLimiter.Allow(ctx.Subject.Sub); !b {
   155  	//		return ex.Throw{Code: 429, Msg: "the access frequency is too fast, please try again later"}
   156  	//	}
   157  	//}
   158  	return chain.DoFilter(chain, ctx, args...)
   159  }
   160  
   161  func (self *RoleFilter) DoFilter(chain Filter, ctx *Context, args ...interface{}) error {
   162  	if ctx.roleRealm == nil || !ctx.Authenticated() { // 未配置权限方法或非登录状态跳过
   163  		return chain.DoFilter(chain, ctx, args...)
   164  	}
   165  	need, err := ctx.roleRealm(ctx, false)
   166  	if err != nil {
   167  		return err
   168  	}
   169  	if need == nil { // 无授权资源配置,跳过
   170  		return chain.DoFilter(chain, ctx, args...)
   171  	}
   172  	if len(need.NeedRole) == 0 { // 无授权角色配置跳过
   173  		return chain.DoFilter(chain, ctx, args...)
   174  	}
   175  	//if !need.NeedLogin { // 无登录状态,跳过
   176  	//	return chain.DoFilter(chain, ctx, args...)
   177  	//} else if !ctx.Authenticated() { // 需要登录状态,会话为空,抛出异常
   178  	//	return ex.Throw{Code: http.StatusUnauthorized, Msg: "login status required"}
   179  	//}
   180  	has, err := ctx.roleRealm(ctx, true)
   181  	if err != nil {
   182  		return err
   183  	}
   184  	var hasRoles []int64
   185  	if has != nil && len(has.HasRole) > 0 {
   186  		hasRoles = has.HasRole
   187  	}
   188  	accessCount := 0
   189  	needAccess := len(need.NeedRole)
   190  	for _, hasRole := range hasRoles {
   191  		for _, needRole := range need.NeedRole {
   192  			if hasRole == needRole {
   193  				accessCount++
   194  				if !need.MatchAll || accessCount == needAccess { // 任意授权通过则放行,或已满足授权长度
   195  					return chain.DoFilter(chain, ctx, args...)
   196  				}
   197  			}
   198  		}
   199  	}
   200  	return ex.Throw{Code: http.StatusUnauthorized, Msg: "access defined"}
   201  }
   202  
   203  func (self *PostHandleFilter) DoFilter(chain Filter, ctx *Context, args ...interface{}) error {
   204  	if err := ctx.Handle(); err != nil {
   205  		return err
   206  	}
   207  	return chain.DoFilter(chain, ctx, args...)
   208  }
   209  
   210  func (self *RenderHandleFilter) DoFilter(chain Filter, ctx *Context, args ...interface{}) error {
   211  	err := chain.DoFilter(chain, ctx, args...)
   212  	if err == nil {
   213  		err = defaultRenderPre(ctx)
   214  	}
   215  	if err != nil {
   216  		err = defaultRenderError(ctx, err)
   217  	}
   218  	return defaultRenderTo(ctx)
   219  }