
     1  package waf
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	""
     7  	teaconst ""
     8  	""
     9  	""
    10  	""
    11  	""
    12  	""
    13  	"net/http"
    14  	"os"
    15  	"reflect"
    16  )
    18  type WAF struct {
    19  	Id               int64                           `yaml:"id" json:"id"`
    20  	IsOn             bool                            `yaml:"isOn" json:"isOn"`
    21  	Name             string                          `yaml:"name" json:"name"`
    22  	Inbound          []*RuleGroup                    `yaml:"inbound" json:"inbound"`
    23  	Outbound         []*RuleGroup                    `yaml:"outbound" json:"outbound"`
    24  	CreatedVersion   string                          `yaml:"createdVersion" json:"createdVersion"`
    25  	Mode             firewallconfigs.FirewallMode    `yaml:"mode" json:"mode"`
    26  	UseLocalFirewall bool                            `yaml:"useLocalFirewall" json:"useLocalFirewall"`
    27  	SYNFlood         *firewallconfigs.SYNFloodConfig `yaml:"synFlood" json:"synFlood"`
    29  	DefaultBlockAction    *BlockAction
    30  	DefaultPageAction     *PageAction
    31  	DefaultCaptchaAction  *CaptchaAction
    32  	DefaultJSCookieAction *JSCookieAction
    33  	DefaultPost307Action  *Post307Action
    34  	DefaultGet302Action   *Get302Action
    36  	hasInboundRules  bool
    37  	hasOutboundRules bool
    39  	checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint
    40  	actionMap      map[int64]ActionInterface                  // actionId => ActionInterface
    41  }
    43  func NewWAF() *WAF {
    44  	return &WAF{
    45  		IsOn:      true,
    46  		actionMap: map[int64]ActionInterface{},
    47  	}
    48  }
    50  func NewWAFFromFile(path string) (waf *WAF, err error) {
    51  	if len(path) == 0 {
    52  		return nil, errors.New("'path' should not be empty")
    53  	}
    54  	file := files.NewFile(path)
    55  	if !file.Exists() {
    56  		return nil, errors.New("'" + path + "' not exist")
    57  	}
    59  	reader, err := file.Reader()
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	defer func() {
    64  		_ = reader.Close()
    65  	}()
    67  	waf = &WAF{}
    68  	err = reader.ReadYAML(waf)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	return waf, nil
    73  }
    75  func (this *WAF) Init() (resultErrors []error) {
    76  	// checkpoint
    77  	this.checkpointsMap = map[string]checkpoints.CheckpointInterface{}
    78  	for _, def := range checkpoints.AllCheckpoints {
    79  		instance := reflect.New(reflect.Indirect(reflect.ValueOf(def.Instance)).Type()).Interface().(checkpoints.CheckpointInterface)
    80  		instance.Init()
    81  		instance.SetPriority(def.Priority)
    82  		this.checkpointsMap[def.Prefix] = instance
    83  	}
    85  	// action map
    86  	this.actionMap = map[int64]ActionInterface{}
    88  	// rules
    89  	this.hasInboundRules = len(this.Inbound) > 0
    90  	this.hasOutboundRules = len(this.Outbound) > 0
    92  	if this.hasInboundRules {
    93  		for _, group := range this.Inbound {
    94  			// finder
    95  			for _, set := range group.RuleSets {
    96  				for _, rule := range set.Rules {
    97  					rule.SetCheckpointFinder(this.FindCheckpointInstance)
    98  				}
    99  			}
   101  			err := group.Init(this)
   102  			if err != nil {
   103  				// 这里我们不阻止其他规则正常加入
   104  				resultErrors = append(resultErrors, fmt.Errorf("init group '%d' failed: %w", group.Id, err))
   105  			}
   106  		}
   107  	}
   109  	if this.hasOutboundRules {
   110  		for _, group := range this.Outbound {
   111  			// finder
   112  			for _, set := range group.RuleSets {
   113  				for _, rule := range set.Rules {
   114  					rule.SetCheckpointFinder(this.FindCheckpointInstance)
   115  				}
   116  			}
   118  			err := group.Init(this)
   119  			if err != nil {
   120  				// 这里我们不阻止其他规则正常加入
   121  				resultErrors = append(resultErrors, err)
   122  			}
   123  		}
   124  	}
   126  	return nil
   127  }
   129  func (this *WAF) AddRuleGroup(ruleGroup *RuleGroup) {
   130  	if ruleGroup.IsInbound {
   131  		this.Inbound = append(this.Inbound, ruleGroup)
   132  	} else {
   133  		this.Outbound = append(this.Outbound, ruleGroup)
   134  	}
   135  }
   137  func (this *WAF) RemoveRuleGroup(ruleGroupId int64) {
   138  	{
   139  		result := []*RuleGroup{}
   140  		for _, group := range this.Inbound {
   141  			if group.Id == ruleGroupId {
   142  				continue
   143  			}
   144  			result = append(result, group)
   145  		}
   146  		this.Inbound = result
   147  	}
   149  	{
   150  		result := []*RuleGroup{}
   151  		for _, group := range this.Outbound {
   152  			if group.Id == ruleGroupId {
   153  				continue
   154  			}
   155  			result = append(result, group)
   156  		}
   157  		this.Outbound = result
   158  	}
   159  }
   161  func (this *WAF) FindRuleGroup(ruleGroupId int64) *RuleGroup {
   162  	for _, group := range this.Inbound {
   163  		if group.Id == ruleGroupId {
   164  			return group
   165  		}
   166  	}
   167  	for _, group := range this.Outbound {
   168  		if group.Id == ruleGroupId {
   169  			return group
   170  		}
   171  	}
   172  	return nil
   173  }
   175  func (this *WAF) FindRuleGroupWithCode(ruleGroupCode string) *RuleGroup {
   176  	if len(ruleGroupCode) == 0 {
   177  		return nil
   178  	}
   179  	for _, group := range this.Inbound {
   180  		if group.Code == ruleGroupCode {
   181  			return group
   182  		}
   183  	}
   184  	for _, group := range this.Outbound {
   185  		if group.Code == ruleGroupCode {
   186  			return group
   187  		}
   188  	}
   189  	return nil
   190  }
   192  func (this *WAF) MoveInboundRuleGroup(fromIndex int, toIndex int) {
   193  	if fromIndex < 0 || fromIndex >= len(this.Inbound) {
   194  		return
   195  	}
   196  	if toIndex < 0 || toIndex >= len(this.Inbound) {
   197  		return
   198  	}
   199  	if fromIndex == toIndex {
   200  		return
   201  	}
   203  	group := this.Inbound[fromIndex]
   204  	result := []*RuleGroup{}
   205  	for i := 0; i < len(this.Inbound); i++ {
   206  		if i == fromIndex {
   207  			continue
   208  		}
   209  		if fromIndex > toIndex && i == toIndex {
   210  			result = append(result, group)
   211  		}
   212  		result = append(result, this.Inbound[i])
   213  		if fromIndex < toIndex && i == toIndex {
   214  			result = append(result, group)
   215  		}
   216  	}
   218  	this.Inbound = result
   219  }
   221  func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) {
   222  	if fromIndex < 0 || fromIndex >= len(this.Outbound) {
   223  		return
   224  	}
   225  	if toIndex < 0 || toIndex >= len(this.Outbound) {
   226  		return
   227  	}
   228  	if fromIndex == toIndex {
   229  		return
   230  	}
   232  	group := this.Outbound[fromIndex]
   233  	result := []*RuleGroup{}
   234  	for i := 0; i < len(this.Outbound); i++ {
   235  		if i == fromIndex {
   236  			continue
   237  		}
   238  		if fromIndex > toIndex && i == toIndex {
   239  			result = append(result, group)
   240  		}
   241  		result = append(result, this.Outbound[i])
   242  		if fromIndex < toIndex && i == toIndex {
   243  			result = append(result, group)
   244  		}
   245  	}
   247  	this.Outbound = result
   248  }
   250  func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter, defaultCaptchaType firewallconfigs.ServerCaptchaType) (result MatchResult, err error) {
   251  	if !this.hasInboundRules {
   252  		return MatchResult{
   253  			GoNext: true,
   254  		}, nil
   255  	}
   257  	// validate captcha
   258  	var rawPath = req.WAFRaw().URL.Path
   259  	if rawPath == CaptchaPath {
   260  		req.DisableAccessLog()
   261  		req.DisableStat()
   262  		captchaValidator.Run(req, writer, defaultCaptchaType)
   263  		return
   264  	}
   266  	// Get 302验证
   267  	if rawPath == Get302Path {
   268  		req.DisableAccessLog()
   269  		req.DisableStat()
   270  		get302Validator.Run(req, writer)
   271  		return
   272  	}
   274  	// match rules
   275  	var hasRequestBody bool
   276  	for _, group := range this.Inbound {
   277  		if !group.IsOn {
   278  			continue
   279  		}
   280  		b, hasCheckedRequestBody, set, matchErr := group.MatchRequest(req)
   281  		if hasCheckedRequestBody {
   282  			hasRequestBody = true
   283  		}
   284  		if matchErr != nil {
   285  			return MatchResult{
   286  				GoNext:         true,
   287  				HasRequestBody: hasRequestBody,
   288  			}, matchErr
   289  		}
   290  		if b {
   291  			var performResult = set.PerformActions(this, group, req, writer)
   292  			if !performResult.GoNextSet {
   293  				if performResult.GoNextGroup {
   294  					continue
   295  				}
   296  				return MatchResult{
   297  					GoNext:         performResult.ContinueRequest,
   298  					HasRequestBody: hasRequestBody,
   299  					Group:          group,
   300  					Set:            set,
   301  					IsAllowed:      performResult.IsAllowed,
   302  					AllowScope:     performResult.AllowScope,
   303  				}, nil
   304  			}
   305  		}
   306  	}
   307  	return MatchResult{
   308  		GoNext:         true,
   309  		HasRequestBody: hasRequestBody,
   310  	}, nil
   311  }
   313  func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (result MatchResult, err error) {
   314  	if !this.hasOutboundRules {
   315  		return MatchResult{
   316  			GoNext: true,
   317  		}, nil
   318  	}
   319  	var hasRequestBody bool
   320  	var resp = requests.NewResponse(rawResp)
   321  	for _, group := range this.Outbound {
   322  		if !group.IsOn {
   323  			continue
   324  		}
   325  		b, hasCheckedRequestBody, set, matchErr := group.MatchResponse(req, resp)
   326  		if hasCheckedRequestBody {
   327  			hasRequestBody = true
   328  		}
   329  		if matchErr != nil {
   330  			return MatchResult{
   331  				GoNext:         true,
   332  				HasRequestBody: hasRequestBody,
   333  			}, matchErr
   334  		}
   335  		if b {
   336  			var performResult = set.PerformActions(this, group, req, writer)
   337  			if !performResult.GoNextSet {
   338  				if performResult.GoNextGroup {
   339  					continue
   340  				}
   341  				return MatchResult{
   342  					GoNext:         performResult.ContinueRequest,
   343  					HasRequestBody: hasRequestBody,
   344  					Group:          group,
   345  					Set:            set,
   346  					IsAllowed:      performResult.IsAllowed,
   347  					AllowScope:     performResult.AllowScope,
   348  				}, nil
   349  			}
   350  		}
   351  	}
   352  	return MatchResult{
   353  		GoNext:         true,
   354  		HasRequestBody: hasRequestBody,
   355  	}, nil
   356  }
   358  // Save to file path
   359  func (this *WAF) Save(path string) error {
   360  	if len(path) == 0 {
   361  		return errors.New("path should not be empty")
   362  	}
   363  	if len(this.CreatedVersion) == 0 {
   364  		this.CreatedVersion = teaconst.Version
   365  	}
   366  	data, err := yaml.Marshal(this)
   367  	if err != nil {
   368  		return err
   369  	}
   370  	return os.WriteFile(path, data, 0644)
   371  }
   373  func (this *WAF) ContainsGroupCode(code string) bool {
   374  	if len(code) == 0 {
   375  		return false
   376  	}
   377  	for _, group := range this.Inbound {
   378  		if group.Code == code {
   379  			return true
   380  		}
   381  	}
   382  	for _, group := range this.Outbound {
   383  		if group.Code == code {
   384  			return true
   385  		}
   386  	}
   387  	return false
   388  }
   390  func (this *WAF) AddAction(action ActionInterface) {
   391  	this.actionMap[action.ActionId()] = action
   392  }
   394  func (this *WAF) FindAction(actionId int64) ActionInterface {
   395  	return this.actionMap[actionId]
   396  }
   398  func (this *WAF) Copy() *WAF {
   399  	var waf = &WAF{
   400  		Id:       this.Id,
   401  		IsOn:     this.IsOn,
   402  		Name:     this.Name,
   403  		Inbound:  this.Inbound,
   404  		Outbound: this.Outbound,
   405  	}
   406  	return waf
   407  }
   409  func (this *WAF) CountInboundRuleSets() int {
   410  	count := 0
   411  	for _, group := range this.Inbound {
   412  		count += len(group.RuleSets)
   413  	}
   414  	return count
   415  }
   417  func (this *WAF) CountOutboundRuleSets() int {
   418  	count := 0
   419  	for _, group := range this.Outbound {
   420  		count += len(group.RuleSets)
   421  	}
   422  	return count
   423  }
   425  func (this *WAF) FindCheckpointInstance(prefix string) checkpoints.CheckpointInterface {
   426  	instance, ok := this.checkpointsMap[prefix]
   427  	if ok {
   428  		return instance
   429  	}
   430  	return nil
   431  }
   433  // Start
   434  func (this *WAF) Start() {
   435  	for _, checkpoint := range this.checkpointsMap {
   436  		checkpoint.Start()
   437  	}
   438  }
   440  // Stop call stop() when the waf was deleted
   441  func (this *WAF) Stop() {
   442  	for _, checkpoint := range this.checkpointsMap {
   443  		checkpoint.Stop()
   444  	}
   445  }
   447  // MergeTemplate merge with template
   448  func (this *WAF) MergeTemplate() (changedItems []string, err error) {
   449  	changedItems = []string{}
   451  	// compare versions
   452  	if !Tea.IsTesting() && this.CreatedVersion == teaconst.Version {
   453  		return
   454  	}
   455  	this.CreatedVersion = teaconst.Version
   457  	template, err := Template()
   458  	if err != nil {
   459  		return nil, err
   460  	}
   461  	groups := []*RuleGroup{}
   462  	groups = append(groups, template.Inbound...)
   463  	groups = append(groups, template.Outbound...)
   465  	var newGroupId int64 = 1_000_000_000
   467  	for _, group := range groups {
   468  		oldGroup := this.FindRuleGroupWithCode(group.Code)
   469  		if oldGroup == nil {
   470  			newGroupId++
   471  			group.Id = newGroupId
   472  			this.AddRuleGroup(group)
   473  			changedItems = append(changedItems, "+group "+group.Name)
   474  			continue
   475  		}
   477  		// check rule sets
   478  		for _, set := range group.RuleSets {
   479  			oldSet := oldGroup.FindRuleSetWithCode(set.Code)
   480  			if oldSet == nil {
   481  				oldGroup.AddRuleSet(set)
   482  				changedItems = append(changedItems, "+group "+group.Name+" rule set:"+set.Name)
   483  			} else if len(oldSet.Rules) < len(set.Rules) {
   484  				oldSet.Rules = set.Rules
   485  				changedItems = append(changedItems, "*group "+group.Name+" rule set:"+set.Name)
   486  			}
   487  		}
   488  	}
   489  	return
   490  }