github.com/kelleygo/clashcore@v1.0.2/component/trie/ipcidr_trie.go (about) 1 package trie 2 3 import ( 4 "net" 5 6 "github.com/kelleygo/clashcore/log" 7 ) 8 9 type IPV6 bool 10 11 const ( 12 ipv4GroupMaxValue = 0xFF 13 ipv6GroupMaxValue = 0xFFFF 14 ) 15 16 type IpCidrTrie struct { 17 ipv4Trie *IpCidrNode 18 ipv6Trie *IpCidrNode 19 } 20 21 func NewIpCidrTrie() *IpCidrTrie { 22 return &IpCidrTrie{ 23 ipv4Trie: NewIpCidrNode(false, ipv4GroupMaxValue), 24 ipv6Trie: NewIpCidrNode(false, ipv6GroupMaxValue), 25 } 26 } 27 28 func (trie *IpCidrTrie) AddIpCidr(ipCidr *net.IPNet) error { 29 subIpCidr, subCidr, isIpv4, err := ipCidrToSubIpCidr(ipCidr) 30 if err != nil { 31 return err 32 } 33 34 for _, sub := range subIpCidr { 35 addIpCidr(trie, isIpv4, sub, subCidr/8) 36 } 37 38 return nil 39 } 40 41 func (trie *IpCidrTrie) AddIpCidrForString(ipCidr string) error { 42 _, ipNet, err := net.ParseCIDR(ipCidr) 43 if err != nil { 44 return err 45 } 46 47 return trie.AddIpCidr(ipNet) 48 } 49 50 func (trie *IpCidrTrie) IsContain(ip net.IP) bool { 51 if ip == nil { 52 return false 53 } 54 isIpv4 := len(ip) == net.IPv4len 55 var groupValues []uint32 56 var ipCidrNode *IpCidrNode 57 58 if isIpv4 { 59 ipCidrNode = trie.ipv4Trie 60 for _, group := range ip { 61 groupValues = append(groupValues, uint32(group)) 62 } 63 } else { 64 ipCidrNode = trie.ipv6Trie 65 for i := 0; i < len(ip); i += 2 { 66 groupValues = append(groupValues, getIpv6GroupValue(ip[i], ip[i+1])) 67 } 68 } 69 70 return search(ipCidrNode, groupValues) != nil 71 } 72 73 func (trie *IpCidrTrie) IsContainForString(ipString string) bool { 74 ip := net.ParseIP(ipString) 75 // deal with 4in6 76 actualIp := ip.To4() 77 if actualIp == nil { 78 actualIp = ip 79 } 80 return trie.IsContain(actualIp) 81 } 82 83 func ipCidrToSubIpCidr(ipNet *net.IPNet) ([]net.IP, int, bool, error) { 84 maskSize, _ := ipNet.Mask.Size() 85 var ( 86 ipList []net.IP 87 newMaskSize int 88 isIpv4 bool 89 err error 90 ) 91 isIpv4 = len(ipNet.IP) == net.IPv4len 92 ipList, newMaskSize, err = subIpCidr(ipNet.IP, maskSize, isIpv4) 93 94 return ipList, newMaskSize, isIpv4, err 95 } 96 97 func subIpCidr(ip net.IP, maskSize int, isIpv4 bool) ([]net.IP, int, error) { 98 var subIpCidrList []net.IP 99 groupSize := 8 100 if !isIpv4 { 101 groupSize = 16 102 } 103 104 if maskSize%groupSize == 0 { 105 return append(subIpCidrList, ip), maskSize, nil 106 } 107 108 lastByteMaskSize := maskSize % 8 109 lastByteMaskIndex := maskSize / 8 110 subIpCidrNum := 0xFF >> lastByteMaskSize 111 for i := 0; i <= subIpCidrNum; i++ { 112 subIpCidr := make([]byte, len(ip)) 113 copy(subIpCidr, ip) 114 subIpCidr[lastByteMaskIndex] += byte(i) 115 subIpCidrList = append(subIpCidrList, subIpCidr) 116 } 117 118 newMaskSize := (lastByteMaskIndex + 1) * 8 119 if !isIpv4 { 120 newMaskSize = (lastByteMaskIndex/2 + 1) * 16 121 } 122 123 return subIpCidrList, newMaskSize, nil 124 } 125 126 func addIpCidr(trie *IpCidrTrie, isIpv4 bool, ip net.IP, groupSize int) { 127 if isIpv4 { 128 addIpv4Cidr(trie, ip, groupSize) 129 } else { 130 addIpv6Cidr(trie, ip, groupSize) 131 } 132 } 133 134 func addIpv4Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) { 135 preNode := trie.ipv4Trie 136 node := preNode.getChild(uint32(ip[0])) 137 if node == nil { 138 err := preNode.addChild(uint32(ip[0])) 139 if err != nil { 140 return 141 } 142 143 node = preNode.getChild(uint32(ip[0])) 144 } 145 146 for i := 1; i < groupSize; i++ { 147 if node.Mark { 148 return 149 } 150 151 groupValue := uint32(ip[i]) 152 if !node.hasChild(groupValue) { 153 err := node.addChild(groupValue) 154 if err != nil { 155 log.Errorln(err.Error()) 156 } 157 } 158 159 preNode = node 160 node = node.getChild(groupValue) 161 if node == nil { 162 err := preNode.addChild(uint32(ip[i-1])) 163 if err != nil { 164 return 165 } 166 167 node = preNode.getChild(uint32(ip[i-1])) 168 } 169 } 170 171 node.Mark = true 172 cleanChild(node) 173 } 174 175 func addIpv6Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) { 176 preNode := trie.ipv6Trie 177 node := preNode.getChild(getIpv6GroupValue(ip[0], ip[1])) 178 if node == nil { 179 err := preNode.addChild(getIpv6GroupValue(ip[0], ip[1])) 180 if err != nil { 181 return 182 } 183 184 node = preNode.getChild(getIpv6GroupValue(ip[0], ip[1])) 185 } 186 187 for i := 2; i < groupSize; i += 2 { 188 if ip[i] == 0 && ip[i+1] == 0 { 189 node.Mark = true 190 } 191 192 if node.Mark { 193 return 194 } 195 196 groupValue := getIpv6GroupValue(ip[i], ip[i+1]) 197 if !node.hasChild(groupValue) { 198 err := node.addChild(groupValue) 199 if err != nil { 200 log.Errorln(err.Error()) 201 } 202 } 203 204 preNode = node 205 node = node.getChild(groupValue) 206 if node == nil { 207 err := preNode.addChild(getIpv6GroupValue(ip[i-2], ip[i-1])) 208 if err != nil { 209 return 210 } 211 212 node = preNode.getChild(getIpv6GroupValue(ip[i-2], ip[i-1])) 213 } 214 } 215 216 node.Mark = true 217 cleanChild(node) 218 } 219 220 func getIpv6GroupValue(high, low byte) uint32 { 221 return (uint32(high) << 8) | uint32(low) 222 } 223 224 func cleanChild(node *IpCidrNode) { 225 for i := uint32(0); i < uint32(len(node.child)); i++ { 226 delete(node.child, i) 227 } 228 } 229 230 func search(root *IpCidrNode, groupValues []uint32) *IpCidrNode { 231 node := root.getChild(groupValues[0]) 232 if node == nil || node.Mark { 233 return node 234 } 235 236 for _, value := range groupValues[1:] { 237 if !node.hasChild(value) { 238 return nil 239 } 240 241 node = node.getChild(value) 242 243 if node == nil || node.Mark { 244 return node 245 } 246 } 247 248 return nil 249 }