github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/core/network/portrange.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package network 5 6 import ( 7 "fmt" 8 "sort" 9 "strconv" 10 "strings" 11 12 "github.com/juju/errors" 13 ) 14 15 // PortRange represents a single range of ports. 16 type PortRange struct { 17 FromPort int 18 ToPort int 19 Protocol string 20 } 21 22 // IsValid determines if the port range is valid. 23 func (p PortRange) Validate() error { 24 proto := strings.ToLower(p.Protocol) 25 if proto != "tcp" && proto != "udp" && proto != "icmp" { 26 return errors.Errorf(`invalid protocol %q, expected "tcp", "udp", or "icmp"`, proto) 27 } 28 if proto == "icmp" { 29 if p.FromPort == p.ToPort && p.FromPort == -1 { 30 return nil 31 } 32 return errors.Errorf(`protocol "icmp" doesn't support any ports; got "%v"`, p.FromPort) 33 } 34 err := errors.Errorf( 35 "invalid port range %d-%d/%s", 36 p.FromPort, 37 p.ToPort, 38 p.Protocol, 39 ) 40 switch { 41 case p.FromPort > p.ToPort: 42 return err 43 case p.FromPort < 1 || p.FromPort > 65535: 44 return err 45 case p.ToPort < 1 || p.ToPort > 65535: 46 return err 47 } 48 return nil 49 } 50 51 // ConflictsWith determines if the two port ranges conflict. 52 func (a PortRange) ConflictsWith(b PortRange) bool { 53 if a.Protocol != b.Protocol { 54 return false 55 } 56 return a.ToPort >= b.FromPort && b.ToPort >= a.FromPort 57 } 58 59 func (p PortRange) String() string { 60 protocol := strings.ToLower(p.Protocol) 61 if protocol == "icmp" { 62 return protocol 63 } 64 if p.FromPort == p.ToPort { 65 return fmt.Sprintf("%d/%s", p.FromPort, protocol) 66 } 67 return fmt.Sprintf("%d-%d/%s", p.FromPort, p.ToPort, protocol) 68 } 69 70 func (p PortRange) GoString() string { 71 return p.String() 72 } 73 74 type portRangeSlice []PortRange 75 76 func (p portRangeSlice) Len() int { return len(p) } 77 func (p portRangeSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 78 func (p portRangeSlice) Less(i, j int) bool { 79 p1 := p[i] 80 p2 := p[j] 81 if p1.Protocol != p2.Protocol { 82 return p1.Protocol < p2.Protocol 83 } 84 if p1.FromPort != p2.FromPort { 85 return p1.FromPort < p2.FromPort 86 } 87 return p1.ToPort < p2.ToPort 88 } 89 90 // SortPortRanges sorts the given ports, first by protocol, then by number. 91 func SortPortRanges(portRanges []PortRange) { 92 sort.Sort(portRangeSlice(portRanges)) 93 } 94 95 // CollapsePorts collapses a slice of ports into port ranges. 96 // 97 // NOTE(dimitern): This is deprecated and should be removed when 98 // possible. It still exists, because in a few places slices of Ports 99 // are converted to PortRanges internally. 100 func CollapsePorts(ports []Port) (result []PortRange) { 101 // First, convert ports to ranges, then sort them. 102 var portRanges []PortRange 103 for _, p := range ports { 104 portRanges = append(portRanges, PortRange{p.Number, p.Number, p.Protocol}) 105 } 106 SortPortRanges(portRanges) 107 fromPort := 0 108 toPort := 0 109 protocol := "" 110 // Now merge single port ranges while preserving the order. 111 for _, pr := range portRanges { 112 if fromPort == 0 { 113 // new port range 114 fromPort = pr.FromPort 115 toPort = pr.ToPort 116 protocol = pr.Protocol 117 } else if pr.FromPort == toPort+1 && protocol == pr.Protocol { 118 // continuing port range 119 toPort = pr.FromPort 120 } else { 121 // break in port range 122 result = append(result, 123 PortRange{ 124 Protocol: protocol, 125 FromPort: fromPort, 126 ToPort: toPort, 127 }) 128 fromPort = pr.FromPort 129 toPort = pr.ToPort 130 protocol = pr.Protocol 131 } 132 } 133 if fromPort != 0 { 134 result = append(result, PortRange{ 135 Protocol: protocol, 136 FromPort: fromPort, 137 ToPort: toPort, 138 }) 139 140 } 141 return 142 } 143 144 // ParsePortRange builds a PortRange from the provided string. If the 145 // string does not include a protocol then "tcp" is used. Validate() 146 // gets called on the result before returning. If validation fails the 147 // invalid PortRange is still returned. 148 // Example strings: "80/tcp", "443", "12345-12349/udp", "icmp". 149 func ParsePortRange(inPortRange string) (PortRange, error) { 150 // Extract the protocol. 151 protocol := "tcp" 152 parts := strings.SplitN(inPortRange, "/", 2) 153 if len(parts) == 2 { 154 inPortRange = parts[0] 155 protocol = parts[1] 156 } 157 158 // Parse the ports. 159 portRange, err := parsePortRange(inPortRange) 160 if err != nil { 161 return portRange, errors.Trace(err) 162 } 163 if portRange.FromPort == -1 { 164 protocol = "icmp" 165 } 166 portRange.Protocol = protocol 167 168 return portRange, portRange.Validate() 169 } 170 171 // MustParsePortRange converts a raw port-range string into a PortRange. 172 // If the string is invalid, the function panics. 173 func MustParsePortRange(portRange string) PortRange { 174 portrange, err := ParsePortRange(portRange) 175 if err != nil { 176 panic(err) 177 } 178 return portrange 179 } 180 181 func parsePortRange(portRange string) (PortRange, error) { 182 var result PortRange 183 var start, end int 184 parts := strings.Split(portRange, "-") 185 if len(parts) > 2 { 186 return result, errors.Errorf("invalid port range %q", portRange) 187 } 188 189 if len(parts) == 1 { 190 if parts[0] == "icmp" { 191 start, end = -1, -1 192 } else { 193 port, err := strconv.Atoi(parts[0]) 194 if err != nil { 195 return result, errors.Annotatef(err, "invalid port %q", portRange) 196 } 197 start, end = port, port 198 } 199 } else { 200 var err error 201 if start, err = strconv.Atoi(parts[0]); err != nil { 202 return result, errors.Annotatef(err, "invalid port %q", parts[0]) 203 } 204 if end, err = strconv.Atoi(parts[1]); err != nil { 205 return result, errors.Annotatef(err, "invalid port %q", parts[1]) 206 } 207 } 208 209 result = PortRange{ 210 FromPort: start, 211 ToPort: end, 212 } 213 return result, nil 214 } 215 216 // CombinePortRanges groups together all port ranges according to 217 // protocol, and then combines then into contiguous port ranges. 218 // NOTE: Juju only allows its model to contain non-overlapping port ranges. 219 // This method operates on that assumption. 220 func CombinePortRanges(ranges ...PortRange) []PortRange { 221 SortPortRanges(ranges) 222 var result []PortRange 223 var current *PortRange 224 for _, pr := range ranges { 225 thispr := pr 226 if current == nil { 227 current = &thispr 228 continue 229 } 230 if pr.Protocol == current.Protocol && pr.FromPort == current.ToPort+1 { 231 current.ToPort = thispr.ToPort 232 continue 233 } 234 result = append(result, *current) 235 current = &thispr 236 } 237 if current != nil { 238 result = append(result, *current) 239 } 240 return result 241 }