github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/portset/portset.go (about) 1 package portset 2 3 import ( 4 "errors" 5 "fmt" 6 "math/bits" 7 "strconv" 8 "strings" 9 ) 10 11 const blockBits = bits.UintSize 12 13 var ErrZeroPort = errors.New("port number cannot be zero") 14 15 // PortSet is a bit set for ports. 16 type PortSet struct { 17 blocks [65536 / blockBits]uint 18 } 19 20 // Count returns the number of ports in the set. 21 func (s *PortSet) Count() (count uint) { 22 for i := range s.blocks { 23 count += uint(bits.OnesCount(s.blocks[i])) 24 } 25 return 26 } 27 28 // First returns the first port in the set. 29 func (s *PortSet) First() uint16 { 30 for i, block := range s.blocks { 31 if block == 0 { 32 continue 33 } 34 return uint16(i*blockBits + bits.TrailingZeros(block)) 35 } 36 return 0 37 } 38 39 // RangeCount returns the number of port ranges in the set. 40 func (s *PortSet) RangeCount() (count uint) { 41 var inRange bool 42 for _, block := range s.blocks { 43 bitsRemaining := uint(blockBits) 44 for { 45 trailingZeros := uint(bits.TrailingZeros(block)) 46 if trailingZeros != 0 { 47 if inRange { 48 inRange = false 49 } 50 if trailingZeros >= bitsRemaining { 51 break 52 } 53 block >>= trailingZeros 54 bitsRemaining -= trailingZeros 55 } 56 57 trailingOnes := uint(bits.TrailingZeros(^block)) 58 if !inRange { 59 inRange = true 60 count++ 61 } 62 if trailingOnes == bitsRemaining { 63 break 64 } 65 block >>= trailingOnes 66 bitsRemaining -= trailingOnes 67 } 68 } 69 return 70 } 71 72 // RangeSet returns the ports in the set as a port range set. 73 func (s *PortSet) RangeSet() PortRangeSet { 74 // [PortRange] is a small struct, so we can afford to preallocate the slice. 75 // Use 16 as the initial capacity, which corresponds to a 64-byte backing array, 76 // which happens to be the most common cache line size on modern CPUs. 77 ranges := make([]PortRange, 0, 16) 78 79 var ( 80 inRange bool 81 from uint16 82 ) 83 84 for i, block := range s.blocks { 85 bitsRemaining := uint(blockBits) 86 for { 87 trailingZeros := uint(bits.TrailingZeros(block)) 88 if trailingZeros != 0 { 89 if inRange { 90 inRange = false 91 ranges = append(ranges, PortRange{From: from, To: uint16((uint(i)+1)*blockBits - bitsRemaining - 1)}) 92 } 93 if trailingZeros >= bitsRemaining { 94 break 95 } 96 block >>= trailingZeros 97 bitsRemaining -= trailingZeros 98 } 99 100 trailingOnes := uint(bits.TrailingZeros(^block)) 101 if !inRange { 102 inRange = true 103 from = uint16((uint(i)+1)*blockBits - bitsRemaining) 104 } 105 if trailingOnes == bitsRemaining { 106 break 107 } 108 block >>= trailingOnes 109 bitsRemaining -= trailingOnes 110 } 111 } 112 113 if inRange { 114 ranges = append(ranges, PortRange{From: from, To: 65535}) 115 } 116 117 return PortRangeSet{ranges: ranges} 118 } 119 120 func panicOnZeroPort(port uint) { 121 if port == 0 { 122 panic(ErrZeroPort) 123 } 124 } 125 126 func (s *PortSet) blockIndex(port uint) uint { 127 return port / blockBits 128 } 129 130 func (s *PortSet) bitIndex(port uint) uint { 131 return port % blockBits 132 } 133 134 // Contains returns whether the given port is in the set. 135 func (s *PortSet) Contains(port uint16) bool { 136 p := uint(port) 137 panicOnZeroPort(p) 138 return s.blocks[s.blockIndex(p)]&(1<<s.bitIndex(p)) != 0 139 } 140 141 func (s *PortSet) add(port uint) { 142 s.blocks[s.blockIndex(port)] |= 1 << s.bitIndex(port) 143 } 144 145 // Add adds the given port to the set. 146 func (s *PortSet) Add(port uint16) { 147 p := uint(port) 148 panicOnZeroPort(p) 149 s.add(p) 150 } 151 152 func (s *PortSet) addRange(fromInclusive, toExclusive uint) { 153 fromBlockIndex := s.blockIndex(fromInclusive) 154 fromBitIndex := s.bitIndex(fromInclusive) 155 toBlockIndex := s.blockIndex(toExclusive) 156 toBitIndex := s.bitIndex(toExclusive) 157 158 fromBlockMask := ^uint(0) << fromBitIndex 159 toBlockMask := ^(^uint(0) << toBitIndex) 160 161 if fromBlockIndex == toBlockIndex { 162 s.blocks[fromBlockIndex] |= fromBlockMask & toBlockMask 163 return 164 } 165 166 s.blocks[fromBlockIndex] |= fromBlockMask 167 for i := fromBlockIndex + 1; i < toBlockIndex; i++ { 168 s.blocks[i] = ^uint(0) 169 } 170 if toBlockIndex < uint(len(s.blocks)) { 171 s.blocks[toBlockIndex] |= toBlockMask 172 } 173 } 174 175 // AddRange adds the given port range to the set. 176 func (s *PortSet) AddRange(from, to uint16) { 177 fromPort := uint(from) 178 toPort := uint(to) 179 panicOnZeroPort(fromPort) 180 if fromPort >= toPort { 181 panic(fmt.Sprintf("invalid port range: %d >= %d", fromPort, toPort)) 182 } 183 toPort++ // Make toPort exclusive. 184 s.addRange(fromPort, toPort) 185 } 186 187 // Parse parses the given string as a comma-separated list of ports and port ranges, 188 // and adds them to the set on success, or returns an error. 189 func (s *PortSet) Parse(portSetString string) error { 190 for len(portSetString) > 0 { 191 var portRangeString string 192 193 commaIndex := strings.IndexByte(portSetString, ',') 194 if commaIndex == -1 { 195 portRangeString = portSetString 196 portSetString = "" 197 } else { 198 portRangeString = portSetString[:commaIndex] 199 portSetString = portSetString[commaIndex+1:] 200 } 201 202 dashIndex := strings.IndexByte(portRangeString, '-') 203 if dashIndex == -1 { 204 port, err := strconv.ParseUint(portRangeString, 10, 16) 205 if err != nil { 206 return fmt.Errorf("invalid port %q: %w", portRangeString, err) 207 } 208 if port == 0 { 209 return fmt.Errorf("invalid port %q: %w", portRangeString, ErrZeroPort) 210 } 211 s.add(uint(port)) 212 } else { 213 fromPort, err := strconv.ParseUint(portRangeString[:dashIndex], 10, 16) 214 if err != nil { 215 return fmt.Errorf("invalid port range %q: %w", portRangeString, err) 216 } 217 if fromPort == 0 { 218 return fmt.Errorf("invalid port range %q: %w", portRangeString, ErrZeroPort) 219 } 220 toPort, err := strconv.ParseUint(portRangeString[dashIndex+1:], 10, 16) 221 if err != nil { 222 return fmt.Errorf("invalid port range %q: %w", portRangeString, err) 223 } 224 if fromPort >= toPort { 225 return fmt.Errorf("invalid port range %q: %d >= %d", portRangeString, fromPort, toPort) 226 } 227 toPort++ // Make toPort exclusive. 228 s.addRange(uint(fromPort), uint(toPort)) 229 } 230 } 231 return nil 232 }