github.com/TeaOSLab/EdgeNode@v1.3.8/internal/firewalls/firewall_nftables.go (about) 1 // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. 2 //go:build linux 3 4 package firewalls 5 6 import ( 7 "errors" 8 "fmt" 9 "github.com/TeaOSLab/EdgeCommon/pkg/iputils" 10 "github.com/TeaOSLab/EdgeNode/internal/conns" 11 teaconst "github.com/TeaOSLab/EdgeNode/internal/const" 12 "github.com/TeaOSLab/EdgeNode/internal/events" 13 "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" 14 "github.com/TeaOSLab/EdgeNode/internal/goman" 15 "github.com/TeaOSLab/EdgeNode/internal/remotelogs" 16 executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec" 17 "github.com/google/nftables/expr" 18 "github.com/iwind/TeaGo/types" 19 "net" 20 "regexp" 21 "runtime" 22 "strings" 23 "time" 24 ) 25 26 // check nft status, if being enabled we load it automatically 27 func init() { 28 if !teaconst.IsMain { 29 return 30 } 31 32 if runtime.GOOS == "linux" { 33 var ticker = time.NewTicker(3 * time.Minute) 34 goman.New(func() { 35 for range ticker.C { 36 // if already ready, we break 37 if nftablesIsReady { 38 ticker.Stop() 39 break 40 } 41 var nftExe = nftables.NftExePath() 42 if len(nftExe) > 0 { 43 nftablesFirewall, err := NewNFTablesFirewall() 44 if err != nil { 45 continue 46 } 47 currentFirewall = nftablesFirewall 48 remotelogs.Println("FIREWALL", "nftables is ready") 49 50 // fire event 51 if nftablesFirewall.IsReady() { 52 events.Notify(events.EventNFTablesReady) 53 } 54 55 ticker.Stop() 56 break 57 } 58 } 59 }) 60 } 61 } 62 63 var nftablesInstance *NFTablesFirewall 64 var nftablesIsReady = false 65 var nftablesFilters = []*nftablesTableDefinition{ 66 // we shorten the name for table name length restriction 67 {Name: "edge_dft_v4", IsIPv4: true}, 68 {Name: "edge_dft_v6", IsIPv6: true}, 69 } 70 var nftablesChainName = "input" 71 72 type nftablesTableDefinition struct { 73 Name string 74 IsIPv4 bool 75 IsIPv6 bool 76 } 77 78 func (this *nftablesTableDefinition) protocol() string { 79 if this.IsIPv6 { 80 return "ip6" 81 } 82 return "ip" 83 } 84 85 type blockIPItem struct { 86 action string 87 ip string 88 timeoutSeconds int 89 } 90 91 func NewNFTablesFirewall() (*NFTablesFirewall, error) { 92 conn, err := nftables.NewConn() 93 if err != nil { 94 return nil, err 95 } 96 var firewall = &NFTablesFirewall{ 97 conn: conn, 98 dropIPQueue: make(chan *blockIPItem, 4096), 99 } 100 err = firewall.init() 101 if err != nil { 102 return nil, err 103 } 104 105 return firewall, nil 106 } 107 108 type NFTablesFirewall struct { 109 BaseFirewall 110 111 conn *nftables.Conn 112 isReady bool 113 version string 114 115 allowIPv4Set *nftables.Set 116 allowIPv6Set *nftables.Set 117 118 denyIPv4Sets []*nftables.Set 119 denyIPv6Sets []*nftables.Set 120 121 firewalld *Firewalld 122 123 dropIPQueue chan *blockIPItem 124 } 125 126 func (this *NFTablesFirewall) init() error { 127 // check nft 128 var nftPath = nftables.NftExePath() 129 if len(nftPath) == 0 { 130 return errors.New("'nft' not found") 131 } 132 this.version = this.readVersion(nftPath) 133 134 // table 135 for _, tableDef := range nftablesFilters { 136 var family nftables.TableFamily 137 if tableDef.IsIPv4 { 138 family = nftables.TableFamilyIPv4 139 } else if tableDef.IsIPv6 { 140 family = nftables.TableFamilyIPv6 141 } else { 142 return errors.New("invalid table family: " + types.String(tableDef)) 143 } 144 table, err := this.conn.GetTable(tableDef.Name, family) 145 if err != nil { 146 if nftables.IsNotFound(err) { 147 if tableDef.IsIPv4 { 148 table, err = this.conn.AddIPv4Table(tableDef.Name) 149 } else if tableDef.IsIPv6 { 150 table, err = this.conn.AddIPv6Table(tableDef.Name) 151 } 152 if err != nil { 153 return fmt.Errorf("create table '%s' failed: %w", tableDef.Name, err) 154 } 155 } else { 156 return fmt.Errorf("get table '%s' failed: %w", tableDef.Name, err) 157 } 158 } 159 if table == nil { 160 return errors.New("can not create table '" + tableDef.Name + "'") 161 } 162 163 // chain 164 var chainName = nftablesChainName 165 chain, err := table.GetChain(chainName) 166 if err != nil { 167 if nftables.IsNotFound(err) { 168 chain, err = table.AddAcceptChain(chainName) 169 if err != nil { 170 return fmt.Errorf("create chain '%s' failed: %w", chainName, err) 171 } 172 } else { 173 return fmt.Errorf("get chain '%s' failed: %w", chainName, err) 174 } 175 } 176 if chain == nil { 177 return errors.New("can not create chain '" + chainName + "'") 178 } 179 180 // allow lo 181 var loRuleName = []byte("lo") 182 _, err = chain.GetRuleWithUserData(loRuleName) 183 if err != nil { 184 if nftables.IsNotFound(err) { 185 _, err = chain.AddAcceptInterfaceRule("lo", loRuleName) 186 } 187 if err != nil { 188 return fmt.Errorf("add 'lo' rule failed: %w", err) 189 } 190 } 191 192 // allow set 193 // "allow" should be always first 194 for _, setAction := range []string{"allow", "deny", "deny1", "deny2", "deny3", "deny4"} { 195 var setName = setAction + "_set" 196 197 set, err := table.GetSet(setName) 198 if err != nil { 199 if nftables.IsNotFound(err) { 200 var keyType nftables.SetDataType 201 if tableDef.IsIPv4 { 202 keyType = nftables.TypeIPAddr 203 } else if tableDef.IsIPv6 { 204 keyType = nftables.TypeIP6Addr 205 } 206 set, err = table.AddSet(setName, &nftables.SetOptions{ 207 KeyType: keyType, 208 HasTimeout: true, 209 }) 210 if err != nil { 211 return fmt.Errorf("create set '%s' failed: %w", setName, err) 212 } 213 } else { 214 return fmt.Errorf("get set '%s' failed: %w", setName, err) 215 } 216 } 217 if set == nil { 218 return errors.New("can not create set '" + setName + "'") 219 } 220 if tableDef.IsIPv4 { 221 if setAction == "allow" { 222 this.allowIPv4Set = set 223 } else { 224 this.denyIPv4Sets = append(this.denyIPv4Sets, set) 225 } 226 } else if tableDef.IsIPv6 { 227 if setAction == "allow" { 228 this.allowIPv6Set = set 229 } else { 230 this.denyIPv6Sets = append(this.denyIPv6Sets, set) 231 } 232 } 233 234 // rule 235 var ruleName = []byte(setAction) 236 rule, err := chain.GetRuleWithUserData(ruleName) 237 238 // 将以前的drop规则删掉,替换成后面的reject 239 if err == nil && setAction != "allow" && rule != nil && rule.VerDict() == expr.VerdictDrop { 240 deleteErr := chain.DeleteRule(rule) 241 if deleteErr == nil { 242 err = nftables.ErrRuleNotFound 243 rule = nil 244 } 245 } 246 247 if err != nil { 248 if nftables.IsNotFound(err) { 249 if tableDef.IsIPv4 { 250 if setAction == "allow" { 251 rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName) 252 } else { 253 rule, err = chain.AddRejectIPv4SetRule(setName, ruleName) 254 } 255 } else if tableDef.IsIPv6 { 256 if setAction == "allow" { 257 rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName) 258 } else { 259 rule, err = chain.AddRejectIPv6SetRule(setName, ruleName) 260 } 261 } 262 if err != nil { 263 return fmt.Errorf("add rule failed: %w", err) 264 } 265 } else { 266 return fmt.Errorf("get rule failed: %w", err) 267 } 268 } 269 if rule == nil { 270 return errors.New("can not create rule '" + string(ruleName) + "'") 271 } 272 } 273 } 274 275 this.isReady = true 276 nftablesIsReady = true 277 nftablesInstance = this 278 279 goman.New(func() { 280 for ipItem := range this.dropIPQueue { 281 switch ipItem.action { 282 case "drop": 283 err := this.DropSourceIP(ipItem.ip, ipItem.timeoutSeconds, false) 284 if err != nil { 285 remotelogs.Warn("NFTABLES", "drop ip '"+ipItem.ip+"' failed: "+err.Error()) 286 } 287 } 288 } 289 }) 290 291 // load firewalld 292 var firewalld = NewFirewalld() 293 if firewalld.IsReady() { 294 this.firewalld = firewalld 295 } 296 297 return nil 298 } 299 300 // Name 名称 301 func (this *NFTablesFirewall) Name() string { 302 return "nftables" 303 } 304 305 // IsReady 是否已准备被调用 306 func (this *NFTablesFirewall) IsReady() bool { 307 return this.isReady 308 } 309 310 // IsMock 是否为模拟 311 func (this *NFTablesFirewall) IsMock() bool { 312 return false 313 } 314 315 // AllowPort 允许端口 316 func (this *NFTablesFirewall) AllowPort(port int, protocol string) error { 317 if this.firewalld != nil { 318 return this.firewalld.AllowPort(port, protocol) 319 } 320 return nil 321 } 322 323 // RemovePort 删除端口 324 func (this *NFTablesFirewall) RemovePort(port int, protocol string) error { 325 if this.firewalld != nil { 326 return this.firewalld.RemovePort(port, protocol) 327 } 328 return nil 329 } 330 331 // AllowSourceIP Allow把IP加入白名单 332 func (this *NFTablesFirewall) AllowSourceIP(ip string) error { 333 var data = net.ParseIP(ip) 334 if data == nil { 335 return errors.New("invalid ip '" + ip + "'") 336 } 337 338 if strings.Contains(ip, ":") { // ipv6 339 if this.allowIPv6Set == nil { 340 return errors.New("ipv6 ip set is nil") 341 } 342 return this.allowIPv6Set.AddElement(data.To16(), nil, false) 343 } 344 345 // ipv4 346 if this.allowIPv4Set == nil { 347 return errors.New("ipv4 ip set is nil") 348 } 349 return this.allowIPv4Set.AddElement(data.To4(), nil, false) 350 } 351 352 // RejectSourceIP 拒绝某个源IP连接 353 // we did not create set for drop ip, so we reuse DropSourceIP() method here 354 func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error { 355 return this.DropSourceIP(ip, timeoutSeconds, true) 356 } 357 358 // DropSourceIP 丢弃某个源IP数据 359 func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async bool) error { 360 var data = net.ParseIP(ip) 361 if data == nil { 362 return errors.New("invalid ip '" + ip + "'") 363 } 364 365 // 尝试关闭连接 366 conns.SharedMap.CloseIPConns(ip) 367 368 // 避免短时间内重复添加 369 if async && this.checkLatestIP(ip) { 370 return nil 371 } 372 373 if async { 374 select { 375 case this.dropIPQueue <- &blockIPItem{ 376 action: "drop", 377 ip: ip, 378 timeoutSeconds: timeoutSeconds, 379 }: 380 default: 381 return errors.New("drop ip queue is full") 382 } 383 return nil 384 } 385 386 // 再次尝试关闭连接 387 defer conns.SharedMap.CloseIPConns(ip) 388 389 if strings.Contains(ip, ":") { // ipv6 390 if len(this.denyIPv6Sets) == 0 { 391 return errors.New("ipv6 ip set not found") 392 } 393 var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv6Sets)) 394 return this.denyIPv6Sets[setIndex].AddElement(data.To16(), &nftables.ElementOptions{ 395 Timeout: time.Duration(timeoutSeconds) * time.Second, 396 }, false) 397 } 398 399 // ipv4 400 if len(this.denyIPv4Sets) == 0 { 401 return errors.New("ipv4 ip set not found") 402 } 403 var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv4Sets)) 404 return this.denyIPv4Sets[setIndex].AddElement(data.To4(), &nftables.ElementOptions{ 405 Timeout: time.Duration(timeoutSeconds) * time.Second, 406 }, false) 407 } 408 409 // RemoveSourceIP 删除某个源IP 410 func (this *NFTablesFirewall) RemoveSourceIP(ip string) error { 411 var data = net.ParseIP(ip) 412 if data == nil { 413 return errors.New("invalid ip '" + ip + "'") 414 } 415 416 if strings.Contains(ip, ":") { // ipv6 417 var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv6Sets)) 418 if len(this.denyIPv6Sets) > 0 { 419 err := this.denyIPv6Sets[setIndex].DeleteElement(data.To16()) 420 if err != nil { 421 return err 422 } 423 } 424 425 if this.allowIPv6Set != nil { 426 err := this.allowIPv6Set.DeleteElement(data.To16()) 427 if err != nil { 428 return err 429 } 430 } 431 432 return nil 433 } 434 435 // ipv4 436 if len(this.denyIPv4Sets) > 0 { 437 var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv4Sets)) 438 err := this.denyIPv4Sets[setIndex].DeleteElement(data.To4()) 439 if err != nil { 440 return err 441 } 442 } 443 if this.allowIPv4Set != nil { 444 err := this.allowIPv4Set.DeleteElement(data.To4()) 445 if err != nil { 446 return err 447 } 448 } 449 450 return nil 451 } 452 453 // 读取版本号 454 func (this *NFTablesFirewall) readVersion(nftPath string) string { 455 var cmd = executils.NewTimeoutCmd(10*time.Second, nftPath, "--version") 456 cmd.WithStdout() 457 err := cmd.Run() 458 if err != nil { 459 return "" 460 } 461 462 var outputString = cmd.Stdout() 463 var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString) 464 if len(versionMatches) <= 1 { 465 return "" 466 } 467 return versionMatches[1] 468 } 469 470 // 检查是否在最近添加过 471 func (this *NFTablesFirewall) existLatestIP(ip string) bool { 472 this.locker.Lock() 473 defer this.locker.Unlock() 474 475 var expiredIndex = -1 476 for index, ipTime := range this.latestIPTimes { 477 var pieces = strings.Split(ipTime, "@") 478 var oldIP = pieces[0] 479 var oldTimestamp = pieces[1] 480 if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ { 481 expiredIndex = index 482 continue 483 } 484 if oldIP == ip { 485 return true 486 } 487 } 488 489 if expiredIndex > -1 { 490 this.latestIPTimes = this.latestIPTimes[expiredIndex+1:] 491 } 492 493 this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix())) 494 const maxLen = 128 495 if len(this.latestIPTimes) > maxLen { 496 this.latestIPTimes = this.latestIPTimes[1:] 497 } 498 499 return false 500 }