github.com/safing/portbase@v0.19.5/runtime/registry.go (about) 1 package runtime 2 3 import ( 4 "errors" 5 "fmt" 6 "strings" 7 "sync" 8 9 "github.com/armon/go-radix" 10 "golang.org/x/sync/errgroup" 11 12 "github.com/safing/portbase/database" 13 "github.com/safing/portbase/database/iterator" 14 "github.com/safing/portbase/database/query" 15 "github.com/safing/portbase/database/record" 16 "github.com/safing/portbase/database/storage" 17 "github.com/safing/portbase/log" 18 ) 19 20 var ( 21 // ErrKeyTaken is returned when trying to register 22 // a value provider at database key or prefix that 23 // is already occupied by another provider. 24 ErrKeyTaken = errors.New("runtime key or prefix already used") 25 // ErrKeyUnmanaged is returned when a Put operation 26 // on an unmanaged key is performed. 27 ErrKeyUnmanaged = errors.New("runtime key not managed by any provider") 28 // ErrInjected is returned by Registry.InjectAsDatabase 29 // if the registry has already been injected. 30 ErrInjected = errors.New("registry already injected") 31 ) 32 33 // Registry keeps track of registered runtime 34 // value providers and exposes them via an 35 // injected database. Users normally just need 36 // to use the defaul registry provided by this 37 // package but may consider creating a dedicated 38 // runtime registry on their own. Registry uses 39 // a radix tree for value providers and their 40 // chosen database key/prefix. 41 type Registry struct { 42 l sync.RWMutex 43 providers *radix.Tree 44 dbController *database.Controller 45 dbName string 46 } 47 48 // keyedValueProvider simply wraps a value provider with it's 49 // registration prefix. 50 type keyedValueProvider struct { 51 ValueProvider 52 key string 53 } 54 55 // NewRegistry returns a new registry. 56 func NewRegistry() *Registry { 57 return &Registry{ 58 providers: radix.New(), 59 } 60 } 61 62 func isPrefixKey(key string) bool { 63 return strings.HasSuffix(key, "/") 64 } 65 66 // DatabaseName returns the name of the database where the 67 // registry has been injected. It returns an empty string 68 // if InjectAsDatabase has not been called. 69 func (r *Registry) DatabaseName() string { 70 r.l.RLock() 71 defer r.l.RUnlock() 72 73 return r.dbName 74 } 75 76 // InjectAsDatabase injects the registry as the storage 77 // database for name. 78 func (r *Registry) InjectAsDatabase(name string) error { 79 r.l.Lock() 80 defer r.l.Unlock() 81 82 if r.dbController != nil { 83 return ErrInjected 84 } 85 86 ctrl, err := database.InjectDatabase(name, r.asStorage()) 87 if err != nil { 88 return err 89 } 90 91 r.dbName = name 92 r.dbController = ctrl 93 94 return nil 95 } 96 97 // Register registers a new value provider p under keyOrPrefix. The 98 // returned PushFunc can be used to send update notitifcations to 99 // database subscribers. Note that keyOrPrefix must end in '/' to be 100 // accepted as a prefix. 101 func (r *Registry) Register(keyOrPrefix string, p ValueProvider) (PushFunc, error) { 102 r.l.Lock() 103 defer r.l.Unlock() 104 105 // search if there's a provider registered for a prefix 106 // that matches or is equal to keyOrPrefix. 107 key, _, ok := r.providers.LongestPrefix(keyOrPrefix) 108 if ok && (isPrefixKey(key) || key == keyOrPrefix) { 109 return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, key) 110 } 111 112 // if keyOrPrefix is a prefix there must not be any provider 113 // registered for a key that matches keyOrPrefix. 114 if isPrefixKey(keyOrPrefix) { 115 foundProvider := "" 116 r.providers.WalkPrefix(keyOrPrefix, func(s string, _ interface{}) bool { 117 foundProvider = s 118 return true 119 }) 120 if foundProvider != "" { 121 return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, foundProvider) 122 } 123 } 124 125 r.providers.Insert(keyOrPrefix, &keyedValueProvider{ 126 ValueProvider: TraceProvider(p), 127 key: keyOrPrefix, 128 }) 129 130 log.Tracef("runtime: registered new provider at %s", keyOrPrefix) 131 132 return func(records ...record.Record) { 133 r.l.RLock() 134 defer r.l.RUnlock() 135 136 if r.dbController == nil { 137 return 138 } 139 140 for _, rec := range records { 141 r.dbController.PushUpdate(rec) 142 } 143 }, nil 144 } 145 146 // Get returns the runtime value that is identified by key. 147 // It implements the storage.Interface. 148 func (r *Registry) Get(key string) (record.Record, error) { 149 provider := r.getMatchingProvider(key) 150 if provider == nil { 151 return nil, database.ErrNotFound 152 } 153 154 records, err := provider.Get(key) 155 if err != nil { 156 // instead of returning ErrWriteOnly to the database interface 157 // we wrap it in ErrNotFound so the records effectively gets 158 // hidden. 159 if errors.Is(err, ErrWriteOnly) { 160 return nil, database.ErrNotFound 161 } 162 return nil, err 163 } 164 165 // Get performs an exact match so filter out 166 // and values that do not match key. 167 for _, r := range records { 168 if r.DatabaseKey() == key { 169 return r, nil 170 } 171 } 172 173 return nil, database.ErrNotFound 174 } 175 176 // Put stores the record m in the runtime database. Note that 177 // ErrReadOnly is returned if there's no value provider responsible 178 // for m.Key(). 179 func (r *Registry) Put(m record.Record) (record.Record, error) { 180 provider := r.getMatchingProvider(m.DatabaseKey()) 181 if provider == nil { 182 // if there's no provider for the given value 183 // return ErrKeyUnmanaged. 184 return nil, ErrKeyUnmanaged 185 } 186 187 res, err := provider.Set(m) 188 if err != nil { 189 return nil, err 190 } 191 return res, nil 192 } 193 194 // Query performs a query on the runtime registry returning all 195 // records across all value providers that match q. 196 // Query implements the storage.Storage interface. 197 func (r *Registry) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { 198 if _, err := q.Check(); err != nil { 199 return nil, fmt.Errorf("invalid query: %w", err) 200 } 201 202 searchPrefix := q.DatabaseKeyPrefix() 203 providers := r.collectProviderByPrefix(searchPrefix) 204 if len(providers) == 0 { 205 return nil, fmt.Errorf("%w: for key %s", ErrKeyUnmanaged, searchPrefix) 206 } 207 208 iter := iterator.New() 209 210 grp := new(errgroup.Group) 211 for idx := range providers { 212 p := providers[idx] 213 214 grp.Go(func() (err error) { 215 defer recovery(&err) 216 217 key := p.key 218 if len(searchPrefix) > len(key) { 219 key = searchPrefix 220 } 221 222 records, err := p.Get(key) 223 if err != nil { 224 if errors.Is(err, ErrWriteOnly) { 225 return nil 226 } 227 return err 228 } 229 230 for _, r := range records { 231 r.Lock() 232 var ( 233 matchesKey = q.MatchesKey(r.DatabaseKey()) 234 isValid = r.Meta().CheckValidity() 235 isAllowed = r.Meta().CheckPermission(local, internal) 236 237 allowed = matchesKey && isValid && isAllowed 238 ) 239 if allowed { 240 allowed = q.MatchesRecord(r) 241 } 242 r.Unlock() 243 244 if !allowed { 245 log.Tracef("runtime: not sending %s for query %s. matchesKey=%v isValid=%v isAllowed=%v", r.DatabaseKey(), searchPrefix, matchesKey, isValid, isAllowed) 246 continue 247 } 248 249 select { 250 case iter.Next <- r: 251 case <-iter.Done: 252 return nil 253 } 254 } 255 256 return nil 257 }) 258 } 259 260 go func() { 261 err := grp.Wait() 262 iter.Finish(err) 263 }() 264 265 return iter, nil 266 } 267 268 func (r *Registry) getMatchingProvider(key string) *keyedValueProvider { 269 r.l.RLock() 270 defer r.l.RUnlock() 271 272 providerKey, provider, ok := r.providers.LongestPrefix(key) 273 if !ok { 274 return nil 275 } 276 277 if !isPrefixKey(providerKey) && providerKey != key { 278 return nil 279 } 280 281 return provider.(*keyedValueProvider) //nolint:forcetypeassert 282 } 283 284 func (r *Registry) collectProviderByPrefix(prefix string) []*keyedValueProvider { 285 r.l.RLock() 286 defer r.l.RUnlock() 287 288 // if there's a LongestPrefix provider that's the only one 289 // we need to ask 290 if _, p, ok := r.providers.LongestPrefix(prefix); ok { 291 return []*keyedValueProvider{p.(*keyedValueProvider)} //nolint:forcetypeassert 292 } 293 294 var providers []*keyedValueProvider 295 r.providers.WalkPrefix(prefix, func(key string, p interface{}) bool { 296 providers = append(providers, p.(*keyedValueProvider)) //nolint:forcetypeassert 297 return false 298 }) 299 300 return providers 301 } 302 303 // GetRegistrationKeys returns a list of all provider registration 304 // keys or prefixes. 305 func (r *Registry) GetRegistrationKeys() []string { 306 r.l.RLock() 307 defer r.l.RUnlock() 308 309 var keys []string 310 311 r.providers.Walk(func(key string, p interface{}) bool { 312 keys = append(keys, key) 313 return false 314 }) 315 return keys 316 } 317 318 // asStorage returns a storage.Interface compatible struct 319 // that is backed by r. 320 func (r *Registry) asStorage() storage.Interface { 321 return &storageWrapper{ 322 Registry: r, 323 } 324 } 325 326 func recovery(err *error) { 327 if x := recover(); x != nil { 328 if e, ok := x.(error); ok { 329 *err = e 330 return 331 } 332 333 *err = fmt.Errorf("%v", x) 334 } 335 }