github.com/TeaOSLab/EdgeNode@v1.3.8/internal/firewalls/nftables/set.go (about) 1 // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. 2 //go:build linux 3 4 package nftables 5 6 import ( 7 "errors" 8 "github.com/TeaOSLab/EdgeNode/internal/utils" 9 nft "github.com/google/nftables" 10 "net" 11 "strings" 12 "time" 13 ) 14 15 const MaxSetNameLength = 15 16 17 type SetOptions struct { 18 Id uint32 19 HasTimeout bool 20 Timeout time.Duration 21 KeyType SetDataType 22 DataType SetDataType 23 Constant bool 24 Interval bool 25 Anonymous bool 26 IsMap bool 27 } 28 29 type ElementOptions struct { 30 Timeout time.Duration 31 } 32 33 type Set struct { 34 conn *Conn 35 rawSet *nft.Set 36 batch *SetBatch 37 38 expiration *Expiration 39 } 40 41 func NewSet(conn *Conn, rawSet *nft.Set) *Set { 42 var set = &Set{ 43 conn: conn, 44 rawSet: rawSet, 45 expiration: nil, 46 batch: &SetBatch{ 47 conn: conn, 48 rawSet: rawSet, 49 }, 50 } 51 52 // retrieve set elements to improve "delete" speed 53 set.initElements() 54 55 return set 56 } 57 58 func (this *Set) Raw() *nft.Set { 59 return this.rawSet 60 } 61 62 func (this *Set) Name() string { 63 return this.rawSet.Name 64 } 65 66 func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error { 67 // check if already exists 68 if this.expiration != nil && !overwrite && this.expiration.Contains(key) { 69 return nil 70 } 71 72 var expiresTime = time.Time{} 73 var rawElement = nft.SetElement{ 74 Key: key, 75 } 76 if options != nil { 77 rawElement.Timeout = options.Timeout 78 79 if options.Timeout > 0 { 80 expiresTime = time.UnixMilli(time.Now().UnixMilli() + options.Timeout.Milliseconds()) 81 } 82 } 83 err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{ 84 rawElement, 85 }) 86 if err != nil { 87 return err 88 } 89 90 err = this.conn.Commit() 91 if err == nil { 92 if this.expiration != nil { 93 this.expiration.Add(key, expiresTime) 94 } 95 } else { 96 var isFileExistsErr = strings.Contains(err.Error(), "file exists") 97 if !overwrite && isFileExistsErr { 98 // ignore file exists error 99 return nil 100 } 101 102 // retry if exists 103 if overwrite && isFileExistsErr { 104 deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ 105 { 106 Key: key, 107 }, 108 }) 109 if deleteErr == nil { 110 err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{ 111 rawElement, 112 }) 113 if err == nil { 114 err = this.conn.Commit() 115 if err == nil { 116 if this.expiration != nil { 117 this.expiration.Add(key, expiresTime) 118 } 119 } 120 } 121 } 122 } 123 } 124 125 return err 126 } 127 128 func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool) error { 129 var ipObj = net.ParseIP(ip) 130 if ipObj == nil { 131 return errors.New("invalid ip '" + ip + "'") 132 } 133 134 if utils.IsIPv4(ip) { 135 return this.AddElement(ipObj.To4(), options, overwrite) 136 } else { 137 return this.AddElement(ipObj.To16(), options, overwrite) 138 } 139 } 140 141 func (this *Set) DeleteElement(key []byte) error { 142 // if set element does not exist, we return immediately 143 if this.expiration != nil && !this.expiration.Contains(key) { 144 return nil 145 } 146 147 err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ 148 { 149 Key: key, 150 }, 151 }) 152 if err != nil { 153 return err 154 } 155 err = this.conn.Commit() 156 if err == nil { 157 if this.expiration != nil { 158 this.expiration.Remove(key) 159 } 160 } else { 161 if strings.Contains(err.Error(), "no such file or directory") { 162 err = nil 163 164 if this.expiration != nil { 165 this.expiration.Remove(key) 166 } 167 } 168 } 169 return err 170 } 171 172 func (this *Set) DeleteIPElement(ip string) error { 173 var ipObj = net.ParseIP(ip) 174 if ipObj == nil { 175 return errors.New("invalid ip '" + ip + "'") 176 } 177 178 if utils.IsIPv4(ip) { 179 return this.DeleteElement(ipObj.To4()) 180 } else { 181 return this.DeleteElement(ipObj.To16()) 182 } 183 } 184 185 func (this *Set) Batch() *SetBatch { 186 return this.batch 187 } 188 189 func (this *Set) GetIPElements() ([]string, error) { 190 elements, err := this.conn.Raw().GetSetElements(this.rawSet) 191 if err != nil { 192 return nil, err 193 } 194 195 var result = []string{} 196 for _, element := range elements { 197 result = append(result, net.IP(element.Key).String()) 198 } 199 return result, nil 200 } 201 202 // not work current time 203 /**func (this *Set) Flush() error { 204 this.conn.Raw().FlushSet(this.rawSet) 205 return this.conn.Commit() 206 }**/