github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/utils/portcache/portcache.go (about)

     1  package portcache
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"go.aporeto.io/enforcerd/trireme-lib/utils/cache"
     8  	"go.aporeto.io/enforcerd/trireme-lib/utils/portspec"
     9  )
    10  
    11  // PortCache is a generic cache of port pairs or exact ports. It can store
    12  // and do lookups of ports on exact matches or ranges. It returns the stored
    13  // values
    14  type PortCache struct {
    15  	ports  cache.DataStore
    16  	ranges []*portspec.PortSpec
    17  	sync.Mutex
    18  }
    19  
    20  // NewPortCache creates a new port cache
    21  func NewPortCache(name string) *PortCache {
    22  	return &PortCache{
    23  		ports:  cache.NewCache(name),
    24  		ranges: []*portspec.PortSpec{},
    25  	}
    26  }
    27  
    28  // AddPortSpec adds a port spec into the cache
    29  func (p *PortCache) AddPortSpec(s *portspec.PortSpec) {
    30  	if s.Min == s.Max {
    31  		p.ports.AddOrUpdate(s.Min, s)
    32  	} else {
    33  		// Remove the range if it exists
    34  		p.Remove(s) // nolint
    35  		// Insert the portspec
    36  		p.Lock()
    37  		p.ranges = append([]*portspec.PortSpec{s}, p.ranges...)
    38  		p.Unlock()
    39  	}
    40  }
    41  
    42  // AddPortSpecToEnd adds a range at the end of the cache
    43  func (p *PortCache) AddPortSpecToEnd(s *portspec.PortSpec) {
    44  
    45  	// Remove the range if it exists
    46  	p.Remove(s) // nolint
    47  
    48  	p.Lock()
    49  	p.ranges = append(p.ranges, s)
    50  	p.Unlock()
    51  
    52  }
    53  
    54  // AddUnique adds a port spec into the cache and makes sure its unique
    55  func (p *PortCache) AddUnique(s *portspec.PortSpec) error {
    56  	p.Lock()
    57  	defer p.Unlock()
    58  
    59  	if s.Min == s.Max {
    60  		if err, _ := p.ports.Get(s.Min); err != nil {
    61  			return fmt.Errorf("Port already exists: %s", err)
    62  		}
    63  	}
    64  
    65  	for _, r := range p.ranges {
    66  		if r.Max <= s.Min || r.Min >= s.Max {
    67  			continue
    68  		}
    69  		return fmt.Errorf("Overlap detected: %d %d", r.Max, r.Min)
    70  	}
    71  
    72  	if s.Min == s.Max {
    73  		return p.ports.Add(s.Min, s)
    74  	}
    75  
    76  	p.ranges = append(p.ranges, s)
    77  	return nil
    78  }
    79  
    80  // GetSpecValueFromPort searches the cache for a match based on a port
    81  // It will return the first match found on exact ports or on the ranges
    82  // of ports. If there are multiple intervals that match it will randomly
    83  // return one of them.
    84  func (p *PortCache) GetSpecValueFromPort(port uint16) (interface{}, error) {
    85  	if spec, err := p.ports.Get(port); err == nil {
    86  		return spec.(*portspec.PortSpec).Value(), nil
    87  	}
    88  
    89  	p.Lock()
    90  	defer p.Unlock()
    91  	for _, s := range p.ranges {
    92  		if s.Min <= port && port <= s.Max {
    93  			return s.Value(), nil
    94  		}
    95  	}
    96  
    97  	return nil, fmt.Errorf("No match for port %d", port)
    98  }
    99  
   100  // GetAllSpecValueFromPort will return all the specs that potentially match. This
   101  // will allow for overlapping ranges
   102  func (p *PortCache) GetAllSpecValueFromPort(port uint16) ([]interface{}, error) {
   103  	var allMatches []interface{}
   104  
   105  	if spec, err := p.ports.Get(port); err == nil {
   106  		allMatches = append(allMatches, spec.(*portspec.PortSpec).Value())
   107  	}
   108  
   109  	p.Lock()
   110  	defer p.Unlock()
   111  	for _, s := range p.ranges {
   112  		if s.Min <= port && port < s.Max {
   113  			allMatches = append(allMatches, s.Value())
   114  		}
   115  	}
   116  
   117  	if len(allMatches) == 0 {
   118  		return nil, fmt.Errorf("No match for port %d", port)
   119  	}
   120  	return allMatches, nil
   121  }
   122  
   123  // Remove will remove a port from the cache
   124  func (p *PortCache) Remove(s *portspec.PortSpec) error {
   125  
   126  	if s.Min == s.Max {
   127  		return p.ports.Remove(s.Min)
   128  	}
   129  
   130  	p.Lock()
   131  	defer p.Unlock()
   132  	for i, r := range p.ranges {
   133  		if r.Min == s.Min && r.Max == s.Max {
   134  			left := p.ranges[:i]
   135  			right := p.ranges[i+1:]
   136  			p.ranges = append(left, right...)
   137  			return nil
   138  		}
   139  	}
   140  
   141  	return fmt.Errorf("port not found")
   142  }
   143  
   144  // RemoveStringPorts will remove a port from the cache
   145  func (p *PortCache) RemoveStringPorts(ports string) error {
   146  
   147  	s, err := portspec.NewPortSpecFromString(ports, nil)
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	return p.Remove(s)
   153  }