github.com/TeaOSLab/EdgeNode@v1.3.8/internal/firewalls/ddos_protection.go (about) 1 // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. 2 //go:build linux 3 4 package firewalls 5 6 import ( 7 "bytes" 8 "encoding/json" 9 "errors" 10 "fmt" 11 "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" 12 "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs" 13 teaconst "github.com/TeaOSLab/EdgeNode/internal/const" 14 "github.com/TeaOSLab/EdgeNode/internal/events" 15 "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" 16 "github.com/TeaOSLab/EdgeNode/internal/remotelogs" 17 "github.com/TeaOSLab/EdgeNode/internal/utils" 18 executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec" 19 "github.com/TeaOSLab/EdgeNode/internal/zero" 20 "github.com/iwind/TeaGo/lists" 21 "github.com/iwind/TeaGo/types" 22 stringutil "github.com/iwind/TeaGo/utils/string" 23 "net" 24 "strings" 25 "sync" 26 "time" 27 ) 28 29 var SharedDDoSProtectionManager = NewDDoSProtectionManager() 30 31 func init() { 32 if !teaconst.IsMain { 33 return 34 } 35 36 events.On(events.EventReload, func() { 37 if nftablesInstance == nil { 38 return 39 } 40 41 nodeConfig, _ := nodeconfigs.SharedNodeConfig() 42 if nodeConfig != nil { 43 err := SharedDDoSProtectionManager.Apply(nodeConfig.DDoSProtection) 44 if err != nil { 45 remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error()) 46 } 47 } 48 }) 49 50 events.On(events.EventNFTablesReady, func() { 51 nodeConfig, _ := nodeconfigs.SharedNodeConfig() 52 if nodeConfig != nil { 53 err := SharedDDoSProtectionManager.Apply(nodeConfig.DDoSProtection) 54 if err != nil { 55 remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error()) 56 } 57 } 58 }) 59 } 60 61 // DDoSProtectionManager DDoS防护 62 type DDoSProtectionManager struct { 63 lastAllowIPList []string 64 lastConfig []byte 65 66 locker sync.Mutex 67 } 68 69 // NewDDoSProtectionManager 获取新对象 70 func NewDDoSProtectionManager() *DDoSProtectionManager { 71 return &DDoSProtectionManager{} 72 } 73 74 // Apply 应用配置 75 func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error { 76 // 加锁防止并发更改 77 if !this.locker.TryLock() { 78 return nil 79 } 80 defer this.locker.Unlock() 81 82 // 同集群节点IP白名单 83 var allowIPListChanged = false 84 nodeConfig, _ := nodeconfigs.SharedNodeConfig() 85 if nodeConfig != nil { 86 var allowIPList = nodeConfig.AllowedIPs 87 if !utils.EqualStrings(allowIPList, this.lastAllowIPList) { 88 allowIPListChanged = true 89 this.lastAllowIPList = allowIPList 90 } 91 } 92 93 // 对比配置 94 configJSON, err := json.Marshal(config) 95 if err != nil { 96 return fmt.Errorf("encode config to json failed: %w", err) 97 } 98 if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) { 99 return nil 100 } 101 remotelogs.Println("FIREWALL", "change DDoS protection config") 102 103 if len(nftables.NftExePath()) == 0 { 104 return errors.New("can not find nft command") 105 } 106 107 if nftablesInstance == nil { 108 if config == nil || !config.IsOn() { 109 return nil 110 } 111 return errors.New("nftables instance should not be nil") 112 } 113 114 if config == nil { 115 // TCP 116 err := this.removeTCPRules() 117 if err != nil { 118 return err 119 } 120 121 // TODO other protocols 122 123 return nil 124 } 125 126 // TCP 127 if config.TCP == nil { 128 err := this.removeTCPRules() 129 if err != nil { 130 return err 131 } 132 } else { 133 // allow ip list 134 var allowIPList = []string{} 135 for _, ipConfig := range config.TCP.AllowIPList { 136 allowIPList = append(allowIPList, ipConfig.IP) 137 } 138 for _, ip := range this.lastAllowIPList { 139 if !lists.ContainsString(allowIPList, ip) { 140 allowIPList = append(allowIPList, ip) 141 } 142 } 143 err = this.updateAllowIPList(allowIPList) 144 if err != nil { 145 return err 146 } 147 148 // tcp 149 if config.TCP.IsOn { 150 err := this.addTCPRules(config.TCP) 151 if err != nil { 152 return err 153 } 154 } else { 155 err := this.removeTCPRules() 156 if err != nil { 157 return err 158 } 159 } 160 } 161 162 this.lastConfig = configJSON 163 164 return nil 165 } 166 167 // 添加TCP规则 168 func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error { 169 var nftExe = nftables.NftExePath() 170 if len(nftExe) == 0 { 171 return nil 172 } 173 174 // 检查nft版本不能小于0.9 175 if len(nftablesInstance.version) > 0 && stringutil.VersionCompare("0.9", nftablesInstance.version) > 0 { 176 return nil 177 } 178 179 var ports = []int32{} 180 for _, portConfig := range tcpConfig.Ports { 181 if !lists.ContainsInt32(ports, portConfig.Port) { 182 ports = append(ports, portConfig.Port) 183 } 184 } 185 if len(ports) == 0 { 186 ports = []int32{80, 443} 187 } 188 189 for _, filter := range nftablesFilters { 190 chain, oldRules, err := this.getRules(filter) 191 if err != nil { 192 return fmt.Errorf("get old rules failed: %w", err) 193 } 194 195 var protocol = filter.protocol() 196 197 // max connections 198 var maxConnections = tcpConfig.MaxConnections 199 if maxConnections <= 0 { 200 maxConnections = nodeconfigs.DefaultTCPMaxConnections 201 if maxConnections <= 0 { 202 maxConnections = 100000 203 } 204 } 205 206 // max connections per ip 207 var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP 208 if maxConnectionsPerIP <= 0 { 209 maxConnectionsPerIP = nodeconfigs.DefaultTCPMaxConnectionsPerIP 210 if maxConnectionsPerIP <= 0 { 211 maxConnectionsPerIP = 100000 212 } 213 } 214 215 // new connections rate (minutely) 216 var newConnectionsMinutelyRate = tcpConfig.NewConnectionsMinutelyRate 217 if newConnectionsMinutelyRate <= 0 { 218 newConnectionsMinutelyRate = nodeconfigs.DefaultTCPNewConnectionsMinutelyRate 219 if newConnectionsMinutelyRate <= 0 { 220 newConnectionsMinutelyRate = 100000 221 } 222 } 223 var newConnectionsMinutelyRateBlockTimeout = tcpConfig.NewConnectionsMinutelyRateBlockTimeout 224 if newConnectionsMinutelyRateBlockTimeout < 0 { 225 newConnectionsMinutelyRateBlockTimeout = 0 226 } 227 228 // new connections rate (secondly) 229 var newConnectionsSecondlyRate = tcpConfig.NewConnectionsSecondlyRate 230 if newConnectionsSecondlyRate <= 0 { 231 newConnectionsSecondlyRate = nodeconfigs.DefaultTCPNewConnectionsSecondlyRate 232 if newConnectionsSecondlyRate <= 0 { 233 newConnectionsSecondlyRate = 10000 234 } 235 } 236 var newConnectionsSecondlyRateBlockTimeout = tcpConfig.NewConnectionsSecondlyRateBlockTimeout 237 if newConnectionsSecondlyRateBlockTimeout < 0 { 238 newConnectionsSecondlyRateBlockTimeout = 0 239 } 240 241 // 检查是否有变化 242 var hasChanges = false 243 for _, port := range ports { 244 if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) { 245 hasChanges = true 246 break 247 } 248 if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) { 249 hasChanges = true 250 break 251 } 252 if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}) { 253 hasChanges = true 254 break 255 } 256 if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}) { 257 hasChanges = true 258 break 259 } 260 } 261 262 if !hasChanges { 263 // 检查是否有多余的端口 264 var oldPorts = this.getTCPPorts(oldRules) 265 if !this.eqPorts(ports, oldPorts) { 266 hasChanges = true 267 } 268 } 269 270 if !hasChanges { 271 return nil 272 } 273 274 // 先清空所有相关规则 275 err = this.removeOldTCPRules(chain, oldRules) 276 if err != nil { 277 return fmt.Errorf("delete old rules failed: %w", err) 278 } 279 280 // 添加新规则 281 for _, port := range ports { 282 if maxConnections > 0 { 283 var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)})) 284 cmd.WithStderr() 285 err = cmd.Run() 286 if err != nil { 287 return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, cmd.Stderr()) 288 } 289 } 290 291 // TODO 让用户选择是drop还是reject 292 if maxConnectionsPerIP > 0 { 293 var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)})) 294 cmd.WithStderr() 295 err := cmd.Run() 296 if err != nil { 297 return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, cmd.Stderr()) 298 } 299 } 300 301 // 超过一定速率就drop或者加入黑名单(分钟) 302 // TODO 让用户选择是drop还是reject 303 if newConnectionsMinutelyRate > 0 { 304 if newConnectionsMinutelyRateBlockTimeout > 0 { 305 var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsMinutelyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)})) 306 cmd.WithStderr() 307 err := cmd.Run() 308 if err != nil { 309 return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, cmd.Stderr()) 310 } 311 } else { 312 var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", "0"})) 313 cmd.WithStderr() 314 err := cmd.Run() 315 if err != nil { 316 return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, cmd.Stderr()) 317 } 318 } 319 } 320 321 // 超过一定速率就drop或者加入黑名单(秒) 322 // TODO 让用户选择是drop还是reject 323 if newConnectionsSecondlyRate > 0 { 324 if newConnectionsSecondlyRateBlockTimeout > 0 { 325 var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsSecondlyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)})) 326 cmd.WithStderr() 327 err := cmd.Run() 328 if err != nil { 329 return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, cmd.Stderr()) 330 } 331 } else { 332 var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", "0"})) 333 cmd.WithStderr() 334 err := cmd.Run() 335 if err != nil { 336 return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, cmd.Stderr()) 337 } 338 } 339 } 340 } 341 } 342 343 return nil 344 } 345 346 // 删除TCP规则 347 func (this *DDoSProtectionManager) removeTCPRules() error { 348 for _, filter := range nftablesFilters { 349 chain, rules, err := this.getRules(filter) 350 351 // TCP 352 err = this.removeOldTCPRules(chain, rules) 353 if err != nil { 354 return err 355 } 356 } 357 358 return nil 359 } 360 361 // 组合user data 362 // 数据中不能包含字母、数字、下划线以外的数据 363 func (this *DDoSProtectionManager) encodeUserData(attrs []string) string { 364 if attrs == nil { 365 return "" 366 } 367 368 return "ZZ" + strings.Join(attrs, "_") + "ZZ" 369 } 370 371 // 解码user data 372 func (this *DDoSProtectionManager) decodeUserData(data []byte) []string { 373 if len(data) == 0 { 374 return nil 375 } 376 377 var dataCopy = make([]byte, len(data)) 378 copy(dataCopy, data) 379 380 var separatorLen = 2 381 var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'}) 382 if index1 < 0 { 383 return nil 384 } 385 386 dataCopy = dataCopy[index1+separatorLen:] 387 var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'}) 388 if index2 < 0 { 389 return nil 390 } 391 392 var s = string(dataCopy[:index2]) 393 var pieces = strings.Split(s, "_") 394 for index, piece := range pieces { 395 pieces[index] = strings.TrimSpace(piece) 396 } 397 return pieces 398 } 399 400 // 清除规则 401 func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error { 402 for _, rule := range rules { 403 var pieces = this.decodeUserData(rule.UserData()) 404 if len(pieces) < 4 { 405 continue 406 } 407 if pieces[0] != "tcp" { 408 continue 409 } 410 switch pieces[2] { 411 case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate", "newConnectionsSecondlyRate": 412 err := chain.DeleteRule(rule) 413 if err != nil { 414 return err 415 } 416 } 417 } 418 419 return nil 420 } 421 422 // 根据参数检查规则是否存在 423 func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) { 424 if len(attrs) == 0 { 425 return false 426 } 427 for _, oldRule := range rules { 428 var pieces = this.decodeUserData(oldRule.UserData()) 429 if len(attrs) != len(pieces) { 430 continue 431 } 432 var isSame = true 433 for index, piece := range pieces { 434 if strings.TrimSpace(piece) != attrs[index] { 435 isSame = false 436 break 437 } 438 } 439 if isSame { 440 return true 441 } 442 } 443 return false 444 } 445 446 // 获取规则中的端口号 447 func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 { 448 var ports = []int32{} 449 for _, rule := range rules { 450 var pieces = this.decodeUserData(rule.UserData()) 451 if len(pieces) != 4 { 452 continue 453 } 454 if pieces[0] != "tcp" { 455 continue 456 } 457 var port = types.Int32(pieces[1]) 458 if port > 0 && !lists.ContainsInt32(ports, port) { 459 ports = append(ports, port) 460 } 461 } 462 return ports 463 } 464 465 // 检查端口是否一样 466 func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool { 467 if len(ports1) != len(ports2) { 468 return false 469 } 470 471 var portMap = map[int32]bool{} 472 for _, port := range ports2 { 473 portMap[port] = true 474 } 475 476 for _, port := range ports1 { 477 _, ok := portMap[port] 478 if !ok { 479 return false 480 } 481 } 482 return true 483 } 484 485 // 查找Table 486 func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) { 487 var family nftables.TableFamily 488 if filter.IsIPv4 { 489 family = nftables.TableFamilyIPv4 490 } else if filter.IsIPv6 { 491 family = nftables.TableFamilyIPv6 492 } else { 493 return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6") 494 } 495 return nftablesInstance.conn.GetTable(filter.Name, family) 496 } 497 498 // 查找所有规则 499 func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) { 500 table, err := this.getTable(filter) 501 if err != nil { 502 return nil, nil, fmt.Errorf("get table failed: %w", err) 503 } 504 chain, err := table.GetChain(nftablesChainName) 505 if err != nil { 506 return nil, nil, fmt.Errorf("get chain failed: %w", err) 507 } 508 rules, err := chain.GetRules() 509 return chain, rules, err 510 } 511 512 // 更新白名单 513 func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error { 514 if nftablesInstance == nil { 515 return nil 516 } 517 518 var allMap = map[string]zero.Zero{} 519 for _, ip := range allIPList { 520 allMap[ip] = zero.New() 521 } 522 523 for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} { 524 var isIPv4 = set == nftablesInstance.allowIPv4Set 525 var isIPv6 = !isIPv4 526 527 // 现有的 528 oldList, err := set.GetIPElements() 529 if err != nil { 530 return err 531 } 532 var oldMap = map[string]zero.Zero{} // ip=> zero 533 for _, ip := range oldList { 534 oldMap[ip] = zero.New() 535 536 if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) { 537 _, ok := allMap[ip] 538 if !ok { 539 // 不存在则删除 540 err = set.DeleteIPElement(ip) 541 if err != nil { 542 return fmt.Errorf("delete ip element '%s' failed: %w", ip, err) 543 } 544 } 545 } 546 } 547 548 // 新增的 549 for _, ip := range allIPList { 550 var ipObj = net.ParseIP(ip) 551 if ipObj == nil { 552 continue 553 } 554 if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) { 555 _, ok := oldMap[ip] 556 if !ok { 557 // 不存在则添加 558 err = set.AddIPElement(ip, nil, false) 559 if err != nil { 560 return fmt.Errorf("add ip '%s' failed: %w", ip, err) 561 } 562 } 563 } 564 } 565 } 566 567 return nil 568 }