github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/servicecache/servicecache.go (about)

     1  package servicecache
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"sync"
     7  
     8  	"go.aporeto.io/enforcerd/trireme-lib/common"
     9  	"go.aporeto.io/enforcerd/trireme-lib/utils/ipprefix"
    10  	"go.aporeto.io/enforcerd/trireme-lib/utils/portspec"
    11  )
    12  
    13  type entry struct {
    14  	id    string
    15  	ports *portspec.PortSpec
    16  	data  interface{}
    17  }
    18  
    19  type entryList []*entry
    20  
    21  func (e entryList) Delete(i int) entryList {
    22  	if i >= len(e) || i < 0 {
    23  		return e
    24  	}
    25  	return append(e[:i], e[i+1:]...)
    26  }
    27  
    28  // ServiceCache is a new service cache
    29  type ServiceCache struct {
    30  	// ipprefixs is map[prefixlength][prefix] -> array of entries indexed by port
    31  	local  ipprefix.IPcache
    32  	remote ipprefix.IPcache
    33  	// hostcaches is map[host] -> array of entries indexed by port.
    34  	remoteHosts map[string]entryList
    35  	localHosts  map[string]entryList
    36  	// portCaches is list of all ports where we can retrieve a service based on the port.
    37  	remotePorts entryList
    38  	localPorts  entryList
    39  	sync.RWMutex
    40  }
    41  
    42  // NewTable creates a new table
    43  func NewTable() *ServiceCache {
    44  
    45  	return &ServiceCache{
    46  		local:       ipprefix.NewIPCache(),
    47  		remote:      ipprefix.NewIPCache(),
    48  		remoteHosts: map[string]entryList{},
    49  		localHosts:  map[string]entryList{},
    50  	}
    51  }
    52  
    53  // Add adds a service into the cache. Returns error of if any overlap has been detected.
    54  func (s *ServiceCache) Add(e *common.Service, id string, data interface{}, local bool) error {
    55  	s.Lock()
    56  	defer s.Unlock()
    57  
    58  	record := &entry{
    59  		ports: e.Ports,
    60  		data:  data,
    61  		id:    id,
    62  	}
    63  	if err := s.addPorts(e, record, local); err != nil {
    64  		return err
    65  	}
    66  
    67  	if err := s.addHostService(e, record, local); err != nil {
    68  		return err
    69  	}
    70  
    71  	return s.addIPService(e, record, local)
    72  }
    73  
    74  // Find searches for a matching service, given an IP and port. Caller must specify
    75  // the local or remote context.
    76  func (s *ServiceCache) Find(ip net.IP, port int, host string, local bool) interface{} {
    77  	s.RLock()
    78  	defer s.RUnlock()
    79  
    80  	if host != "" {
    81  		if data := s.findHost(host, port, local); data != nil {
    82  			return data
    83  		}
    84  	}
    85  
    86  	return s.findIP(ip, port, local)
    87  }
    88  
    89  // FindListeningServicesForPU returns a service that is found and the associated
    90  // portSpecifications that refer to this service.
    91  func (s *ServiceCache) FindListeningServicesForPU(id string) (interface{}, *portspec.PortSpec) {
    92  	s.RLock()
    93  	defer s.RUnlock()
    94  
    95  	for _, spec := range s.localPorts {
    96  		if spec.id == id {
    97  			return spec.data, spec.ports
    98  		}
    99  	}
   100  	return nil, nil
   101  }
   102  
   103  // DeleteByID will delete all entries related to this ID from all references.
   104  func (s *ServiceCache) DeleteByID(id string, local bool) {
   105  	s.Lock()
   106  	defer s.Unlock()
   107  
   108  	hosts := s.remoteHosts
   109  	cache := s.remote
   110  	if local {
   111  		hosts = s.localHosts
   112  		cache = s.local
   113  	}
   114  
   115  	if local {
   116  		s.localPorts = deleteMatchingPorts(s.localPorts, id)
   117  	} else {
   118  		s.remotePorts = deleteMatchingPorts(s.remotePorts, id)
   119  	}
   120  
   121  	for host, ports := range hosts {
   122  		hosts[host] = deleteMatchingPorts(ports, id)
   123  		if len(hosts[host]) == 0 {
   124  			delete(hosts, host)
   125  		}
   126  	}
   127  
   128  	deleteMatching := func(val interface{}) interface{} {
   129  		if val == nil {
   130  			return nil
   131  		}
   132  
   133  		entryL := val.(entryList)
   134  		r := deleteMatchingPorts(entryL, id)
   135  		if len(r) == 0 {
   136  			return nil
   137  		}
   138  
   139  		return r
   140  	}
   141  
   142  	cache.RunFuncOnVals(deleteMatching)
   143  }
   144  
   145  func deleteMatchingPorts(list entryList, id string) entryList {
   146  	remainingPorts := entryList{}
   147  	for _, spec := range list {
   148  		if spec.id != id {
   149  			remainingPorts = append(remainingPorts, spec)
   150  		}
   151  	}
   152  	return remainingPorts
   153  }
   154  
   155  func (s *ServiceCache) addIPService(e *common.Service, record *entry, local bool) error {
   156  
   157  	cache := s.remote
   158  	if local {
   159  		cache = s.local
   160  	}
   161  
   162  	addresses := e.Addresses
   163  
   164  	// If addresses are nil, I only care about ports.
   165  	if len(e.Addresses) == 0 && len(e.FQDNs) == 0 {
   166  		addresses = map[string]struct{}{}
   167  		addresses["0.0.0.0/0"] = struct{}{}
   168  		addresses["::/0"] = struct{}{}
   169  	}
   170  
   171  	for addrS := range addresses {
   172  		var records entryList
   173  		var addr *net.IPNet
   174  		var err error
   175  
   176  		if _, addr, err = net.ParseCIDR(addrS); err != nil {
   177  			continue
   178  		}
   179  
   180  		mask, _ := addr.Mask.Size()
   181  		v, exists := cache.Get(addr.IP, mask)
   182  
   183  		if !exists {
   184  			records = entryList{}
   185  		} else {
   186  			records = v.(entryList)
   187  			for _, spec := range records {
   188  				if spec.ports.Overlaps(e.Ports) {
   189  					return fmt.Errorf("service port overlap for a given IP not allowed: ip %s, port %s", addr.String(), e.Ports.String())
   190  				}
   191  			}
   192  		}
   193  
   194  		records = append(records, record)
   195  		cache.Put(addr.IP, mask, records)
   196  	}
   197  
   198  	return nil
   199  }
   200  
   201  func (s *ServiceCache) addHostService(e *common.Service, record *entry, local bool) error {
   202  	hostCache := s.remoteHosts
   203  	if local {
   204  		hostCache = s.localHosts
   205  	}
   206  
   207  	// If addresses are nil, I only care about ports.
   208  	if len(e.FQDNs) == 0 {
   209  		return nil
   210  	}
   211  
   212  	for _, host := range e.FQDNs {
   213  		if _, ok := hostCache[host]; !ok {
   214  			hostCache[host] = entryList{}
   215  		}
   216  		for _, spec := range hostCache[host] {
   217  			if spec.ports.Overlaps(e.Ports) {
   218  				return fmt.Errorf("service port overlap for a given host not allowed: host %s, port %s", host, e.Ports.String())
   219  			}
   220  		}
   221  		hostCache[host] = append(hostCache[host], record)
   222  	}
   223  	return nil
   224  }
   225  
   226  // findIP searches for a matching service, given an IP and port
   227  func (s *ServiceCache) findIP(ip net.IP, port int, local bool) interface{} {
   228  
   229  	cache := s.remote
   230  	if local {
   231  		cache = s.local
   232  	}
   233  
   234  	if ip == nil {
   235  		return nil
   236  	}
   237  
   238  	var data interface{}
   239  
   240  	findMatch := func(val interface{}) bool {
   241  		if val != nil {
   242  			records := val.(entryList)
   243  			for _, e := range records {
   244  				if e.ports.IsIncluded(port) {
   245  					data = e.data
   246  					return true
   247  				}
   248  			}
   249  		}
   250  		return false
   251  	}
   252  
   253  	cache.RunFuncOnLpmIP(ip, findMatch)
   254  	return data
   255  }
   256  
   257  // findIP searches for a matching service, given an IP and port
   258  func (s *ServiceCache) findHost(host string, port int, local bool) interface{} {
   259  	hostCache := s.remoteHosts
   260  	if local {
   261  		hostCache = s.localHosts
   262  	}
   263  
   264  	entries, ok := hostCache[host]
   265  	if !ok {
   266  		return nil
   267  	}
   268  	for _, e := range entries {
   269  		if e.ports.IsIncluded(port) {
   270  			return e.data
   271  		}
   272  	}
   273  
   274  	return nil
   275  }
   276  
   277  // addPorts will only work for local ports.
   278  func (s *ServiceCache) addPorts(e *common.Service, record *entry, local bool) error {
   279  	if !local {
   280  		return nil
   281  	}
   282  
   283  	for _, spec := range s.localPorts {
   284  		if spec.ports.Overlaps(e.Ports) {
   285  			return fmt.Errorf("service port overlap in the global port list: %+v %s", e.Addresses, e.Ports.String())
   286  		}
   287  	}
   288  
   289  	s.localPorts = append(s.localPorts, record)
   290  
   291  	return nil
   292  }