github.com/axw/juju@v0.0.0-20161005053422-4bd6544d08d4/network/portset.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package network
     5  
     6  import (
     7  	"sort"
     8  	"strconv"
     9  
    10  	"github.com/juju/utils/set"
    11  )
    12  
    13  // PortSet is a set-like container of Port values.
    14  type PortSet struct {
    15  	values map[string]set.Ints
    16  }
    17  
    18  // NewPortSet creates a map of protocols to sets of stringified port numbers.
    19  func NewPortSet(portRanges ...PortRange) PortSet {
    20  	var result PortSet
    21  	result.values = make(map[string]set.Ints)
    22  	result.AddRanges(portRanges...)
    23  	return result
    24  }
    25  
    26  // Size returns the number of ports in the set.
    27  func (ps PortSet) Size() int {
    28  	size := 0
    29  	for _, ports := range ps.values {
    30  		size += len(ports)
    31  	}
    32  	return size
    33  }
    34  
    35  // IsEmpty returns true if the PortSet is empty.
    36  func (ps PortSet) IsEmpty() bool {
    37  	return len(ps.values) == 0
    38  }
    39  
    40  // Values returns a list of all the ports in the set.
    41  func (ps PortSet) Values() []Port {
    42  	return ps.Ports()
    43  }
    44  
    45  // Protocols returns a list of protocols known to the PortSet.
    46  func (ps PortSet) Protocols() []string {
    47  	var result []string
    48  	for key := range ps.values {
    49  		result = append(result, key)
    50  	}
    51  	return result
    52  }
    53  
    54  // PortRanges returns a list of all the port ranges in the set for the
    55  // given protocols. If no protocols are provided all known protocols in
    56  // the set are used.
    57  func (ps PortSet) PortRanges(protocols ...string) []PortRange {
    58  	if len(protocols) == 0 {
    59  		protocols = ps.Protocols()
    60  	}
    61  
    62  	var result []PortRange
    63  	for _, protocol := range protocols {
    64  		ranges := collapsePorts(protocol, ps.PortNumbers(protocol)...)
    65  		result = append(result, ranges...)
    66  	}
    67  	return result
    68  }
    69  
    70  func collapsePorts(protocol string, ports ...int) (result []PortRange) {
    71  	if len(ports) == 0 {
    72  		return nil
    73  	}
    74  
    75  	sort.Ints(ports)
    76  
    77  	fromPort := 0
    78  	toPort := 0
    79  	for _, port := range ports {
    80  		if fromPort == 0 {
    81  			// new port range
    82  			fromPort = port
    83  			toPort = port
    84  		} else if port == toPort+1 {
    85  			// continuing port range
    86  			toPort = port
    87  		} else {
    88  			// break in port range
    89  			result = append(result, PortRange{
    90  				Protocol: protocol,
    91  				FromPort: fromPort,
    92  				ToPort:   toPort,
    93  			})
    94  			fromPort = port
    95  			toPort = port
    96  		}
    97  	}
    98  	result = append(result, PortRange{
    99  		Protocol: protocol,
   100  		FromPort: fromPort,
   101  		ToPort:   toPort,
   102  	})
   103  	return
   104  }
   105  
   106  // PortNumbers returns a list of all the port numbers in the set for
   107  // the given protocols. If no protocols are provided then all known
   108  // protocols in the set are used.
   109  func (ps PortSet) Ports(protocols ...string) []Port {
   110  	if len(protocols) == 0 {
   111  		protocols = ps.Protocols()
   112  	}
   113  
   114  	var results []Port
   115  	for _, portRange := range ps.PortRanges(protocols...) {
   116  		for p := portRange.FromPort; p <= portRange.ToPort; p++ {
   117  			results = append(results, Port{portRange.Protocol, p})
   118  		}
   119  	}
   120  	return results
   121  }
   122  
   123  // PortNumbers returns a list of all the port numbers in the set for
   124  // the given protocol.
   125  func (ps PortSet) PortNumbers(protocol string) []int {
   126  	ports, ok := ps.values[protocol]
   127  	if !ok {
   128  		return nil
   129  	}
   130  	return ports.Values()
   131  }
   132  
   133  // PortStrings returns a list of stringified ports in the set
   134  // for the given protocol. This is strictly a convenience method
   135  // for situations where another API requires a list of strings.
   136  func (ps PortSet) PortStrings(protocol string) []string {
   137  	ports, ok := ps.values[protocol]
   138  	if !ok {
   139  		return nil
   140  	}
   141  	var result []string
   142  	for _, port := range ports.Values() {
   143  		portStr := strconv.Itoa(port)
   144  		result = append(result, portStr)
   145  	}
   146  	return result
   147  }
   148  
   149  // Add adds a port to the PortSet.
   150  func (ps *PortSet) Add(protocol string, port int) {
   151  	if ps.values == nil {
   152  		panic("uninitalised set")
   153  	}
   154  	ports, ok := ps.values[protocol]
   155  	if !ok {
   156  		ps.values[protocol] = set.NewInts(port)
   157  	} else {
   158  		ports.Add(port)
   159  	}
   160  }
   161  
   162  // AddRanges adds port ranges to the PortSet.
   163  func (ps *PortSet) AddRanges(portRanges ...PortRange) {
   164  	for _, portRange := range portRanges {
   165  		for p := portRange.FromPort; p <= portRange.ToPort; p++ {
   166  			ps.Add(portRange.Protocol, p)
   167  		}
   168  	}
   169  }
   170  
   171  // Remove removes the given port from the set.
   172  func (ps *PortSet) Remove(protocol string, port int) {
   173  	ports, ok := ps.values[protocol]
   174  	if ok {
   175  		ports.Remove(port)
   176  	}
   177  }
   178  
   179  // RemoveRanges removes all ports in the given PortRange values
   180  // from the set.
   181  func (ps *PortSet) RemoveRanges(portRanges ...PortRange) {
   182  	for _, portRange := range portRanges {
   183  		_, ok := ps.values[portRange.Protocol]
   184  		if ok {
   185  			for p := portRange.FromPort; p <= portRange.ToPort; p++ {
   186  				ps.Remove(portRange.Protocol, p)
   187  			}
   188  		}
   189  	}
   190  }
   191  
   192  // Contains returns true if the provided port is in the set.
   193  func (ps *PortSet) Contains(protocol string, port int) bool {
   194  	ports, ok := ps.values[protocol]
   195  	if !ok {
   196  		return false
   197  	}
   198  	return ports.Contains(port)
   199  }
   200  
   201  // ContainsRanges returns true if the provided port ranges are
   202  // in the set.
   203  func (ps *PortSet) ContainsRanges(portRanges ...PortRange) bool {
   204  	for _, portRange := range portRanges {
   205  		ports, ok := ps.values[portRange.Protocol]
   206  		if !ok {
   207  			return false
   208  		}
   209  		for p := portRange.FromPort; p <= portRange.ToPort; p++ {
   210  			if !ports.Contains(p) {
   211  				return false
   212  			}
   213  		}
   214  	}
   215  	return true
   216  }
   217  
   218  // Union returns a new PortSet of the shared values
   219  // that are common between both PortSets.
   220  func (ps PortSet) Union(other PortSet) PortSet {
   221  	result := NewPortSet()
   222  	for protocol, value := range ps.values {
   223  		result.values[protocol] = value.Union(nil)
   224  	}
   225  	for protocol, value := range other.values {
   226  		ports, ok := result.values[protocol]
   227  		if !ok {
   228  			value = nil
   229  		}
   230  		result.values[protocol] = ports.Union(value)
   231  	}
   232  	return result
   233  }
   234  
   235  // Intersection returns a new PortSet of the values that are in both
   236  // this set and the other, but not in just one of either.
   237  func (ps PortSet) Intersection(other PortSet) PortSet {
   238  	result := NewPortSet()
   239  	for protocol, value := range ps.values {
   240  		if ports, ok := other.values[protocol]; ok {
   241  			// For PortSet, a protocol without any associated ports
   242  			// doesn't make a lot of sense. It's also a waste of space.
   243  			// Consequently, if the intersection for a protocol is empty
   244  			// then we simply skip it.
   245  			if newValue := value.Intersection(ports); !newValue.IsEmpty() {
   246  				result.values[protocol] = newValue
   247  			}
   248  		}
   249  	}
   250  	return result
   251  }
   252  
   253  // Difference returns a new PortSet of the values
   254  // that are not in the other PortSet.
   255  func (ps PortSet) Difference(other PortSet) PortSet {
   256  	result := NewPortSet()
   257  	for protocol, value := range ps.values {
   258  		if ports, ok := other.values[protocol]; ok {
   259  			// For PortSet, a protocol without any associated ports
   260  			// doesn't make a lot of sense. It's also a waste of space.
   261  			// Consequently, if the difference for a protocol is empty
   262  			// then we simply skip it.
   263  			if newValue := value.Difference(ports); !newValue.IsEmpty() {
   264  				result.values[protocol] = newValue
   265  			}
   266  		} else {
   267  			result.values[protocol] = value
   268  		}
   269  	}
   270  	return result
   271  }