github.com/kelleygo/clashcore@v1.0.2/component/trie/ipcidr_trie.go (about)

     1  package trie
     2  
     3  import (
     4  	"net"
     5  
     6  	"github.com/kelleygo/clashcore/log"
     7  )
     8  
     9  type IPV6 bool
    10  
    11  const (
    12  	ipv4GroupMaxValue = 0xFF
    13  	ipv6GroupMaxValue = 0xFFFF
    14  )
    15  
    16  type IpCidrTrie struct {
    17  	ipv4Trie *IpCidrNode
    18  	ipv6Trie *IpCidrNode
    19  }
    20  
    21  func NewIpCidrTrie() *IpCidrTrie {
    22  	return &IpCidrTrie{
    23  		ipv4Trie: NewIpCidrNode(false, ipv4GroupMaxValue),
    24  		ipv6Trie: NewIpCidrNode(false, ipv6GroupMaxValue),
    25  	}
    26  }
    27  
    28  func (trie *IpCidrTrie) AddIpCidr(ipCidr *net.IPNet) error {
    29  	subIpCidr, subCidr, isIpv4, err := ipCidrToSubIpCidr(ipCidr)
    30  	if err != nil {
    31  		return err
    32  	}
    33  
    34  	for _, sub := range subIpCidr {
    35  		addIpCidr(trie, isIpv4, sub, subCidr/8)
    36  	}
    37  
    38  	return nil
    39  }
    40  
    41  func (trie *IpCidrTrie) AddIpCidrForString(ipCidr string) error {
    42  	_, ipNet, err := net.ParseCIDR(ipCidr)
    43  	if err != nil {
    44  		return err
    45  	}
    46  
    47  	return trie.AddIpCidr(ipNet)
    48  }
    49  
    50  func (trie *IpCidrTrie) IsContain(ip net.IP) bool {
    51  	if ip == nil {
    52  		return false
    53  	}
    54  	isIpv4 := len(ip) == net.IPv4len
    55  	var groupValues []uint32
    56  	var ipCidrNode *IpCidrNode
    57  
    58  	if isIpv4 {
    59  		ipCidrNode = trie.ipv4Trie
    60  		for _, group := range ip {
    61  			groupValues = append(groupValues, uint32(group))
    62  		}
    63  	} else {
    64  		ipCidrNode = trie.ipv6Trie
    65  		for i := 0; i < len(ip); i += 2 {
    66  			groupValues = append(groupValues, getIpv6GroupValue(ip[i], ip[i+1]))
    67  		}
    68  	}
    69  
    70  	return search(ipCidrNode, groupValues) != nil
    71  }
    72  
    73  func (trie *IpCidrTrie) IsContainForString(ipString string) bool {
    74  	ip := net.ParseIP(ipString)
    75  	// deal with 4in6
    76  	actualIp := ip.To4()
    77  	if actualIp == nil {
    78  		actualIp = ip
    79  	}
    80  	return trie.IsContain(actualIp)
    81  }
    82  
    83  func ipCidrToSubIpCidr(ipNet *net.IPNet) ([]net.IP, int, bool, error) {
    84  	maskSize, _ := ipNet.Mask.Size()
    85  	var (
    86  		ipList      []net.IP
    87  		newMaskSize int
    88  		isIpv4      bool
    89  		err         error
    90  	)
    91  	isIpv4 = len(ipNet.IP) == net.IPv4len
    92  	ipList, newMaskSize, err = subIpCidr(ipNet.IP, maskSize, isIpv4)
    93  
    94  	return ipList, newMaskSize, isIpv4, err
    95  }
    96  
    97  func subIpCidr(ip net.IP, maskSize int, isIpv4 bool) ([]net.IP, int, error) {
    98  	var subIpCidrList []net.IP
    99  	groupSize := 8
   100  	if !isIpv4 {
   101  		groupSize = 16
   102  	}
   103  
   104  	if maskSize%groupSize == 0 {
   105  		return append(subIpCidrList, ip), maskSize, nil
   106  	}
   107  
   108  	lastByteMaskSize := maskSize % 8
   109  	lastByteMaskIndex := maskSize / 8
   110  	subIpCidrNum := 0xFF >> lastByteMaskSize
   111  	for i := 0; i <= subIpCidrNum; i++ {
   112  		subIpCidr := make([]byte, len(ip))
   113  		copy(subIpCidr, ip)
   114  		subIpCidr[lastByteMaskIndex] += byte(i)
   115  		subIpCidrList = append(subIpCidrList, subIpCidr)
   116  	}
   117  
   118  	newMaskSize := (lastByteMaskIndex + 1) * 8
   119  	if !isIpv4 {
   120  		newMaskSize = (lastByteMaskIndex/2 + 1) * 16
   121  	}
   122  
   123  	return subIpCidrList, newMaskSize, nil
   124  }
   125  
   126  func addIpCidr(trie *IpCidrTrie, isIpv4 bool, ip net.IP, groupSize int) {
   127  	if isIpv4 {
   128  		addIpv4Cidr(trie, ip, groupSize)
   129  	} else {
   130  		addIpv6Cidr(trie, ip, groupSize)
   131  	}
   132  }
   133  
   134  func addIpv4Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) {
   135  	preNode := trie.ipv4Trie
   136  	node := preNode.getChild(uint32(ip[0]))
   137  	if node == nil {
   138  		err := preNode.addChild(uint32(ip[0]))
   139  		if err != nil {
   140  			return
   141  		}
   142  
   143  		node = preNode.getChild(uint32(ip[0]))
   144  	}
   145  
   146  	for i := 1; i < groupSize; i++ {
   147  		if node.Mark {
   148  			return
   149  		}
   150  
   151  		groupValue := uint32(ip[i])
   152  		if !node.hasChild(groupValue) {
   153  			err := node.addChild(groupValue)
   154  			if err != nil {
   155  				log.Errorln(err.Error())
   156  			}
   157  		}
   158  
   159  		preNode = node
   160  		node = node.getChild(groupValue)
   161  		if node == nil {
   162  			err := preNode.addChild(uint32(ip[i-1]))
   163  			if err != nil {
   164  				return
   165  			}
   166  
   167  			node = preNode.getChild(uint32(ip[i-1]))
   168  		}
   169  	}
   170  
   171  	node.Mark = true
   172  	cleanChild(node)
   173  }
   174  
   175  func addIpv6Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) {
   176  	preNode := trie.ipv6Trie
   177  	node := preNode.getChild(getIpv6GroupValue(ip[0], ip[1]))
   178  	if node == nil {
   179  		err := preNode.addChild(getIpv6GroupValue(ip[0], ip[1]))
   180  		if err != nil {
   181  			return
   182  		}
   183  
   184  		node = preNode.getChild(getIpv6GroupValue(ip[0], ip[1]))
   185  	}
   186  
   187  	for i := 2; i < groupSize; i += 2 {
   188  		if ip[i] == 0 && ip[i+1] == 0 {
   189  			node.Mark = true
   190  		}
   191  
   192  		if node.Mark {
   193  			return
   194  		}
   195  
   196  		groupValue := getIpv6GroupValue(ip[i], ip[i+1])
   197  		if !node.hasChild(groupValue) {
   198  			err := node.addChild(groupValue)
   199  			if err != nil {
   200  				log.Errorln(err.Error())
   201  			}
   202  		}
   203  
   204  		preNode = node
   205  		node = node.getChild(groupValue)
   206  		if node == nil {
   207  			err := preNode.addChild(getIpv6GroupValue(ip[i-2], ip[i-1]))
   208  			if err != nil {
   209  				return
   210  			}
   211  
   212  			node = preNode.getChild(getIpv6GroupValue(ip[i-2], ip[i-1]))
   213  		}
   214  	}
   215  
   216  	node.Mark = true
   217  	cleanChild(node)
   218  }
   219  
   220  func getIpv6GroupValue(high, low byte) uint32 {
   221  	return (uint32(high) << 8) | uint32(low)
   222  }
   223  
   224  func cleanChild(node *IpCidrNode) {
   225  	for i := uint32(0); i < uint32(len(node.child)); i++ {
   226  		delete(node.child, i)
   227  	}
   228  }
   229  
   230  func search(root *IpCidrNode, groupValues []uint32) *IpCidrNode {
   231  	node := root.getChild(groupValues[0])
   232  	if node == nil || node.Mark {
   233  		return node
   234  	}
   235  
   236  	for _, value := range groupValues[1:] {
   237  		if !node.hasChild(value) {
   238  			return nil
   239  		}
   240  
   241  		node = node.getChild(value)
   242  
   243  		if node == nil || node.Mark {
   244  			return node
   245  		}
   246  	}
   247  
   248  	return nil
   249  }