github.com/TeaOSLab/EdgeNode@v1.3.8/internal/waf/action_block.go (about)

     1  package waf
     2  
     3  import (
     4  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
     5  	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
     6  	"github.com/TeaOSLab/EdgeNode/internal/utils"
     7  	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
     8  	"github.com/iwind/TeaGo/Tea"
     9  	"github.com/iwind/TeaGo/logs"
    10  	"github.com/iwind/TeaGo/rands"
    11  	"io"
    12  	"net/http"
    13  	"os"
    14  	"path/filepath"
    15  	"regexp"
    16  	"time"
    17  )
    18  
    19  // url client configure
    20  var urlPrefixReg = regexp.MustCompile("^(?i)(http|https)://")
    21  var httpClient = utils.SharedHttpClient(5 * time.Second)
    22  
    23  type BlockAction struct {
    24  	BaseAction
    25  
    26  	StatusCode int    `yaml:"statusCode" json:"statusCode"`
    27  	Body       string `yaml:"body" json:"body"` // supports HTML
    28  	URL        string `yaml:"url" json:"url"`
    29  	Timeout    int32  `yaml:"timeout" json:"timeout"`
    30  	TimeoutMax int32  `yaml:"timeoutMax" json:"timeoutMax"`
    31  	Scope      string `yaml:"scope" json:"scope"`
    32  
    33  	FailBlockScopeAll bool `yaml:"failBlockScopeAll" json:"failBlockScopeAll"`
    34  }
    35  
    36  func (this *BlockAction) Init(waf *WAF) error {
    37  	if waf.DefaultBlockAction != nil {
    38  		if this.StatusCode <= 0 {
    39  			this.StatusCode = waf.DefaultBlockAction.StatusCode
    40  		}
    41  		if len(this.Body) == 0 {
    42  			this.Body = waf.DefaultBlockAction.Body
    43  		}
    44  		if len(this.URL) == 0 {
    45  			this.URL = waf.DefaultBlockAction.URL
    46  		}
    47  		if this.Timeout <= 0 {
    48  			this.Timeout = waf.DefaultBlockAction.Timeout
    49  			this.TimeoutMax = waf.DefaultBlockAction.TimeoutMax // 只有没有填写封锁时长的时候才会使用默认的封锁时长最大值
    50  		}
    51  
    52  		this.FailBlockScopeAll = waf.DefaultBlockAction.FailBlockScopeAll
    53  	}
    54  
    55  	return nil
    56  }
    57  
    58  func (this *BlockAction) Code() string {
    59  	return ActionBlock
    60  }
    61  
    62  func (this *BlockAction) IsAttack() bool {
    63  	return true
    64  }
    65  
    66  func (this *BlockAction) WillChange() bool {
    67  	return true
    68  }
    69  
    70  func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
    71  	// 加入到黑名单
    72  	var timeout = this.Timeout
    73  	if timeout <= 0 {
    74  		timeout = 300 // 默认封锁300秒
    75  	}
    76  
    77  	// 随机时长
    78  	var timeoutMax = this.TimeoutMax
    79  	if timeoutMax > timeout {
    80  		timeout = timeout + int32(rands.Int64()%int64(timeoutMax-timeout+1))
    81  	}
    82  
    83  	SharedIPBlackList.RecordIP(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout), waf.Id, waf.UseLocalFirewall && (this.FailBlockScopeAll || this.Scope == firewallconfigs.FirewallScopeGlobal), group.Id, set.Id, "")
    84  
    85  	if writer != nil {
    86  		// close the connection
    87  		defer request.WAFClose()
    88  
    89  		// output response
    90  		if this.StatusCode > 0 {
    91  			request.ProcessResponseHeaders(writer.Header(), this.StatusCode)
    92  			writer.WriteHeader(this.StatusCode)
    93  		} else {
    94  			request.ProcessResponseHeaders(writer.Header(), http.StatusForbidden)
    95  			writer.WriteHeader(http.StatusForbidden)
    96  		}
    97  		if len(this.URL) > 0 {
    98  			if urlPrefixReg.MatchString(this.URL) {
    99  				req, err := http.NewRequest(http.MethodGet, this.URL, nil)
   100  				if err != nil {
   101  					logs.Error(err)
   102  					return PerformResult{}
   103  				}
   104  				req.Header.Set("User-Agent", teaconst.GlobalProductName+"/"+teaconst.Version)
   105  
   106  				resp, err := httpClient.Do(req)
   107  				if err != nil {
   108  					logs.Error(err)
   109  					return PerformResult{}
   110  				}
   111  				defer func() {
   112  					_ = resp.Body.Close()
   113  				}()
   114  
   115  				for k, v := range resp.Header {
   116  					for _, v1 := range v {
   117  						writer.Header().Add(k, v1)
   118  					}
   119  				}
   120  
   121  				var buf = utils.BytePool1k.Get()
   122  				_, _ = io.CopyBuffer(writer, resp.Body, buf.Bytes)
   123  				utils.BytePool1k.Put(buf)
   124  			} else {
   125  				var path = this.URL
   126  				if !filepath.IsAbs(this.URL) {
   127  					path = Tea.Root + string(os.PathSeparator) + path
   128  				}
   129  
   130  				data, err := os.ReadFile(path)
   131  				if err != nil {
   132  					logs.Error(err)
   133  					return PerformResult{}
   134  				}
   135  				_, _ = writer.Write(data)
   136  			}
   137  			return PerformResult{}
   138  		}
   139  		if len(this.Body) > 0 {
   140  			_, _ = writer.Write([]byte(this.Body))
   141  		} else {
   142  			_, _ = writer.Write([]byte("The request is blocked by " + teaconst.ProductName))
   143  		}
   144  	}
   145  
   146  	return PerformResult{}
   147  }