github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/core/network/portrange.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package network
     5  
     6  import (
     7  	"fmt"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/juju/errors"
    13  )
    14  
    15  // GroupedPortRanges represents a list of PortRange instances grouped by a
    16  // particular feature.
    17  type GroupedPortRanges map[string][]PortRange
    18  
    19  // MergePendingOpenPortRanges will merge this group's port ranges with the
    20  // provided *open* ports. If the provided range already exists in this group
    21  // then this method returns false and the group is not modified.
    22  func (grp GroupedPortRanges) MergePendingOpenPortRanges(pendingOpenRanges GroupedPortRanges) bool {
    23  	var modified bool
    24  	for endpointName, pendingRanges := range pendingOpenRanges {
    25  		for _, pendingRange := range pendingRanges {
    26  			if grp.rangeExistsForEndpoint(endpointName, pendingRange) {
    27  				// Exists, no op for opening.
    28  				continue
    29  			}
    30  			grp[endpointName] = append(grp[endpointName], pendingRange)
    31  			modified = true
    32  		}
    33  	}
    34  	return modified
    35  }
    36  
    37  // MergePendingClosePortRanges will merge this group's port ranges with the
    38  // provided *closed* ports. If the provided range does not exists in this group
    39  // then this method returns false and the group is not modified.
    40  func (grp GroupedPortRanges) MergePendingClosePortRanges(pendingCloseRanges GroupedPortRanges) bool {
    41  	var modified bool
    42  	for endpointName, pendingRanges := range pendingCloseRanges {
    43  		for _, pendingRange := range pendingRanges {
    44  			if !grp.rangeExistsForEndpoint(endpointName, pendingRange) {
    45  				// Not exists, no op for closing.
    46  				continue
    47  			}
    48  			modified = grp.removePortRange(endpointName, pendingRange)
    49  		}
    50  	}
    51  	return modified
    52  }
    53  
    54  func (grp GroupedPortRanges) removePortRange(endpointName string, portRange PortRange) bool {
    55  	var modified bool
    56  	existingRanges := grp[endpointName]
    57  	for i, v := range existingRanges {
    58  		if v != portRange {
    59  			continue
    60  		}
    61  		existingRanges = append(existingRanges[:i], existingRanges[i+1:]...)
    62  		if len(existingRanges) == 0 {
    63  			delete(grp, endpointName)
    64  		} else {
    65  			grp[endpointName] = existingRanges
    66  		}
    67  		modified = true
    68  	}
    69  	return modified
    70  }
    71  
    72  func (grp GroupedPortRanges) rangeExistsForEndpoint(endpointName string, portRange PortRange) bool {
    73  	if len(grp[endpointName]) == 0 {
    74  		return false
    75  	}
    76  
    77  	for _, existingRange := range grp[endpointName] {
    78  		if existingRange == portRange {
    79  			return true
    80  		}
    81  	}
    82  	return false
    83  }
    84  
    85  // UniquePortRanges returns the unique set of PortRanges in this group.
    86  func (grp GroupedPortRanges) UniquePortRanges() []PortRange {
    87  	var allPorts []PortRange
    88  	for _, portRanges := range grp {
    89  		allPorts = append(allPorts, portRanges...)
    90  	}
    91  	uniquePortRanges := UniquePortRanges(allPorts)
    92  	SortPortRanges(uniquePortRanges)
    93  	return uniquePortRanges
    94  }
    95  
    96  // Clone returns a copy of this port range grouping.
    97  func (grp GroupedPortRanges) Clone() GroupedPortRanges {
    98  	if len(grp) == 0 {
    99  		return nil
   100  	}
   101  
   102  	grpCopy := make(GroupedPortRanges, len(grp))
   103  	for k, v := range grp {
   104  		grpCopy[k] = append([]PortRange(nil), v...)
   105  	}
   106  	return grpCopy
   107  }
   108  
   109  // EqualTo returns true if this set of grouped port ranges are equal to other.
   110  func (grp GroupedPortRanges) EqualTo(other GroupedPortRanges) bool {
   111  	if len(grp) != len(other) {
   112  		return false
   113  	}
   114  
   115  	for groupKey, portRanges := range grp {
   116  		otherPortRanges, found := other[groupKey]
   117  		if !found || len(portRanges) != len(otherPortRanges) {
   118  			return false
   119  		}
   120  
   121  		SortPortRanges(portRanges)
   122  		SortPortRanges(otherPortRanges)
   123  		for i, pr := range portRanges {
   124  			if pr != otherPortRanges[i] {
   125  				return false
   126  			}
   127  		}
   128  	}
   129  
   130  	return true
   131  }
   132  
   133  // PortRange represents a single range of ports on a particular subnet.
   134  type PortRange struct {
   135  	FromPort int
   136  	ToPort   int
   137  	Protocol string
   138  }
   139  
   140  // IsValid determines if the port range is valid.
   141  func (p PortRange) Validate() error {
   142  	proto := strings.ToLower(p.Protocol)
   143  	if proto != "tcp" && proto != "udp" && proto != "icmp" {
   144  		return errors.Errorf(`invalid protocol %q, expected "tcp", "udp", or "icmp"`, proto)
   145  	}
   146  	if proto == "icmp" {
   147  		if p.FromPort == p.ToPort && p.FromPort == -1 {
   148  			return nil
   149  		}
   150  		return errors.Errorf(`protocol "icmp" doesn't support any ports; got "%v"`, p.FromPort)
   151  	}
   152  	if p.FromPort > p.ToPort {
   153  		return errors.Errorf("invalid port range %s", p)
   154  	} else if p.FromPort < 0 || p.FromPort > 65535 || p.ToPort < 0 || p.ToPort > 65535 {
   155  		return errors.Errorf("port range bounds must be between 0 and 65535, got %d-%d", p.FromPort, p.ToPort)
   156  	}
   157  	return nil
   158  }
   159  
   160  // Length returns the number of ports in the range.  If the range is not valid,
   161  // it returns 0. If this range uses ICMP as the protocol then a -1 is returned
   162  // instead.
   163  func (p PortRange) Length() int {
   164  	if err := p.Validate(); err != nil {
   165  		return 0
   166  	}
   167  	return (p.ToPort - p.FromPort) + 1
   168  }
   169  
   170  // ConflictsWith determines if the two port ranges conflict.
   171  func (p PortRange) ConflictsWith(other PortRange) bool {
   172  	if p.Protocol != other.Protocol {
   173  		return false
   174  	}
   175  	return p.ToPort >= other.FromPort && other.ToPort >= p.FromPort
   176  }
   177  
   178  // SanitizeBounds returns a copy of the port range, which is guaranteed to have
   179  // FromPort >= ToPort and both FromPort and ToPort fit into the valid range
   180  // from 1 to 65535, inclusive.
   181  func (p PortRange) SanitizeBounds() PortRange {
   182  	res := p
   183  	if res.Protocol == "icmp" {
   184  		return res
   185  	}
   186  	if res.FromPort > res.ToPort {
   187  		res.FromPort, res.ToPort = res.ToPort, res.FromPort
   188  	}
   189  	for _, bound := range []*int{&res.FromPort, &res.ToPort} {
   190  		switch {
   191  		case *bound <= 0:
   192  			*bound = 1
   193  		case *bound > 65535:
   194  			*bound = 65535
   195  		}
   196  	}
   197  	return res
   198  }
   199  
   200  // String returns a formatted representation of this port range.
   201  func (p PortRange) String() string {
   202  	protocol := strings.ToLower(p.Protocol)
   203  	if protocol == "icmp" {
   204  		return protocol
   205  	}
   206  	if p.FromPort == p.ToPort {
   207  		return fmt.Sprintf("%d/%s", p.FromPort, protocol)
   208  	}
   209  	return fmt.Sprintf("%d-%d/%s", p.FromPort, p.ToPort, protocol)
   210  }
   211  
   212  func (p PortRange) GoString() string {
   213  	return p.String()
   214  }
   215  
   216  // LessThan returns true if other should appear after p when sorting a port
   217  // range list.
   218  func (p PortRange) LessThan(other PortRange) bool {
   219  	if p.Protocol != other.Protocol {
   220  		return p.Protocol < other.Protocol
   221  	}
   222  	if p.FromPort != other.FromPort {
   223  		return p.FromPort < other.FromPort
   224  	}
   225  	return p.ToPort < other.ToPort
   226  }
   227  
   228  // SortPortRanges sorts the given ports, first by protocol, then by number.
   229  func SortPortRanges(portRanges []PortRange) {
   230  	sort.Slice(portRanges, func(i, j int) bool {
   231  		return portRanges[i].LessThan(portRanges[j])
   232  	})
   233  }
   234  
   235  // UniquePortRanges removes any duplicate port ranges from the input and
   236  // returns de-dupped list back.
   237  func UniquePortRanges(portRanges []PortRange) []PortRange {
   238  	var (
   239  		res       []PortRange
   240  		processed = make(map[PortRange]struct{})
   241  	)
   242  
   243  	for _, pr := range portRanges {
   244  		if _, seen := processed[pr]; seen {
   245  			continue
   246  		}
   247  
   248  		res = append(res, pr)
   249  		processed[pr] = struct{}{}
   250  	}
   251  	return res
   252  }
   253  
   254  // ParsePortRange builds a PortRange from the provided string. If the
   255  // string does not include a protocol then "tcp" is used. Validate()
   256  // gets called on the result before returning. If validation fails the
   257  // invalid PortRange is still returned.
   258  // Example strings: "80/tcp", "443", "12345-12349/udp", "icmp".
   259  func ParsePortRange(inPortRange string) (PortRange, error) {
   260  	// Extract the protocol.
   261  	protocol := "tcp"
   262  	parts := strings.SplitN(inPortRange, "/", 2)
   263  	if len(parts) == 2 {
   264  		inPortRange = parts[0]
   265  		protocol = parts[1]
   266  	}
   267  
   268  	// Parse the ports.
   269  	portRange, err := parsePortRange(inPortRange)
   270  	if err != nil {
   271  		return portRange, errors.Trace(err)
   272  	}
   273  	if portRange.FromPort == -1 {
   274  		protocol = "icmp"
   275  	}
   276  	portRange.Protocol = protocol
   277  
   278  	return portRange, portRange.Validate()
   279  }
   280  
   281  // MustParsePortRange converts a raw port-range string into a PortRange.
   282  // If the string is invalid, the function panics.
   283  func MustParsePortRange(portRange string) PortRange {
   284  	portrange, err := ParsePortRange(portRange)
   285  	if err != nil {
   286  		panic(err)
   287  	}
   288  	return portrange
   289  }
   290  
   291  func parsePortRange(portRange string) (PortRange, error) {
   292  	var result PortRange
   293  	var start, end int
   294  	parts := strings.Split(portRange, "-")
   295  	if len(parts) > 2 {
   296  		return result, errors.Errorf("invalid port range %q", portRange)
   297  	}
   298  
   299  	if len(parts) == 1 {
   300  		if parts[0] == "icmp" {
   301  			start, end = -1, -1
   302  		} else {
   303  			port, err := strconv.Atoi(parts[0])
   304  			if err != nil {
   305  				return result, errors.Annotatef(err, "invalid port %q", portRange)
   306  			}
   307  			start, end = port, port
   308  		}
   309  	} else {
   310  		var err error
   311  		if start, err = strconv.Atoi(parts[0]); err != nil {
   312  			return result, errors.Annotatef(err, "invalid port %q", parts[0])
   313  		}
   314  		if end, err = strconv.Atoi(parts[1]); err != nil {
   315  			return result, errors.Annotatef(err, "invalid port %q", parts[1])
   316  		}
   317  	}
   318  
   319  	result = PortRange{
   320  		FromPort: start,
   321  		ToPort:   end,
   322  	}
   323  	return result, nil
   324  }
   325  
   326  // CombinePortRanges groups together all port ranges according to
   327  // protocol, and then combines then into contiguous port ranges.
   328  // NOTE: Juju only allows its model to contain non-overlapping port ranges.
   329  // This method operates on that assumption.
   330  func CombinePortRanges(ranges ...PortRange) []PortRange {
   331  	SortPortRanges(ranges)
   332  	var result []PortRange
   333  	var current *PortRange
   334  	for _, pr := range ranges {
   335  		thispr := pr
   336  		if current == nil {
   337  			current = &thispr
   338  			continue
   339  		}
   340  		if pr.Protocol == current.Protocol && pr.FromPort == current.ToPort+1 {
   341  			current.ToPort = thispr.ToPort
   342  			continue
   343  		}
   344  		result = append(result, *current)
   345  		current = &thispr
   346  	}
   347  	if current != nil {
   348  		result = append(result, *current)
   349  	}
   350  	return result
   351  }