github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/core/network/portrange.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package network 5 6 import ( 7 "fmt" 8 "sort" 9 "strconv" 10 "strings" 11 12 "github.com/juju/errors" 13 ) 14 15 // GroupedPortRanges represents a list of PortRange instances grouped by a 16 // particular feature. 17 type GroupedPortRanges map[string][]PortRange 18 19 // MergePendingOpenPortRanges will merge this group's port ranges with the 20 // provided *open* ports. If the provided range already exists in this group 21 // then this method returns false and the group is not modified. 22 func (grp GroupedPortRanges) MergePendingOpenPortRanges(pendingOpenRanges GroupedPortRanges) bool { 23 var modified bool 24 for endpointName, pendingRanges := range pendingOpenRanges { 25 for _, pendingRange := range pendingRanges { 26 if grp.rangeExistsForEndpoint(endpointName, pendingRange) { 27 // Exists, no op for opening. 28 continue 29 } 30 grp[endpointName] = append(grp[endpointName], pendingRange) 31 modified = true 32 } 33 } 34 return modified 35 } 36 37 // MergePendingClosePortRanges will merge this group's port ranges with the 38 // provided *closed* ports. If the provided range does not exists in this group 39 // then this method returns false and the group is not modified. 40 func (grp GroupedPortRanges) MergePendingClosePortRanges(pendingCloseRanges GroupedPortRanges) bool { 41 var modified bool 42 for endpointName, pendingRanges := range pendingCloseRanges { 43 for _, pendingRange := range pendingRanges { 44 if !grp.rangeExistsForEndpoint(endpointName, pendingRange) { 45 // Not exists, no op for closing. 46 continue 47 } 48 modified = grp.removePortRange(endpointName, pendingRange) 49 } 50 } 51 return modified 52 } 53 54 func (grp GroupedPortRanges) removePortRange(endpointName string, portRange PortRange) bool { 55 var modified bool 56 existingRanges := grp[endpointName] 57 for i, v := range existingRanges { 58 if v != portRange { 59 continue 60 } 61 existingRanges = append(existingRanges[:i], existingRanges[i+1:]...) 62 if len(existingRanges) == 0 { 63 delete(grp, endpointName) 64 } else { 65 grp[endpointName] = existingRanges 66 } 67 modified = true 68 } 69 return modified 70 } 71 72 func (grp GroupedPortRanges) rangeExistsForEndpoint(endpointName string, portRange PortRange) bool { 73 if len(grp[endpointName]) == 0 { 74 return false 75 } 76 77 for _, existingRange := range grp[endpointName] { 78 if existingRange == portRange { 79 return true 80 } 81 } 82 return false 83 } 84 85 // UniquePortRanges returns the unique set of PortRanges in this group. 86 func (grp GroupedPortRanges) UniquePortRanges() []PortRange { 87 var allPorts []PortRange 88 for _, portRanges := range grp { 89 allPorts = append(allPorts, portRanges...) 90 } 91 uniquePortRanges := UniquePortRanges(allPorts) 92 SortPortRanges(uniquePortRanges) 93 return uniquePortRanges 94 } 95 96 // Clone returns a copy of this port range grouping. 97 func (grp GroupedPortRanges) Clone() GroupedPortRanges { 98 if len(grp) == 0 { 99 return nil 100 } 101 102 grpCopy := make(GroupedPortRanges, len(grp)) 103 for k, v := range grp { 104 grpCopy[k] = append([]PortRange(nil), v...) 105 } 106 return grpCopy 107 } 108 109 // EqualTo returns true if this set of grouped port ranges are equal to other. 110 func (grp GroupedPortRanges) EqualTo(other GroupedPortRanges) bool { 111 if len(grp) != len(other) { 112 return false 113 } 114 115 for groupKey, portRanges := range grp { 116 otherPortRanges, found := other[groupKey] 117 if !found || len(portRanges) != len(otherPortRanges) { 118 return false 119 } 120 121 SortPortRanges(portRanges) 122 SortPortRanges(otherPortRanges) 123 for i, pr := range portRanges { 124 if pr != otherPortRanges[i] { 125 return false 126 } 127 } 128 } 129 130 return true 131 } 132 133 // PortRange represents a single range of ports on a particular subnet. 134 type PortRange struct { 135 FromPort int 136 ToPort int 137 Protocol string 138 } 139 140 // IsValid determines if the port range is valid. 141 func (p PortRange) Validate() error { 142 proto := strings.ToLower(p.Protocol) 143 if proto != "tcp" && proto != "udp" && proto != "icmp" { 144 return errors.Errorf(`invalid protocol %q, expected "tcp", "udp", or "icmp"`, proto) 145 } 146 if proto == "icmp" { 147 if p.FromPort == p.ToPort && p.FromPort == -1 { 148 return nil 149 } 150 return errors.Errorf(`protocol "icmp" doesn't support any ports; got "%v"`, p.FromPort) 151 } 152 if p.FromPort > p.ToPort { 153 return errors.Errorf("invalid port range %s", p) 154 } else if p.FromPort < 0 || p.FromPort > 65535 || p.ToPort < 0 || p.ToPort > 65535 { 155 return errors.Errorf("port range bounds must be between 0 and 65535, got %d-%d", p.FromPort, p.ToPort) 156 } 157 return nil 158 } 159 160 // Length returns the number of ports in the range. If the range is not valid, 161 // it returns 0. If this range uses ICMP as the protocol then a -1 is returned 162 // instead. 163 func (p PortRange) Length() int { 164 if err := p.Validate(); err != nil { 165 return 0 166 } 167 return (p.ToPort - p.FromPort) + 1 168 } 169 170 // ConflictsWith determines if the two port ranges conflict. 171 func (p PortRange) ConflictsWith(other PortRange) bool { 172 if p.Protocol != other.Protocol { 173 return false 174 } 175 return p.ToPort >= other.FromPort && other.ToPort >= p.FromPort 176 } 177 178 // SanitizeBounds returns a copy of the port range, which is guaranteed to have 179 // FromPort >= ToPort and both FromPort and ToPort fit into the valid range 180 // from 1 to 65535, inclusive. 181 func (p PortRange) SanitizeBounds() PortRange { 182 res := p 183 if res.Protocol == "icmp" { 184 return res 185 } 186 if res.FromPort > res.ToPort { 187 res.FromPort, res.ToPort = res.ToPort, res.FromPort 188 } 189 for _, bound := range []*int{&res.FromPort, &res.ToPort} { 190 switch { 191 case *bound <= 0: 192 *bound = 1 193 case *bound > 65535: 194 *bound = 65535 195 } 196 } 197 return res 198 } 199 200 // String returns a formatted representation of this port range. 201 func (p PortRange) String() string { 202 protocol := strings.ToLower(p.Protocol) 203 if protocol == "icmp" { 204 return protocol 205 } 206 if p.FromPort == p.ToPort { 207 return fmt.Sprintf("%d/%s", p.FromPort, protocol) 208 } 209 return fmt.Sprintf("%d-%d/%s", p.FromPort, p.ToPort, protocol) 210 } 211 212 func (p PortRange) GoString() string { 213 return p.String() 214 } 215 216 // LessThan returns true if other should appear after p when sorting a port 217 // range list. 218 func (p PortRange) LessThan(other PortRange) bool { 219 if p.Protocol != other.Protocol { 220 return p.Protocol < other.Protocol 221 } 222 if p.FromPort != other.FromPort { 223 return p.FromPort < other.FromPort 224 } 225 return p.ToPort < other.ToPort 226 } 227 228 // SortPortRanges sorts the given ports, first by protocol, then by number. 229 func SortPortRanges(portRanges []PortRange) { 230 sort.Slice(portRanges, func(i, j int) bool { 231 return portRanges[i].LessThan(portRanges[j]) 232 }) 233 } 234 235 // UniquePortRanges removes any duplicate port ranges from the input and 236 // returns de-dupped list back. 237 func UniquePortRanges(portRanges []PortRange) []PortRange { 238 var ( 239 res []PortRange 240 processed = make(map[PortRange]struct{}) 241 ) 242 243 for _, pr := range portRanges { 244 if _, seen := processed[pr]; seen { 245 continue 246 } 247 248 res = append(res, pr) 249 processed[pr] = struct{}{} 250 } 251 return res 252 } 253 254 // ParsePortRange builds a PortRange from the provided string. If the 255 // string does not include a protocol then "tcp" is used. Validate() 256 // gets called on the result before returning. If validation fails the 257 // invalid PortRange is still returned. 258 // Example strings: "80/tcp", "443", "12345-12349/udp", "icmp". 259 func ParsePortRange(inPortRange string) (PortRange, error) { 260 // Extract the protocol. 261 protocol := "tcp" 262 parts := strings.SplitN(inPortRange, "/", 2) 263 if len(parts) == 2 { 264 inPortRange = parts[0] 265 protocol = parts[1] 266 } 267 268 // Parse the ports. 269 portRange, err := parsePortRange(inPortRange) 270 if err != nil { 271 return portRange, errors.Trace(err) 272 } 273 if portRange.FromPort == -1 { 274 protocol = "icmp" 275 } 276 portRange.Protocol = protocol 277 278 return portRange, portRange.Validate() 279 } 280 281 // MustParsePortRange converts a raw port-range string into a PortRange. 282 // If the string is invalid, the function panics. 283 func MustParsePortRange(portRange string) PortRange { 284 portrange, err := ParsePortRange(portRange) 285 if err != nil { 286 panic(err) 287 } 288 return portrange 289 } 290 291 func parsePortRange(portRange string) (PortRange, error) { 292 var result PortRange 293 var start, end int 294 parts := strings.Split(portRange, "-") 295 if len(parts) > 2 { 296 return result, errors.Errorf("invalid port range %q", portRange) 297 } 298 299 if len(parts) == 1 { 300 if parts[0] == "icmp" { 301 start, end = -1, -1 302 } else { 303 port, err := strconv.Atoi(parts[0]) 304 if err != nil { 305 return result, errors.Annotatef(err, "invalid port %q", portRange) 306 } 307 start, end = port, port 308 } 309 } else { 310 var err error 311 if start, err = strconv.Atoi(parts[0]); err != nil { 312 return result, errors.Annotatef(err, "invalid port %q", parts[0]) 313 } 314 if end, err = strconv.Atoi(parts[1]); err != nil { 315 return result, errors.Annotatef(err, "invalid port %q", parts[1]) 316 } 317 } 318 319 result = PortRange{ 320 FromPort: start, 321 ToPort: end, 322 } 323 return result, nil 324 } 325 326 // CombinePortRanges groups together all port ranges according to 327 // protocol, and then combines then into contiguous port ranges. 328 // NOTE: Juju only allows its model to contain non-overlapping port ranges. 329 // This method operates on that assumption. 330 func CombinePortRanges(ranges ...PortRange) []PortRange { 331 SortPortRanges(ranges) 332 var result []PortRange 333 var current *PortRange 334 for _, pr := range ranges { 335 thispr := pr 336 if current == nil { 337 current = &thispr 338 continue 339 } 340 if pr.Protocol == current.Protocol && pr.FromPort == current.ToPort+1 { 341 current.ToPort = thispr.ToPort 342 continue 343 } 344 result = append(result, *current) 345 current = &thispr 346 } 347 if current != nil { 348 result = append(result, *current) 349 } 350 return result 351 }