github.com/safing/portbase@v0.19.5/runtime/registry.go (about)

     1  package runtime
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"strings"
     7  	"sync"
     8  
     9  	"github.com/armon/go-radix"
    10  	"golang.org/x/sync/errgroup"
    11  
    12  	"github.com/safing/portbase/database"
    13  	"github.com/safing/portbase/database/iterator"
    14  	"github.com/safing/portbase/database/query"
    15  	"github.com/safing/portbase/database/record"
    16  	"github.com/safing/portbase/database/storage"
    17  	"github.com/safing/portbase/log"
    18  )
    19  
    20  var (
    21  	// ErrKeyTaken is returned when trying to register
    22  	// a value provider at database key or prefix that
    23  	// is already occupied by another provider.
    24  	ErrKeyTaken = errors.New("runtime key or prefix already used")
    25  	// ErrKeyUnmanaged is returned when a Put operation
    26  	// on an unmanaged key is performed.
    27  	ErrKeyUnmanaged = errors.New("runtime key not managed by any provider")
    28  	// ErrInjected is returned by Registry.InjectAsDatabase
    29  	// if the registry has already been injected.
    30  	ErrInjected = errors.New("registry already injected")
    31  )
    32  
    33  // Registry keeps track of registered runtime
    34  // value providers and exposes them via an
    35  // injected database. Users normally just need
    36  // to use the defaul registry provided by this
    37  // package but may consider creating a dedicated
    38  // runtime registry on their own. Registry uses
    39  // a radix tree for value providers and their
    40  // chosen database key/prefix.
    41  type Registry struct {
    42  	l            sync.RWMutex
    43  	providers    *radix.Tree
    44  	dbController *database.Controller
    45  	dbName       string
    46  }
    47  
    48  // keyedValueProvider simply wraps a value provider with it's
    49  // registration prefix.
    50  type keyedValueProvider struct {
    51  	ValueProvider
    52  	key string
    53  }
    54  
    55  // NewRegistry returns a new registry.
    56  func NewRegistry() *Registry {
    57  	return &Registry{
    58  		providers: radix.New(),
    59  	}
    60  }
    61  
    62  func isPrefixKey(key string) bool {
    63  	return strings.HasSuffix(key, "/")
    64  }
    65  
    66  // DatabaseName returns the name of the database where the
    67  // registry has been injected. It returns an empty string
    68  // if InjectAsDatabase has not been called.
    69  func (r *Registry) DatabaseName() string {
    70  	r.l.RLock()
    71  	defer r.l.RUnlock()
    72  
    73  	return r.dbName
    74  }
    75  
    76  // InjectAsDatabase injects the registry as the storage
    77  // database for name.
    78  func (r *Registry) InjectAsDatabase(name string) error {
    79  	r.l.Lock()
    80  	defer r.l.Unlock()
    81  
    82  	if r.dbController != nil {
    83  		return ErrInjected
    84  	}
    85  
    86  	ctrl, err := database.InjectDatabase(name, r.asStorage())
    87  	if err != nil {
    88  		return err
    89  	}
    90  
    91  	r.dbName = name
    92  	r.dbController = ctrl
    93  
    94  	return nil
    95  }
    96  
    97  // Register registers a new value provider p under keyOrPrefix. The
    98  // returned PushFunc can be used to send update notitifcations to
    99  // database subscribers. Note that keyOrPrefix must end in '/' to be
   100  // accepted as a prefix.
   101  func (r *Registry) Register(keyOrPrefix string, p ValueProvider) (PushFunc, error) {
   102  	r.l.Lock()
   103  	defer r.l.Unlock()
   104  
   105  	// search if there's a provider registered for a prefix
   106  	// that matches or is equal to keyOrPrefix.
   107  	key, _, ok := r.providers.LongestPrefix(keyOrPrefix)
   108  	if ok && (isPrefixKey(key) || key == keyOrPrefix) {
   109  		return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, key)
   110  	}
   111  
   112  	// if keyOrPrefix is a prefix there must not be any provider
   113  	// registered for a key that matches keyOrPrefix.
   114  	if isPrefixKey(keyOrPrefix) {
   115  		foundProvider := ""
   116  		r.providers.WalkPrefix(keyOrPrefix, func(s string, _ interface{}) bool {
   117  			foundProvider = s
   118  			return true
   119  		})
   120  		if foundProvider != "" {
   121  			return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, foundProvider)
   122  		}
   123  	}
   124  
   125  	r.providers.Insert(keyOrPrefix, &keyedValueProvider{
   126  		ValueProvider: TraceProvider(p),
   127  		key:           keyOrPrefix,
   128  	})
   129  
   130  	log.Tracef("runtime: registered new provider at %s", keyOrPrefix)
   131  
   132  	return func(records ...record.Record) {
   133  		r.l.RLock()
   134  		defer r.l.RUnlock()
   135  
   136  		if r.dbController == nil {
   137  			return
   138  		}
   139  
   140  		for _, rec := range records {
   141  			r.dbController.PushUpdate(rec)
   142  		}
   143  	}, nil
   144  }
   145  
   146  // Get returns the runtime value that is identified by key.
   147  // It implements the storage.Interface.
   148  func (r *Registry) Get(key string) (record.Record, error) {
   149  	provider := r.getMatchingProvider(key)
   150  	if provider == nil {
   151  		return nil, database.ErrNotFound
   152  	}
   153  
   154  	records, err := provider.Get(key)
   155  	if err != nil {
   156  		// instead of returning ErrWriteOnly to the database interface
   157  		// we wrap it in ErrNotFound so the records effectively gets
   158  		// hidden.
   159  		if errors.Is(err, ErrWriteOnly) {
   160  			return nil, database.ErrNotFound
   161  		}
   162  		return nil, err
   163  	}
   164  
   165  	// Get performs an exact match so filter out
   166  	// and values that do not match key.
   167  	for _, r := range records {
   168  		if r.DatabaseKey() == key {
   169  			return r, nil
   170  		}
   171  	}
   172  
   173  	return nil, database.ErrNotFound
   174  }
   175  
   176  // Put stores the record m in the runtime database. Note that
   177  // ErrReadOnly is returned if there's no value provider responsible
   178  // for m.Key().
   179  func (r *Registry) Put(m record.Record) (record.Record, error) {
   180  	provider := r.getMatchingProvider(m.DatabaseKey())
   181  	if provider == nil {
   182  		// if there's no provider for the given value
   183  		// return ErrKeyUnmanaged.
   184  		return nil, ErrKeyUnmanaged
   185  	}
   186  
   187  	res, err := provider.Set(m)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	return res, nil
   192  }
   193  
   194  // Query performs a query on the runtime registry returning all
   195  // records across all value providers that match q.
   196  // Query implements the storage.Storage interface.
   197  func (r *Registry) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
   198  	if _, err := q.Check(); err != nil {
   199  		return nil, fmt.Errorf("invalid query: %w", err)
   200  	}
   201  
   202  	searchPrefix := q.DatabaseKeyPrefix()
   203  	providers := r.collectProviderByPrefix(searchPrefix)
   204  	if len(providers) == 0 {
   205  		return nil, fmt.Errorf("%w: for key %s", ErrKeyUnmanaged, searchPrefix)
   206  	}
   207  
   208  	iter := iterator.New()
   209  
   210  	grp := new(errgroup.Group)
   211  	for idx := range providers {
   212  		p := providers[idx]
   213  
   214  		grp.Go(func() (err error) {
   215  			defer recovery(&err)
   216  
   217  			key := p.key
   218  			if len(searchPrefix) > len(key) {
   219  				key = searchPrefix
   220  			}
   221  
   222  			records, err := p.Get(key)
   223  			if err != nil {
   224  				if errors.Is(err, ErrWriteOnly) {
   225  					return nil
   226  				}
   227  				return err
   228  			}
   229  
   230  			for _, r := range records {
   231  				r.Lock()
   232  				var (
   233  					matchesKey = q.MatchesKey(r.DatabaseKey())
   234  					isValid    = r.Meta().CheckValidity()
   235  					isAllowed  = r.Meta().CheckPermission(local, internal)
   236  
   237  					allowed = matchesKey && isValid && isAllowed
   238  				)
   239  				if allowed {
   240  					allowed = q.MatchesRecord(r)
   241  				}
   242  				r.Unlock()
   243  
   244  				if !allowed {
   245  					log.Tracef("runtime: not sending %s for query %s. matchesKey=%v isValid=%v isAllowed=%v", r.DatabaseKey(), searchPrefix, matchesKey, isValid, isAllowed)
   246  					continue
   247  				}
   248  
   249  				select {
   250  				case iter.Next <- r:
   251  				case <-iter.Done:
   252  					return nil
   253  				}
   254  			}
   255  
   256  			return nil
   257  		})
   258  	}
   259  
   260  	go func() {
   261  		err := grp.Wait()
   262  		iter.Finish(err)
   263  	}()
   264  
   265  	return iter, nil
   266  }
   267  
   268  func (r *Registry) getMatchingProvider(key string) *keyedValueProvider {
   269  	r.l.RLock()
   270  	defer r.l.RUnlock()
   271  
   272  	providerKey, provider, ok := r.providers.LongestPrefix(key)
   273  	if !ok {
   274  		return nil
   275  	}
   276  
   277  	if !isPrefixKey(providerKey) && providerKey != key {
   278  		return nil
   279  	}
   280  
   281  	return provider.(*keyedValueProvider) //nolint:forcetypeassert
   282  }
   283  
   284  func (r *Registry) collectProviderByPrefix(prefix string) []*keyedValueProvider {
   285  	r.l.RLock()
   286  	defer r.l.RUnlock()
   287  
   288  	// if there's a LongestPrefix provider that's the only one
   289  	// we need to ask
   290  	if _, p, ok := r.providers.LongestPrefix(prefix); ok {
   291  		return []*keyedValueProvider{p.(*keyedValueProvider)} //nolint:forcetypeassert
   292  	}
   293  
   294  	var providers []*keyedValueProvider
   295  	r.providers.WalkPrefix(prefix, func(key string, p interface{}) bool {
   296  		providers = append(providers, p.(*keyedValueProvider)) //nolint:forcetypeassert
   297  		return false
   298  	})
   299  
   300  	return providers
   301  }
   302  
   303  // GetRegistrationKeys returns a list of all provider registration
   304  // keys or prefixes.
   305  func (r *Registry) GetRegistrationKeys() []string {
   306  	r.l.RLock()
   307  	defer r.l.RUnlock()
   308  
   309  	var keys []string
   310  
   311  	r.providers.Walk(func(key string, p interface{}) bool {
   312  		keys = append(keys, key)
   313  		return false
   314  	})
   315  	return keys
   316  }
   317  
   318  // asStorage returns a storage.Interface compatible struct
   319  // that is backed by r.
   320  func (r *Registry) asStorage() storage.Interface {
   321  	return &storageWrapper{
   322  		Registry: r,
   323  	}
   324  }
   325  
   326  func recovery(err *error) {
   327  	if x := recover(); x != nil {
   328  		if e, ok := x.(error); ok {
   329  			*err = e
   330  			return
   331  		}
   332  
   333  		*err = fmt.Errorf("%v", x)
   334  	}
   335  }