github.com/TeaOSLab/EdgeNode@v1.3.8/internal/firewalls/firewall_nftables.go (about)

     1  // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
     2  //go:build linux
     3  
     4  package firewalls
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
    10  	"github.com/TeaOSLab/EdgeNode/internal/conns"
    11  	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
    12  	"github.com/TeaOSLab/EdgeNode/internal/events"
    13  	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
    14  	"github.com/TeaOSLab/EdgeNode/internal/goman"
    15  	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
    16  	executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
    17  	"github.com/google/nftables/expr"
    18  	"github.com/iwind/TeaGo/types"
    19  	"net"
    20  	"regexp"
    21  	"runtime"
    22  	"strings"
    23  	"time"
    24  )
    25  
    26  // check nft status, if being enabled we load it automatically
    27  func init() {
    28  	if !teaconst.IsMain {
    29  		return
    30  	}
    31  
    32  	if runtime.GOOS == "linux" {
    33  		var ticker = time.NewTicker(3 * time.Minute)
    34  		goman.New(func() {
    35  			for range ticker.C {
    36  				// if already ready, we break
    37  				if nftablesIsReady {
    38  					ticker.Stop()
    39  					break
    40  				}
    41  				var nftExe = nftables.NftExePath()
    42  				if len(nftExe) > 0 {
    43  					nftablesFirewall, err := NewNFTablesFirewall()
    44  					if err != nil {
    45  						continue
    46  					}
    47  					currentFirewall = nftablesFirewall
    48  					remotelogs.Println("FIREWALL", "nftables is ready")
    49  
    50  					// fire event
    51  					if nftablesFirewall.IsReady() {
    52  						events.Notify(events.EventNFTablesReady)
    53  					}
    54  
    55  					ticker.Stop()
    56  					break
    57  				}
    58  			}
    59  		})
    60  	}
    61  }
    62  
    63  var nftablesInstance *NFTablesFirewall
    64  var nftablesIsReady = false
    65  var nftablesFilters = []*nftablesTableDefinition{
    66  	// we shorten the name for table name length restriction
    67  	{Name: "edge_dft_v4", IsIPv4: true},
    68  	{Name: "edge_dft_v6", IsIPv6: true},
    69  }
    70  var nftablesChainName = "input"
    71  
    72  type nftablesTableDefinition struct {
    73  	Name   string
    74  	IsIPv4 bool
    75  	IsIPv6 bool
    76  }
    77  
    78  func (this *nftablesTableDefinition) protocol() string {
    79  	if this.IsIPv6 {
    80  		return "ip6"
    81  	}
    82  	return "ip"
    83  }
    84  
    85  type blockIPItem struct {
    86  	action         string
    87  	ip             string
    88  	timeoutSeconds int
    89  }
    90  
    91  func NewNFTablesFirewall() (*NFTablesFirewall, error) {
    92  	conn, err := nftables.NewConn()
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	var firewall = &NFTablesFirewall{
    97  		conn:        conn,
    98  		dropIPQueue: make(chan *blockIPItem, 4096),
    99  	}
   100  	err = firewall.init()
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	return firewall, nil
   106  }
   107  
   108  type NFTablesFirewall struct {
   109  	BaseFirewall
   110  
   111  	conn    *nftables.Conn
   112  	isReady bool
   113  	version string
   114  
   115  	allowIPv4Set *nftables.Set
   116  	allowIPv6Set *nftables.Set
   117  
   118  	denyIPv4Sets []*nftables.Set
   119  	denyIPv6Sets []*nftables.Set
   120  
   121  	firewalld *Firewalld
   122  
   123  	dropIPQueue chan *blockIPItem
   124  }
   125  
   126  func (this *NFTablesFirewall) init() error {
   127  	// check nft
   128  	var nftPath = nftables.NftExePath()
   129  	if len(nftPath) == 0 {
   130  		return errors.New("'nft' not found")
   131  	}
   132  	this.version = this.readVersion(nftPath)
   133  
   134  	// table
   135  	for _, tableDef := range nftablesFilters {
   136  		var family nftables.TableFamily
   137  		if tableDef.IsIPv4 {
   138  			family = nftables.TableFamilyIPv4
   139  		} else if tableDef.IsIPv6 {
   140  			family = nftables.TableFamilyIPv6
   141  		} else {
   142  			return errors.New("invalid table family: " + types.String(tableDef))
   143  		}
   144  		table, err := this.conn.GetTable(tableDef.Name, family)
   145  		if err != nil {
   146  			if nftables.IsNotFound(err) {
   147  				if tableDef.IsIPv4 {
   148  					table, err = this.conn.AddIPv4Table(tableDef.Name)
   149  				} else if tableDef.IsIPv6 {
   150  					table, err = this.conn.AddIPv6Table(tableDef.Name)
   151  				}
   152  				if err != nil {
   153  					return fmt.Errorf("create table '%s' failed: %w", tableDef.Name, err)
   154  				}
   155  			} else {
   156  				return fmt.Errorf("get table '%s' failed: %w", tableDef.Name, err)
   157  			}
   158  		}
   159  		if table == nil {
   160  			return errors.New("can not create table '" + tableDef.Name + "'")
   161  		}
   162  
   163  		// chain
   164  		var chainName = nftablesChainName
   165  		chain, err := table.GetChain(chainName)
   166  		if err != nil {
   167  			if nftables.IsNotFound(err) {
   168  				chain, err = table.AddAcceptChain(chainName)
   169  				if err != nil {
   170  					return fmt.Errorf("create chain '%s' failed: %w", chainName, err)
   171  				}
   172  			} else {
   173  				return fmt.Errorf("get chain '%s' failed: %w", chainName, err)
   174  			}
   175  		}
   176  		if chain == nil {
   177  			return errors.New("can not create chain '" + chainName + "'")
   178  		}
   179  
   180  		// allow lo
   181  		var loRuleName = []byte("lo")
   182  		_, err = chain.GetRuleWithUserData(loRuleName)
   183  		if err != nil {
   184  			if nftables.IsNotFound(err) {
   185  				_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
   186  			}
   187  			if err != nil {
   188  				return fmt.Errorf("add 'lo' rule failed: %w", err)
   189  			}
   190  		}
   191  
   192  		// allow set
   193  		// "allow" should be always first
   194  		for _, setAction := range []string{"allow", "deny", "deny1", "deny2", "deny3", "deny4"} {
   195  			var setName = setAction + "_set"
   196  
   197  			set, err := table.GetSet(setName)
   198  			if err != nil {
   199  				if nftables.IsNotFound(err) {
   200  					var keyType nftables.SetDataType
   201  					if tableDef.IsIPv4 {
   202  						keyType = nftables.TypeIPAddr
   203  					} else if tableDef.IsIPv6 {
   204  						keyType = nftables.TypeIP6Addr
   205  					}
   206  					set, err = table.AddSet(setName, &nftables.SetOptions{
   207  						KeyType:    keyType,
   208  						HasTimeout: true,
   209  					})
   210  					if err != nil {
   211  						return fmt.Errorf("create set '%s' failed: %w", setName, err)
   212  					}
   213  				} else {
   214  					return fmt.Errorf("get set '%s' failed: %w", setName, err)
   215  				}
   216  			}
   217  			if set == nil {
   218  				return errors.New("can not create set '" + setName + "'")
   219  			}
   220  			if tableDef.IsIPv4 {
   221  				if setAction == "allow" {
   222  					this.allowIPv4Set = set
   223  				} else {
   224  					this.denyIPv4Sets = append(this.denyIPv4Sets, set)
   225  				}
   226  			} else if tableDef.IsIPv6 {
   227  				if setAction == "allow" {
   228  					this.allowIPv6Set = set
   229  				} else {
   230  					this.denyIPv6Sets = append(this.denyIPv6Sets, set)
   231  				}
   232  			}
   233  
   234  			// rule
   235  			var ruleName = []byte(setAction)
   236  			rule, err := chain.GetRuleWithUserData(ruleName)
   237  
   238  			// 将以前的drop规则删掉,替换成后面的reject
   239  			if err == nil && setAction != "allow" && rule != nil && rule.VerDict() == expr.VerdictDrop {
   240  				deleteErr := chain.DeleteRule(rule)
   241  				if deleteErr == nil {
   242  					err = nftables.ErrRuleNotFound
   243  					rule = nil
   244  				}
   245  			}
   246  
   247  			if err != nil {
   248  				if nftables.IsNotFound(err) {
   249  					if tableDef.IsIPv4 {
   250  						if setAction == "allow" {
   251  							rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
   252  						} else {
   253  							rule, err = chain.AddRejectIPv4SetRule(setName, ruleName)
   254  						}
   255  					} else if tableDef.IsIPv6 {
   256  						if setAction == "allow" {
   257  							rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
   258  						} else {
   259  							rule, err = chain.AddRejectIPv6SetRule(setName, ruleName)
   260  						}
   261  					}
   262  					if err != nil {
   263  						return fmt.Errorf("add rule failed: %w", err)
   264  					}
   265  				} else {
   266  					return fmt.Errorf("get rule failed: %w", err)
   267  				}
   268  			}
   269  			if rule == nil {
   270  				return errors.New("can not create rule '" + string(ruleName) + "'")
   271  			}
   272  		}
   273  	}
   274  
   275  	this.isReady = true
   276  	nftablesIsReady = true
   277  	nftablesInstance = this
   278  
   279  	goman.New(func() {
   280  		for ipItem := range this.dropIPQueue {
   281  			switch ipItem.action {
   282  			case "drop":
   283  				err := this.DropSourceIP(ipItem.ip, ipItem.timeoutSeconds, false)
   284  				if err != nil {
   285  					remotelogs.Warn("NFTABLES", "drop ip '"+ipItem.ip+"' failed: "+err.Error())
   286  				}
   287  			}
   288  		}
   289  	})
   290  
   291  	// load firewalld
   292  	var firewalld = NewFirewalld()
   293  	if firewalld.IsReady() {
   294  		this.firewalld = firewalld
   295  	}
   296  
   297  	return nil
   298  }
   299  
   300  // Name 名称
   301  func (this *NFTablesFirewall) Name() string {
   302  	return "nftables"
   303  }
   304  
   305  // IsReady 是否已准备被调用
   306  func (this *NFTablesFirewall) IsReady() bool {
   307  	return this.isReady
   308  }
   309  
   310  // IsMock 是否为模拟
   311  func (this *NFTablesFirewall) IsMock() bool {
   312  	return false
   313  }
   314  
   315  // AllowPort 允许端口
   316  func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
   317  	if this.firewalld != nil {
   318  		return this.firewalld.AllowPort(port, protocol)
   319  	}
   320  	return nil
   321  }
   322  
   323  // RemovePort 删除端口
   324  func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
   325  	if this.firewalld != nil {
   326  		return this.firewalld.RemovePort(port, protocol)
   327  	}
   328  	return nil
   329  }
   330  
   331  // AllowSourceIP Allow把IP加入白名单
   332  func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
   333  	var data = net.ParseIP(ip)
   334  	if data == nil {
   335  		return errors.New("invalid ip '" + ip + "'")
   336  	}
   337  
   338  	if strings.Contains(ip, ":") { // ipv6
   339  		if this.allowIPv6Set == nil {
   340  			return errors.New("ipv6 ip set is nil")
   341  		}
   342  		return this.allowIPv6Set.AddElement(data.To16(), nil, false)
   343  	}
   344  
   345  	// ipv4
   346  	if this.allowIPv4Set == nil {
   347  		return errors.New("ipv4 ip set is nil")
   348  	}
   349  	return this.allowIPv4Set.AddElement(data.To4(), nil, false)
   350  }
   351  
   352  // RejectSourceIP 拒绝某个源IP连接
   353  // we did not create set for drop ip, so we reuse DropSourceIP() method here
   354  func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
   355  	return this.DropSourceIP(ip, timeoutSeconds, true)
   356  }
   357  
   358  // DropSourceIP 丢弃某个源IP数据
   359  func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async bool) error {
   360  	var data = net.ParseIP(ip)
   361  	if data == nil {
   362  		return errors.New("invalid ip '" + ip + "'")
   363  	}
   364  
   365  	// 尝试关闭连接
   366  	conns.SharedMap.CloseIPConns(ip)
   367  
   368  	// 避免短时间内重复添加
   369  	if async && this.checkLatestIP(ip) {
   370  		return nil
   371  	}
   372  
   373  	if async {
   374  		select {
   375  		case this.dropIPQueue <- &blockIPItem{
   376  			action:         "drop",
   377  			ip:             ip,
   378  			timeoutSeconds: timeoutSeconds,
   379  		}:
   380  		default:
   381  			return errors.New("drop ip queue is full")
   382  		}
   383  		return nil
   384  	}
   385  
   386  	// 再次尝试关闭连接
   387  	defer conns.SharedMap.CloseIPConns(ip)
   388  
   389  	if strings.Contains(ip, ":") { // ipv6
   390  		if len(this.denyIPv6Sets) == 0 {
   391  			return errors.New("ipv6 ip set not found")
   392  		}
   393  		var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv6Sets))
   394  		return this.denyIPv6Sets[setIndex].AddElement(data.To16(), &nftables.ElementOptions{
   395  			Timeout: time.Duration(timeoutSeconds) * time.Second,
   396  		}, false)
   397  	}
   398  
   399  	// ipv4
   400  	if len(this.denyIPv4Sets) == 0 {
   401  		return errors.New("ipv4 ip set not found")
   402  	}
   403  	var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv4Sets))
   404  	return this.denyIPv4Sets[setIndex].AddElement(data.To4(), &nftables.ElementOptions{
   405  		Timeout: time.Duration(timeoutSeconds) * time.Second,
   406  	}, false)
   407  }
   408  
   409  // RemoveSourceIP 删除某个源IP
   410  func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
   411  	var data = net.ParseIP(ip)
   412  	if data == nil {
   413  		return errors.New("invalid ip '" + ip + "'")
   414  	}
   415  
   416  	if strings.Contains(ip, ":") { // ipv6
   417  		var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv6Sets))
   418  		if len(this.denyIPv6Sets) > 0 {
   419  			err := this.denyIPv6Sets[setIndex].DeleteElement(data.To16())
   420  			if err != nil {
   421  				return err
   422  			}
   423  		}
   424  
   425  		if this.allowIPv6Set != nil {
   426  			err := this.allowIPv6Set.DeleteElement(data.To16())
   427  			if err != nil {
   428  				return err
   429  			}
   430  		}
   431  
   432  		return nil
   433  	}
   434  
   435  	// ipv4
   436  	if len(this.denyIPv4Sets) > 0 {
   437  		var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv4Sets))
   438  		err := this.denyIPv4Sets[setIndex].DeleteElement(data.To4())
   439  		if err != nil {
   440  			return err
   441  		}
   442  	}
   443  	if this.allowIPv4Set != nil {
   444  		err := this.allowIPv4Set.DeleteElement(data.To4())
   445  		if err != nil {
   446  			return err
   447  		}
   448  	}
   449  
   450  	return nil
   451  }
   452  
   453  // 读取版本号
   454  func (this *NFTablesFirewall) readVersion(nftPath string) string {
   455  	var cmd = executils.NewTimeoutCmd(10*time.Second, nftPath, "--version")
   456  	cmd.WithStdout()
   457  	err := cmd.Run()
   458  	if err != nil {
   459  		return ""
   460  	}
   461  
   462  	var outputString = cmd.Stdout()
   463  	var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
   464  	if len(versionMatches) <= 1 {
   465  		return ""
   466  	}
   467  	return versionMatches[1]
   468  }
   469  
   470  // 检查是否在最近添加过
   471  func (this *NFTablesFirewall) existLatestIP(ip string) bool {
   472  	this.locker.Lock()
   473  	defer this.locker.Unlock()
   474  
   475  	var expiredIndex = -1
   476  	for index, ipTime := range this.latestIPTimes {
   477  		var pieces = strings.Split(ipTime, "@")
   478  		var oldIP = pieces[0]
   479  		var oldTimestamp = pieces[1]
   480  		if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ {
   481  			expiredIndex = index
   482  			continue
   483  		}
   484  		if oldIP == ip {
   485  			return true
   486  		}
   487  	}
   488  
   489  	if expiredIndex > -1 {
   490  		this.latestIPTimes = this.latestIPTimes[expiredIndex+1:]
   491  	}
   492  
   493  	this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix()))
   494  	const maxLen = 128
   495  	if len(this.latestIPTimes) > maxLen {
   496  		this.latestIPTimes = this.latestIPTimes[1:]
   497  	}
   498  
   499  	return false
   500  }