github.com/TeaOSLab/EdgeNode@v1.3.8/internal/firewalls/nftables/set.go (about)

     1  // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
     2  //go:build linux
     3  
     4  package nftables
     5  
     6  import (
     7  	"errors"
     8  	"github.com/TeaOSLab/EdgeNode/internal/utils"
     9  	nft "github.com/google/nftables"
    10  	"net"
    11  	"strings"
    12  	"time"
    13  )
    14  
    15  const MaxSetNameLength = 15
    16  
    17  type SetOptions struct {
    18  	Id         uint32
    19  	HasTimeout bool
    20  	Timeout    time.Duration
    21  	KeyType    SetDataType
    22  	DataType   SetDataType
    23  	Constant   bool
    24  	Interval   bool
    25  	Anonymous  bool
    26  	IsMap      bool
    27  }
    28  
    29  type ElementOptions struct {
    30  	Timeout time.Duration
    31  }
    32  
    33  type Set struct {
    34  	conn   *Conn
    35  	rawSet *nft.Set
    36  	batch  *SetBatch
    37  
    38  	expiration *Expiration
    39  }
    40  
    41  func NewSet(conn *Conn, rawSet *nft.Set) *Set {
    42  	var set = &Set{
    43  		conn:       conn,
    44  		rawSet:     rawSet,
    45  		expiration: nil,
    46  		batch: &SetBatch{
    47  			conn:   conn,
    48  			rawSet: rawSet,
    49  		},
    50  	}
    51  
    52  	// retrieve set elements to improve "delete" speed
    53  	set.initElements()
    54  
    55  	return set
    56  }
    57  
    58  func (this *Set) Raw() *nft.Set {
    59  	return this.rawSet
    60  }
    61  
    62  func (this *Set) Name() string {
    63  	return this.rawSet.Name
    64  }
    65  
    66  func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error {
    67  	// check if already exists
    68  	if this.expiration != nil && !overwrite && this.expiration.Contains(key) {
    69  		return nil
    70  	}
    71  
    72  	var expiresTime = time.Time{}
    73  	var rawElement = nft.SetElement{
    74  		Key: key,
    75  	}
    76  	if options != nil {
    77  		rawElement.Timeout = options.Timeout
    78  
    79  		if options.Timeout > 0 {
    80  			expiresTime = time.UnixMilli(time.Now().UnixMilli() + options.Timeout.Milliseconds())
    81  		}
    82  	}
    83  	err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
    84  		rawElement,
    85  	})
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	err = this.conn.Commit()
    91  	if err == nil {
    92  		if this.expiration != nil {
    93  			this.expiration.Add(key, expiresTime)
    94  		}
    95  	} else {
    96  		var isFileExistsErr = strings.Contains(err.Error(), "file exists")
    97  		if !overwrite && isFileExistsErr {
    98  			// ignore file exists error
    99  			return nil
   100  		}
   101  
   102  		// retry if exists
   103  		if overwrite && isFileExistsErr {
   104  			deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
   105  				{
   106  					Key: key,
   107  				},
   108  			})
   109  			if deleteErr == nil {
   110  				err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
   111  					rawElement,
   112  				})
   113  				if err == nil {
   114  					err = this.conn.Commit()
   115  					if err == nil {
   116  						if this.expiration != nil {
   117  							this.expiration.Add(key, expiresTime)
   118  						}
   119  					}
   120  				}
   121  			}
   122  		}
   123  	}
   124  
   125  	return err
   126  }
   127  
   128  func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool) error {
   129  	var ipObj = net.ParseIP(ip)
   130  	if ipObj == nil {
   131  		return errors.New("invalid ip '" + ip + "'")
   132  	}
   133  
   134  	if utils.IsIPv4(ip) {
   135  		return this.AddElement(ipObj.To4(), options, overwrite)
   136  	} else {
   137  		return this.AddElement(ipObj.To16(), options, overwrite)
   138  	}
   139  }
   140  
   141  func (this *Set) DeleteElement(key []byte) error {
   142  	// if set element does not exist, we return immediately
   143  	if this.expiration != nil && !this.expiration.Contains(key) {
   144  		return nil
   145  	}
   146  
   147  	err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
   148  		{
   149  			Key: key,
   150  		},
   151  	})
   152  	if err != nil {
   153  		return err
   154  	}
   155  	err = this.conn.Commit()
   156  	if err == nil {
   157  		if this.expiration != nil {
   158  			this.expiration.Remove(key)
   159  		}
   160  	} else {
   161  		if strings.Contains(err.Error(), "no such file or directory") {
   162  			err = nil
   163  
   164  			if this.expiration != nil {
   165  				this.expiration.Remove(key)
   166  			}
   167  		}
   168  	}
   169  	return err
   170  }
   171  
   172  func (this *Set) DeleteIPElement(ip string) error {
   173  	var ipObj = net.ParseIP(ip)
   174  	if ipObj == nil {
   175  		return errors.New("invalid ip '" + ip + "'")
   176  	}
   177  
   178  	if utils.IsIPv4(ip) {
   179  		return this.DeleteElement(ipObj.To4())
   180  	} else {
   181  		return this.DeleteElement(ipObj.To16())
   182  	}
   183  }
   184  
   185  func (this *Set) Batch() *SetBatch {
   186  	return this.batch
   187  }
   188  
   189  func (this *Set) GetIPElements() ([]string, error) {
   190  	elements, err := this.conn.Raw().GetSetElements(this.rawSet)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  
   195  	var result = []string{}
   196  	for _, element := range elements {
   197  		result = append(result, net.IP(element.Key).String())
   198  	}
   199  	return result, nil
   200  }
   201  
   202  // not work current time
   203  /**func (this *Set) Flush() error {
   204  	this.conn.Raw().FlushSet(this.rawSet)
   205  	return this.conn.Commit()
   206  }**/