github.com/TeaOSLab/EdgeNode@v1.3.8/internal/waf/checkpoints/request_referer_block.go (about)

     1  // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
     2  
     3  package checkpoints
     4  
     5  import (
     6  	"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
     7  	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
     8  	"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
     9  	"github.com/iwind/TeaGo/maps"
    10  	"github.com/iwind/TeaGo/types"
    11  	"net/url"
    12  )
    13  
    14  // RequestRefererBlockCheckpoint 防盗链
    15  type RequestRefererBlockCheckpoint struct {
    16  	Checkpoint
    17  }
    18  
    19  // RequestValue 计算checkpoint值
    20  // 选项:allowEmpty, allowSameDomain, allowDomains
    21  func (this *RequestRefererBlockCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
    22  	var checkOrigin = options.GetBool("checkOrigin")
    23  	var referer = req.WAFRaw().Referer()
    24  	if len(referer) == 0 && checkOrigin {
    25  		var origin = req.WAFRaw().Header.Get("Origin")
    26  		if len(origin) > 0 && origin != "null" {
    27  			referer = "https://" + origin // 因为Origin都只有域名部分,所以为了下面的URL 分析需要加上https://
    28  		}
    29  	}
    30  
    31  	if len(referer) == 0 {
    32  		if options.GetBool("allowEmpty") {
    33  			value = 1
    34  			return
    35  		}
    36  		value = 0
    37  		return
    38  	}
    39  
    40  	u, err := url.Parse(referer)
    41  	if err != nil {
    42  		value = 0
    43  		return
    44  	}
    45  	var host = u.Host
    46  
    47  	if options.GetBool("allowSameDomain") && host == req.WAFRaw().Host {
    48  		value = 1
    49  		return
    50  	}
    51  
    52  	// allow domains
    53  	var allowDomains = options.GetSlice("allowDomains")
    54  	var allowDomainStrings = []string{}
    55  	for _, domain := range allowDomains {
    56  		allowDomainStrings = append(allowDomainStrings, types.String(domain))
    57  	}
    58  
    59  	// deny domains
    60  	var denyDomains = options.GetSlice("denyDomains")
    61  	var denyDomainStrings = []string{}
    62  	for _, domain := range denyDomains {
    63  		denyDomainStrings = append(denyDomainStrings, types.String(domain))
    64  	}
    65  
    66  	if len(allowDomainStrings) == 0 {
    67  		if len(denyDomainStrings) > 0 {
    68  			if configutils.MatchDomains(denyDomainStrings, host) {
    69  				value = 0
    70  			} else {
    71  				value = 1
    72  			}
    73  			return
    74  		}
    75  
    76  		value = 0
    77  		return
    78  	}
    79  
    80  	if configutils.MatchDomains(allowDomainStrings, host) {
    81  		if len(denyDomainStrings) > 0 {
    82  			if configutils.MatchDomains(denyDomainStrings, host) {
    83  				value = 0
    84  			} else {
    85  				value = 1
    86  			}
    87  			return
    88  		}
    89  		value = 1
    90  		return
    91  	} else {
    92  		value = 0
    93  	}
    94  
    95  	return
    96  }
    97  
    98  func (this *RequestRefererBlockCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
    99  	return
   100  }
   101  
   102  func (this *RequestRefererBlockCheckpoint) CacheLife() utils.CacheLife {
   103  	return utils.CacheLongLife
   104  }