github.com/xmidt-org/webpa-common@v1.11.9/device/devicegate/filter.go (about)

     1  package devicegate
     2  
     3  import (
     4  	"encoding/json"
     5  	"sync"
     6  
     7  	"github.com/xmidt-org/webpa-common/device"
     8  )
     9  
    10  const (
    11  	metadataMapLocation = "metadata_map"
    12  	claimsLocation      = "claims"
    13  )
    14  
    15  // Interface is a gate interface specifically for filtering devices
    16  type Interface interface {
    17  	device.Filter
    18  
    19  	// VisitAll applies the given visitor function to each set of filter values
    20  	//
    21  	// No methods on this Interface should be called from within the visitor function, or
    22  	// a deadlock will likely occur.
    23  	VisitAll(visit func(string, Set) bool) int
    24  
    25  	// GetFilter returns the set of filter values associated with a filter key and a bool
    26  	// that is true if the key was found, false if it doesn't exist.
    27  	GetFilter(key string) (Set, bool)
    28  
    29  	// SetFilter saves the filter values and filter key to filter by. It returns a Set of the old values and a
    30  	// bool that is true if the filter key did not previously exist and false if the filter key had existed beforehand.
    31  	SetFilter(key string, values []interface{}) (Set, bool)
    32  
    33  	// DeleteFilter deletes a filter key. This completely removes all filter values associated with that key as well.
    34  	// Returns true if key had existed and values actually deleted, and false if key was not found.
    35  	DeleteFilter(key string) bool
    36  
    37  	// GetAllowedFilters returns the set of filters that devices are allowed to be filtered by. Also returns a
    38  	// bool that is true if there are allowed filters set, and false if there aren't (meaning that all filters are allowed)
    39  	GetAllowedFilters() (Set, bool)
    40  }
    41  
    42  // Set is an interface that represents a read-only hashset
    43  type Set interface {
    44  	json.Marshaler
    45  	// Has returns true if a value exists in the set, false if it doesn't.
    46  	Has(interface{}) bool
    47  
    48  	// VisitAll applies the visitor function to every value in the set.
    49  	VisitAll(func(interface{}))
    50  }
    51  
    52  // FilterStore can be used to store filters in the Interface
    53  type FilterStore map[string]Set
    54  
    55  // FilterSet is a concrete type that implements the Set interface
    56  type FilterSet struct {
    57  	Set  map[interface{}]bool
    58  	lock sync.RWMutex
    59  }
    60  
    61  // FilterGate is a concrete implementation of the Interface
    62  type FilterGate struct {
    63  	FilterStore    FilterStore `json:"filters"`
    64  	AllowedFilters Set         `json:"allowedFilters"`
    65  
    66  	lock sync.RWMutex
    67  }
    68  
    69  type FilterRequest struct {
    70  	Key    string        `json:"key"`
    71  	Values []interface{} `json:"values"`
    72  }
    73  
    74  func (f *FilterGate) VisitAll(visit func(string, Set) bool) int {
    75  	f.lock.RLock()
    76  	defer f.lock.RUnlock()
    77  
    78  	visited := 0
    79  	for key, filterValues := range f.FilterStore {
    80  		visited++
    81  		if !visit(key, filterValues) {
    82  			break
    83  		}
    84  	}
    85  
    86  	return visited
    87  }
    88  
    89  func (f *FilterGate) GetFilter(key string) (Set, bool) {
    90  	f.lock.RLock()
    91  	defer f.lock.RUnlock()
    92  
    93  	v, ok := f.FilterStore[key]
    94  	return v, ok
    95  
    96  }
    97  
    98  func (f *FilterGate) SetFilter(key string, values []interface{}) (Set, bool) {
    99  	f.lock.Lock()
   100  	defer f.lock.Unlock()
   101  
   102  	oldValues := f.FilterStore[key]
   103  	newValues := make(map[interface{}]bool)
   104  
   105  	for _, v := range values {
   106  		newValues[v] = true
   107  	}
   108  
   109  	f.FilterStore[key] = &FilterSet{
   110  		Set: newValues,
   111  	}
   112  
   113  	if oldValues == nil {
   114  		return oldValues, true
   115  	}
   116  
   117  	return oldValues, false
   118  
   119  }
   120  
   121  func (f *FilterGate) DeleteFilter(key string) bool {
   122  	f.lock.Lock()
   123  	defer f.lock.Unlock()
   124  
   125  	_, ok := f.FilterStore[key]
   126  
   127  	if ok {
   128  		delete(f.FilterStore, key)
   129  		return true
   130  	}
   131  
   132  	return false
   133  }
   134  
   135  func (f *FilterGate) AllowConnection(d device.Interface) (bool, device.MatchResult) {
   136  	f.lock.RLock()
   137  	defer f.lock.RUnlock()
   138  
   139  	for filterKey, filterValues := range f.FilterStore {
   140  		// check for filter match
   141  		if found, result := f.FilterStore.metadataMatch(filterKey, filterValues, d.Metadata()); found {
   142  			return false, result
   143  		}
   144  	}
   145  
   146  	return true, device.MatchResult{}
   147  }
   148  
   149  func (f *FilterGate) GetAllowedFilters() (Set, bool) {
   150  	if f.AllowedFilters == nil {
   151  		return f.AllowedFilters, false
   152  	}
   153  
   154  	return f.AllowedFilters, true
   155  }
   156  
   157  func (s *FilterSet) Has(key interface{}) bool {
   158  	if s.Set != nil {
   159  		s.lock.RLock()
   160  		defer s.lock.RUnlock()
   161  		return s.Set[key]
   162  	}
   163  
   164  	return false
   165  }
   166  
   167  func (s *FilterSet) VisitAll(f func(interface{})) {
   168  	s.lock.RLock()
   169  	defer s.lock.RUnlock()
   170  	for key := range s.Set {
   171  		f(key)
   172  	}
   173  }
   174  
   175  func (s *FilterSet) MarshalJSON() ([]byte, error) {
   176  	s.lock.RLock()
   177  	defer s.lock.RUnlock()
   178  	temp := make([]interface{}, 0, len(s.Set))
   179  	for key := range s.Set {
   180  		temp = append(temp, key)
   181  	}
   182  
   183  	return json.Marshal(temp)
   184  }
   185  
   186  func (f *FilterStore) metadataMatch(keyToCheck string, filterValues Set, m *device.Metadata) (bool, device.MatchResult) {
   187  	var val interface{}
   188  	result := device.MatchResult{
   189  		Key: keyToCheck,
   190  	}
   191  	if metadataVal := m.Load(keyToCheck); metadataVal != nil {
   192  		val = metadataVal
   193  		result.Location = metadataMapLocation
   194  	} else if claimsVal, found := m.Claims()[keyToCheck]; found {
   195  		val = claimsVal
   196  		result.Location = claimsLocation
   197  	}
   198  
   199  	if val != nil {
   200  		switch t := val.(type) {
   201  		case []interface{}:
   202  			if filterMatch(filterValues, t...) {
   203  				return true, result
   204  			}
   205  		case interface{}:
   206  			if filterMatch(filterValues, t) {
   207  				return true, result
   208  			}
   209  		}
   210  	}
   211  
   212  	return false, device.MatchResult{}
   213  }
   214  
   215  // function to check if any params are in a set
   216  func filterMatch(filterValues Set, paramsToCheck ...interface{}) bool {
   217  	for _, param := range paramsToCheck {
   218  		if filterValues.Has(param) {
   219  			return true
   220  		}
   221  	}
   222  
   223  	return false
   224  
   225  }