github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/core/network/portrange.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package network
     5  
     6  import (
     7  	"fmt"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/juju/errors"
    13  )
    14  
    15  // PortRange represents a single range of ports.
    16  type PortRange struct {
    17  	FromPort int
    18  	ToPort   int
    19  	Protocol string
    20  }
    21  
    22  // IsValid determines if the port range is valid.
    23  func (p PortRange) Validate() error {
    24  	proto := strings.ToLower(p.Protocol)
    25  	if proto != "tcp" && proto != "udp" && proto != "icmp" {
    26  		return errors.Errorf(`invalid protocol %q, expected "tcp", "udp", or "icmp"`, proto)
    27  	}
    28  	if proto == "icmp" {
    29  		if p.FromPort == p.ToPort && p.FromPort == -1 {
    30  			return nil
    31  		}
    32  		return errors.Errorf(`protocol "icmp" doesn't support any ports; got "%v"`, p.FromPort)
    33  	}
    34  	err := errors.Errorf(
    35  		"invalid port range %d-%d/%s",
    36  		p.FromPort,
    37  		p.ToPort,
    38  		p.Protocol,
    39  	)
    40  	switch {
    41  	case p.FromPort > p.ToPort:
    42  		return err
    43  	case p.FromPort < 1 || p.FromPort > 65535:
    44  		return err
    45  	case p.ToPort < 1 || p.ToPort > 65535:
    46  		return err
    47  	}
    48  	return nil
    49  }
    50  
    51  // ConflictsWith determines if the two port ranges conflict.
    52  func (a PortRange) ConflictsWith(b PortRange) bool {
    53  	if a.Protocol != b.Protocol {
    54  		return false
    55  	}
    56  	return a.ToPort >= b.FromPort && b.ToPort >= a.FromPort
    57  }
    58  
    59  func (p PortRange) String() string {
    60  	protocol := strings.ToLower(p.Protocol)
    61  	if protocol == "icmp" {
    62  		return protocol
    63  	}
    64  	if p.FromPort == p.ToPort {
    65  		return fmt.Sprintf("%d/%s", p.FromPort, protocol)
    66  	}
    67  	return fmt.Sprintf("%d-%d/%s", p.FromPort, p.ToPort, protocol)
    68  }
    69  
    70  func (p PortRange) GoString() string {
    71  	return p.String()
    72  }
    73  
    74  type portRangeSlice []PortRange
    75  
    76  func (p portRangeSlice) Len() int      { return len(p) }
    77  func (p portRangeSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
    78  func (p portRangeSlice) Less(i, j int) bool {
    79  	p1 := p[i]
    80  	p2 := p[j]
    81  	if p1.Protocol != p2.Protocol {
    82  		return p1.Protocol < p2.Protocol
    83  	}
    84  	if p1.FromPort != p2.FromPort {
    85  		return p1.FromPort < p2.FromPort
    86  	}
    87  	return p1.ToPort < p2.ToPort
    88  }
    89  
    90  // SortPortRanges sorts the given ports, first by protocol, then by number.
    91  func SortPortRanges(portRanges []PortRange) {
    92  	sort.Sort(portRangeSlice(portRanges))
    93  }
    94  
    95  // CollapsePorts collapses a slice of ports into port ranges.
    96  //
    97  // NOTE(dimitern): This is deprecated and should be removed when
    98  // possible. It still exists, because in a few places slices of Ports
    99  // are converted to PortRanges internally.
   100  func CollapsePorts(ports []Port) (result []PortRange) {
   101  	// First, convert ports to ranges, then sort them.
   102  	var portRanges []PortRange
   103  	for _, p := range ports {
   104  		portRanges = append(portRanges, PortRange{p.Number, p.Number, p.Protocol})
   105  	}
   106  	SortPortRanges(portRanges)
   107  	fromPort := 0
   108  	toPort := 0
   109  	protocol := ""
   110  	// Now merge single port ranges while preserving the order.
   111  	for _, pr := range portRanges {
   112  		if fromPort == 0 {
   113  			// new port range
   114  			fromPort = pr.FromPort
   115  			toPort = pr.ToPort
   116  			protocol = pr.Protocol
   117  		} else if pr.FromPort == toPort+1 && protocol == pr.Protocol {
   118  			// continuing port range
   119  			toPort = pr.FromPort
   120  		} else {
   121  			// break in port range
   122  			result = append(result,
   123  				PortRange{
   124  					Protocol: protocol,
   125  					FromPort: fromPort,
   126  					ToPort:   toPort,
   127  				})
   128  			fromPort = pr.FromPort
   129  			toPort = pr.ToPort
   130  			protocol = pr.Protocol
   131  		}
   132  	}
   133  	if fromPort != 0 {
   134  		result = append(result, PortRange{
   135  			Protocol: protocol,
   136  			FromPort: fromPort,
   137  			ToPort:   toPort,
   138  		})
   139  
   140  	}
   141  	return
   142  }
   143  
   144  // ParsePortRange builds a PortRange from the provided string. If the
   145  // string does not include a protocol then "tcp" is used. Validate()
   146  // gets called on the result before returning. If validation fails the
   147  // invalid PortRange is still returned.
   148  // Example strings: "80/tcp", "443", "12345-12349/udp", "icmp".
   149  func ParsePortRange(inPortRange string) (PortRange, error) {
   150  	// Extract the protocol.
   151  	protocol := "tcp"
   152  	parts := strings.SplitN(inPortRange, "/", 2)
   153  	if len(parts) == 2 {
   154  		inPortRange = parts[0]
   155  		protocol = parts[1]
   156  	}
   157  
   158  	// Parse the ports.
   159  	portRange, err := parsePortRange(inPortRange)
   160  	if err != nil {
   161  		return portRange, errors.Trace(err)
   162  	}
   163  	if portRange.FromPort == -1 {
   164  		protocol = "icmp"
   165  	}
   166  	portRange.Protocol = protocol
   167  
   168  	return portRange, portRange.Validate()
   169  }
   170  
   171  // MustParsePortRange converts a raw port-range string into a PortRange.
   172  // If the string is invalid, the function panics.
   173  func MustParsePortRange(portRange string) PortRange {
   174  	portrange, err := ParsePortRange(portRange)
   175  	if err != nil {
   176  		panic(err)
   177  	}
   178  	return portrange
   179  }
   180  
   181  func parsePortRange(portRange string) (PortRange, error) {
   182  	var result PortRange
   183  	var start, end int
   184  	parts := strings.Split(portRange, "-")
   185  	if len(parts) > 2 {
   186  		return result, errors.Errorf("invalid port range %q", portRange)
   187  	}
   188  
   189  	if len(parts) == 1 {
   190  		if parts[0] == "icmp" {
   191  			start, end = -1, -1
   192  		} else {
   193  			port, err := strconv.Atoi(parts[0])
   194  			if err != nil {
   195  				return result, errors.Annotatef(err, "invalid port %q", portRange)
   196  			}
   197  			start, end = port, port
   198  		}
   199  	} else {
   200  		var err error
   201  		if start, err = strconv.Atoi(parts[0]); err != nil {
   202  			return result, errors.Annotatef(err, "invalid port %q", parts[0])
   203  		}
   204  		if end, err = strconv.Atoi(parts[1]); err != nil {
   205  			return result, errors.Annotatef(err, "invalid port %q", parts[1])
   206  		}
   207  	}
   208  
   209  	result = PortRange{
   210  		FromPort: start,
   211  		ToPort:   end,
   212  	}
   213  	return result, nil
   214  }
   215  
   216  // CombinePortRanges groups together all port ranges according to
   217  // protocol, and then combines then into contiguous port ranges.
   218  // NOTE: Juju only allows its model to contain non-overlapping port ranges.
   219  // This method operates on that assumption.
   220  func CombinePortRanges(ranges ...PortRange) []PortRange {
   221  	SortPortRanges(ranges)
   222  	var result []PortRange
   223  	var current *PortRange
   224  	for _, pr := range ranges {
   225  		thispr := pr
   226  		if current == nil {
   227  			current = &thispr
   228  			continue
   229  		}
   230  		if pr.Protocol == current.Protocol && pr.FromPort == current.ToPort+1 {
   231  			current.ToPort = thispr.ToPort
   232  			continue
   233  		}
   234  		result = append(result, *current)
   235  		current = &thispr
   236  	}
   237  	if current != nil {
   238  		result = append(result, *current)
   239  	}
   240  	return result
   241  }