github.com/TeaOSLab/EdgeNode@v1.3.8/internal/waf/ip_list.go (about) 1 // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. 2 3 package waf 4 5 import ( 6 "encoding/json" 7 "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" 8 "github.com/TeaOSLab/EdgeNode/internal/conns" 9 teaconst "github.com/TeaOSLab/EdgeNode/internal/const" 10 "github.com/TeaOSLab/EdgeNode/internal/events" 11 "github.com/TeaOSLab/EdgeNode/internal/firewalls" 12 "github.com/TeaOSLab/EdgeNode/internal/utils/expires" 13 "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" 14 "github.com/iwind/TeaGo/Tea" 15 "github.com/iwind/TeaGo/maps" 16 "github.com/iwind/TeaGo/types" 17 "os" 18 "sync" 19 "sync/atomic" 20 ) 21 22 var SharedIPWhiteList = NewIPList(IPListTypeAllow) 23 var SharedIPBlackList = NewIPList(IPListTypeDeny) 24 25 type IPListType = string 26 27 const ( 28 IPListTypeAllow IPListType = "allow" 29 IPListTypeDeny IPListType = "deny" 30 ) 31 32 const IPTypeAll = "*" 33 34 func init() { 35 if !teaconst.IsMain { 36 return 37 } 38 39 var cacheFile = Tea.Root + "/data/waf_white_list.cache" 40 41 // save 42 events.On(events.EventTerminated, func() { 43 _ = SharedIPWhiteList.Save(cacheFile) 44 }) 45 46 // load 47 go func() { 48 if !Tea.IsTesting() { 49 _ = SharedIPWhiteList.Load(cacheFile) 50 _ = os.Remove(cacheFile) 51 } 52 }() 53 } 54 55 // IPList IP列表管理 56 type IPList struct { 57 expireList *expires.List 58 ipMap map[string]uint64 // ip info => id 59 idMap map[uint64]string // id => ip info 60 listType IPListType 61 62 id uint64 63 locker sync.RWMutex 64 65 lastIP string // 加入到 recordIPTaskChan 之前尽可能去重 66 lastTime int64 67 } 68 69 // NewIPList 获取新对象 70 func NewIPList(listType IPListType) *IPList { 71 var list = &IPList{ 72 ipMap: map[string]uint64{}, 73 idMap: map[uint64]string{}, 74 listType: listType, 75 } 76 77 var e = expires.NewList() 78 list.expireList = e 79 80 e.OnGC(func(itemId uint64) { 81 list.remove(itemId) // TODO 使用异步,防止阻塞GC 82 }) 83 84 return list 85 } 86 87 // Add 添加IP 88 func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) { 89 switch scope { 90 case firewallconfigs.FirewallScopeGlobal: 91 ip = "*@" + ip + "@" + ipType 92 case firewallconfigs.FirewallScopeServer: 93 ip = types.String(serverId) + "@" + ip + "@" + ipType 94 default: 95 ip = "*@" + ip + "@" + ipType 96 } 97 98 var id = this.nextId() 99 this.expireList.Add(id, expiresAt) 100 this.locker.Lock() 101 102 // 删除以前 103 oldId, ok := this.ipMap[ip] 104 if ok { 105 delete(this.idMap, oldId) 106 this.expireList.Remove(oldId) 107 } 108 109 this.ipMap[ip] = id 110 this.idMap[id] = ip 111 this.locker.Unlock() 112 } 113 114 // RecordIP 记录IP 115 func (this *IPList) RecordIP(ipType string, 116 scope firewallconfigs.FirewallScope, 117 serverId int64, 118 ip string, 119 expiresAt int64, 120 policyId int64, 121 useLocalFirewall bool, 122 groupId int64, 123 setId int64, 124 reason string) { 125 this.Add(ipType, scope, serverId, ip, expiresAt) 126 127 if this.listType == IPListTypeDeny { 128 // 作用域 129 var scopeServerId int64 130 if scope == firewallconfigs.FirewallScopeServer { 131 scopeServerId = serverId 132 } 133 134 // 加入队列等待上传 135 if this.lastIP != ip || fasttime.Now().Unix()-this.lastTime > 3 /** 3秒外才允许重复添加 **/ { 136 select { 137 case recordIPTaskChan <- &recordIPTask{ 138 ip: ip, 139 listId: firewallconfigs.GlobalListId, 140 expiresAt: expiresAt, 141 level: firewallconfigs.DefaultEventLevel, 142 serverId: scopeServerId, 143 sourceServerId: serverId, 144 sourceHTTPFirewallPolicyId: policyId, 145 sourceHTTPFirewallRuleGroupId: groupId, 146 sourceHTTPFirewallRuleSetId: setId, 147 reason: reason, 148 }: 149 this.lastIP = ip 150 this.lastTime = fasttime.Now().Unix() 151 default: 152 } 153 154 // 使用本地防火墙 155 if useLocalFirewall { 156 firewalls.DropTemporaryTo(ip, expiresAt) 157 } 158 } 159 160 // 关闭此IP相关连接 161 conns.SharedMap.CloseIPConns(ip) 162 } 163 } 164 165 // Contains 判断是否有某个IP 166 func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool { 167 switch scope { 168 case firewallconfigs.FirewallScopeGlobal: 169 ip = "*@" + ip + "@" + ipType 170 case firewallconfigs.FirewallScopeServer: 171 ip = types.String(serverId) + "@" + ip + "@" + ipType 172 default: 173 ip = "*@" + ip + "@" + ipType 174 } 175 176 this.locker.RLock() 177 _, ok := this.ipMap[ip] 178 this.locker.RUnlock() 179 return ok 180 } 181 182 // ContainsExpires 判断是否有某个IP,并返回过期时间 183 func (this *IPList) ContainsExpires(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) (expiresAt int64, ok bool) { 184 switch scope { 185 case firewallconfigs.FirewallScopeGlobal: 186 ip = "*@" + ip + "@" + ipType 187 case firewallconfigs.FirewallScopeServer: 188 ip = types.String(serverId) + "@" + ip + "@" + ipType 189 default: 190 ip = "*@" + ip + "@" + ipType 191 } 192 193 this.locker.RLock() 194 id, ok := this.ipMap[ip] 195 if ok { 196 expiresAt = this.expireList.ExpiresAt(id) 197 } 198 this.locker.RUnlock() 199 return expiresAt, ok 200 } 201 202 // RemoveIP 删除IP 203 func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) { 204 this.locker.Lock() 205 206 { 207 var key = "*@" + ip + "@" + IPTypeAll 208 id, ok := this.ipMap[key] 209 if ok { 210 delete(this.ipMap, key) 211 delete(this.idMap, id) 212 213 this.expireList.Remove(id) 214 } 215 } 216 217 if serverId > 0 { 218 var key = types.String(serverId) + "@" + ip + "@" + IPTypeAll 219 id, ok := this.ipMap[key] 220 if ok { 221 delete(this.ipMap, key) 222 delete(this.idMap, id) 223 224 this.expireList.Remove(id) 225 } 226 } 227 228 this.locker.Unlock() 229 230 // 从本地防火墙中删除 231 if shouldExecute { 232 _ = firewalls.Firewall().RemoveSourceIP(ip) 233 } 234 } 235 236 // Save to local file 237 func (this *IPList) Save(path string) error { 238 var itemMaps = []maps.Map{} // [ {ip info, expiresAt }, ... ] 239 this.locker.Lock() 240 defer this.locker.Unlock() 241 242 // prevent too many items 243 if len(this.ipMap) > 100_000 { 244 return nil 245 } 246 247 for ipInfo, id := range this.ipMap { 248 var expiresAt = this.expireList.ExpiresAt(id) 249 if expiresAt <= 0 { 250 continue 251 } 252 itemMaps = append(itemMaps, maps.Map{ 253 "ip": ipInfo, 254 "expiresAt": expiresAt, 255 }) 256 } 257 258 itemMapsJSON, err := json.Marshal(itemMaps) 259 if err != nil { 260 return err 261 } 262 return os.WriteFile(path, itemMapsJSON, 0666) 263 } 264 265 // Load from local file 266 func (this *IPList) Load(path string) error { 267 data, err := os.ReadFile(path) 268 if err != nil { 269 return err 270 } 271 if len(data) == 0 { 272 return nil 273 } 274 275 var itemMaps = []maps.Map{} 276 err = json.Unmarshal(data, &itemMaps) 277 if err != nil { 278 return err 279 } 280 281 this.locker.Lock() 282 defer this.locker.Unlock() 283 284 for _, itemMap := range itemMaps { 285 var ip = itemMap.GetString("ip") 286 var expiresAt = itemMap.GetInt64("expiresAt") 287 if len(ip) == 0 || expiresAt < fasttime.Now().Unix()+10 /** seconds **/ { 288 continue 289 } 290 291 var id = this.nextId() 292 this.expireList.Add(id, expiresAt) 293 294 this.ipMap[ip] = id 295 this.idMap[id] = ip 296 } 297 298 return nil 299 } 300 301 // IPMap get ipMap 302 func (this *IPList) IPMap() map[string]uint64 { 303 this.locker.RLock() 304 defer this.locker.RUnlock() 305 return this.ipMap 306 } 307 308 // IdMap get idMap 309 func (this *IPList) IdMap() map[uint64]string { 310 this.locker.RLock() 311 defer this.locker.RUnlock() 312 return this.idMap 313 } 314 315 func (this *IPList) remove(id uint64) { 316 this.locker.Lock() 317 ip, ok := this.idMap[id] 318 if ok { 319 ipId, ok := this.ipMap[ip] 320 if ok && ipId == id { 321 delete(this.ipMap, ip) 322 } 323 delete(this.idMap, id) 324 } 325 this.locker.Unlock() 326 } 327 328 func (this *IPList) nextId() uint64 { 329 return atomic.AddUint64(&this.id, 1) 330 }