github.com/TeaOSLab/EdgeNode@v1.3.8/internal/waf/ip_list.go (about)

     1  // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
     2  
     3  package waf
     4  
     5  import (
     6  	"encoding/json"
     7  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
     8  	"github.com/TeaOSLab/EdgeNode/internal/conns"
     9  	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
    10  	"github.com/TeaOSLab/EdgeNode/internal/events"
    11  	"github.com/TeaOSLab/EdgeNode/internal/firewalls"
    12  	"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
    13  	"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
    14  	"github.com/iwind/TeaGo/Tea"
    15  	"github.com/iwind/TeaGo/maps"
    16  	"github.com/iwind/TeaGo/types"
    17  	"os"
    18  	"sync"
    19  	"sync/atomic"
    20  )
    21  
    22  var SharedIPWhiteList = NewIPList(IPListTypeAllow)
    23  var SharedIPBlackList = NewIPList(IPListTypeDeny)
    24  
    25  type IPListType = string
    26  
    27  const (
    28  	IPListTypeAllow IPListType = "allow"
    29  	IPListTypeDeny  IPListType = "deny"
    30  )
    31  
    32  const IPTypeAll = "*"
    33  
    34  func init() {
    35  	if !teaconst.IsMain {
    36  		return
    37  	}
    38  
    39  	var cacheFile = Tea.Root + "/data/waf_white_list.cache"
    40  
    41  	// save
    42  	events.On(events.EventTerminated, func() {
    43  		_ = SharedIPWhiteList.Save(cacheFile)
    44  	})
    45  
    46  	// load
    47  	go func() {
    48  		if !Tea.IsTesting() {
    49  			_ = SharedIPWhiteList.Load(cacheFile)
    50  			_ = os.Remove(cacheFile)
    51  		}
    52  	}()
    53  }
    54  
    55  // IPList IP列表管理
    56  type IPList struct {
    57  	expireList *expires.List
    58  	ipMap      map[string]uint64 // ip info => id
    59  	idMap      map[uint64]string // id => ip info
    60  	listType   IPListType
    61  
    62  	id     uint64
    63  	locker sync.RWMutex
    64  
    65  	lastIP   string // 加入到 recordIPTaskChan 之前尽可能去重
    66  	lastTime int64
    67  }
    68  
    69  // NewIPList 获取新对象
    70  func NewIPList(listType IPListType) *IPList {
    71  	var list = &IPList{
    72  		ipMap:    map[string]uint64{},
    73  		idMap:    map[uint64]string{},
    74  		listType: listType,
    75  	}
    76  
    77  	var e = expires.NewList()
    78  	list.expireList = e
    79  
    80  	e.OnGC(func(itemId uint64) {
    81  		list.remove(itemId) // TODO 使用异步,防止阻塞GC
    82  	})
    83  
    84  	return list
    85  }
    86  
    87  // Add 添加IP
    88  func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) {
    89  	switch scope {
    90  	case firewallconfigs.FirewallScopeGlobal:
    91  		ip = "*@" + ip + "@" + ipType
    92  	case firewallconfigs.FirewallScopeServer:
    93  		ip = types.String(serverId) + "@" + ip + "@" + ipType
    94  	default:
    95  		ip = "*@" + ip + "@" + ipType
    96  	}
    97  
    98  	var id = this.nextId()
    99  	this.expireList.Add(id, expiresAt)
   100  	this.locker.Lock()
   101  
   102  	// 删除以前
   103  	oldId, ok := this.ipMap[ip]
   104  	if ok {
   105  		delete(this.idMap, oldId)
   106  		this.expireList.Remove(oldId)
   107  	}
   108  
   109  	this.ipMap[ip] = id
   110  	this.idMap[id] = ip
   111  	this.locker.Unlock()
   112  }
   113  
   114  // RecordIP 记录IP
   115  func (this *IPList) RecordIP(ipType string,
   116  	scope firewallconfigs.FirewallScope,
   117  	serverId int64,
   118  	ip string,
   119  	expiresAt int64,
   120  	policyId int64,
   121  	useLocalFirewall bool,
   122  	groupId int64,
   123  	setId int64,
   124  	reason string) {
   125  	this.Add(ipType, scope, serverId, ip, expiresAt)
   126  
   127  	if this.listType == IPListTypeDeny {
   128  		// 作用域
   129  		var scopeServerId int64
   130  		if scope == firewallconfigs.FirewallScopeServer {
   131  			scopeServerId = serverId
   132  		}
   133  
   134  		// 加入队列等待上传
   135  		if this.lastIP != ip || fasttime.Now().Unix()-this.lastTime > 3 /** 3秒外才允许重复添加 **/ {
   136  			select {
   137  			case recordIPTaskChan <- &recordIPTask{
   138  				ip:                            ip,
   139  				listId:                        firewallconfigs.GlobalListId,
   140  				expiresAt:                     expiresAt,
   141  				level:                         firewallconfigs.DefaultEventLevel,
   142  				serverId:                      scopeServerId,
   143  				sourceServerId:                serverId,
   144  				sourceHTTPFirewallPolicyId:    policyId,
   145  				sourceHTTPFirewallRuleGroupId: groupId,
   146  				sourceHTTPFirewallRuleSetId:   setId,
   147  				reason:                        reason,
   148  			}:
   149  				this.lastIP = ip
   150  				this.lastTime = fasttime.Now().Unix()
   151  			default:
   152  			}
   153  
   154  			// 使用本地防火墙
   155  			if useLocalFirewall {
   156  				firewalls.DropTemporaryTo(ip, expiresAt)
   157  			}
   158  		}
   159  
   160  		// 关闭此IP相关连接
   161  		conns.SharedMap.CloseIPConns(ip)
   162  	}
   163  }
   164  
   165  // Contains 判断是否有某个IP
   166  func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool {
   167  	switch scope {
   168  	case firewallconfigs.FirewallScopeGlobal:
   169  		ip = "*@" + ip + "@" + ipType
   170  	case firewallconfigs.FirewallScopeServer:
   171  		ip = types.String(serverId) + "@" + ip + "@" + ipType
   172  	default:
   173  		ip = "*@" + ip + "@" + ipType
   174  	}
   175  
   176  	this.locker.RLock()
   177  	_, ok := this.ipMap[ip]
   178  	this.locker.RUnlock()
   179  	return ok
   180  }
   181  
   182  // ContainsExpires 判断是否有某个IP,并返回过期时间
   183  func (this *IPList) ContainsExpires(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) (expiresAt int64, ok bool) {
   184  	switch scope {
   185  	case firewallconfigs.FirewallScopeGlobal:
   186  		ip = "*@" + ip + "@" + ipType
   187  	case firewallconfigs.FirewallScopeServer:
   188  		ip = types.String(serverId) + "@" + ip + "@" + ipType
   189  	default:
   190  		ip = "*@" + ip + "@" + ipType
   191  	}
   192  
   193  	this.locker.RLock()
   194  	id, ok := this.ipMap[ip]
   195  	if ok {
   196  		expiresAt = this.expireList.ExpiresAt(id)
   197  	}
   198  	this.locker.RUnlock()
   199  	return expiresAt, ok
   200  }
   201  
   202  // RemoveIP 删除IP
   203  func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) {
   204  	this.locker.Lock()
   205  
   206  	{
   207  		var key = "*@" + ip + "@" + IPTypeAll
   208  		id, ok := this.ipMap[key]
   209  		if ok {
   210  			delete(this.ipMap, key)
   211  			delete(this.idMap, id)
   212  
   213  			this.expireList.Remove(id)
   214  		}
   215  	}
   216  
   217  	if serverId > 0 {
   218  		var key = types.String(serverId) + "@" + ip + "@" + IPTypeAll
   219  		id, ok := this.ipMap[key]
   220  		if ok {
   221  			delete(this.ipMap, key)
   222  			delete(this.idMap, id)
   223  
   224  			this.expireList.Remove(id)
   225  		}
   226  	}
   227  
   228  	this.locker.Unlock()
   229  
   230  	// 从本地防火墙中删除
   231  	if shouldExecute {
   232  		_ = firewalls.Firewall().RemoveSourceIP(ip)
   233  	}
   234  }
   235  
   236  // Save to local file
   237  func (this *IPList) Save(path string) error {
   238  	var itemMaps = []maps.Map{} // [ {ip info, expiresAt }, ... ]
   239  	this.locker.Lock()
   240  	defer this.locker.Unlock()
   241  
   242  	// prevent too many items
   243  	if len(this.ipMap) > 100_000 {
   244  		return nil
   245  	}
   246  
   247  	for ipInfo, id := range this.ipMap {
   248  		var expiresAt = this.expireList.ExpiresAt(id)
   249  		if expiresAt <= 0 {
   250  			continue
   251  		}
   252  		itemMaps = append(itemMaps, maps.Map{
   253  			"ip":        ipInfo,
   254  			"expiresAt": expiresAt,
   255  		})
   256  	}
   257  
   258  	itemMapsJSON, err := json.Marshal(itemMaps)
   259  	if err != nil {
   260  		return err
   261  	}
   262  	return os.WriteFile(path, itemMapsJSON, 0666)
   263  }
   264  
   265  // Load from local file
   266  func (this *IPList) Load(path string) error {
   267  	data, err := os.ReadFile(path)
   268  	if err != nil {
   269  		return err
   270  	}
   271  	if len(data) == 0 {
   272  		return nil
   273  	}
   274  
   275  	var itemMaps = []maps.Map{}
   276  	err = json.Unmarshal(data, &itemMaps)
   277  	if err != nil {
   278  		return err
   279  	}
   280  
   281  	this.locker.Lock()
   282  	defer this.locker.Unlock()
   283  
   284  	for _, itemMap := range itemMaps {
   285  		var ip = itemMap.GetString("ip")
   286  		var expiresAt = itemMap.GetInt64("expiresAt")
   287  		if len(ip) == 0 || expiresAt < fasttime.Now().Unix()+10 /** seconds **/ {
   288  			continue
   289  		}
   290  
   291  		var id = this.nextId()
   292  		this.expireList.Add(id, expiresAt)
   293  
   294  		this.ipMap[ip] = id
   295  		this.idMap[id] = ip
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  // IPMap get ipMap
   302  func (this *IPList) IPMap() map[string]uint64 {
   303  	this.locker.RLock()
   304  	defer this.locker.RUnlock()
   305  	return this.ipMap
   306  }
   307  
   308  // IdMap get idMap
   309  func (this *IPList) IdMap() map[uint64]string {
   310  	this.locker.RLock()
   311  	defer this.locker.RUnlock()
   312  	return this.idMap
   313  }
   314  
   315  func (this *IPList) remove(id uint64) {
   316  	this.locker.Lock()
   317  	ip, ok := this.idMap[id]
   318  	if ok {
   319  		ipId, ok := this.ipMap[ip]
   320  		if ok && ipId == id {
   321  			delete(this.ipMap, ip)
   322  		}
   323  		delete(this.idMap, id)
   324  	}
   325  	this.locker.Unlock()
   326  }
   327  
   328  func (this *IPList) nextId() uint64 {
   329  	return atomic.AddUint64(&this.id, 1)
   330  }