github.com/safing/portbase@v0.19.5/runtime/registry_test.go (about)

     1  package runtime
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/safing/portbase/database/query"
    12  	"github.com/safing/portbase/database/record"
    13  )
    14  
    15  type testRecord struct {
    16  	record.Base
    17  	sync.Mutex
    18  	Value string
    19  }
    20  
    21  func makeTestRecord(key, value string) record.Record {
    22  	r := &testRecord{Value: value}
    23  	r.CreateMeta()
    24  	r.SetKey("runtime:" + key)
    25  	return r
    26  }
    27  
    28  type testProvider struct {
    29  	k string
    30  	r []record.Record
    31  }
    32  
    33  func (tp *testProvider) Get(key string) ([]record.Record, error) {
    34  	return tp.r, nil
    35  }
    36  
    37  func (tp *testProvider) Set(r record.Record) (record.Record, error) {
    38  	return nil, errors.New("not implemented")
    39  }
    40  
    41  func getTestRegistry(t *testing.T) *Registry {
    42  	t.Helper()
    43  
    44  	r := NewRegistry()
    45  
    46  	providers := []testProvider{
    47  		{
    48  			k: "p1/",
    49  			r: []record.Record{
    50  				makeTestRecord("p1/f1/v1", "p1.1"),
    51  				makeTestRecord("p1/f2/v2", "p1.2"),
    52  				makeTestRecord("p1/v3", "p1.3"),
    53  			},
    54  		},
    55  		{
    56  			k: "p2/f1",
    57  			r: []record.Record{
    58  				makeTestRecord("p2/f1/v1", "p2.1"),
    59  				makeTestRecord("p2/f1/f2/v2", "p2.2"),
    60  				makeTestRecord("p2/f1/v3", "p2.3"),
    61  			},
    62  		},
    63  	}
    64  
    65  	for idx := range providers {
    66  		p := providers[idx]
    67  		_, err := r.Register(p.k, &p)
    68  		require.NoError(t, err)
    69  	}
    70  
    71  	return r
    72  }
    73  
    74  func TestRegistryGet(t *testing.T) {
    75  	t.Parallel()
    76  
    77  	var (
    78  		r   record.Record
    79  		err error
    80  	)
    81  
    82  	reg := getTestRegistry(t)
    83  
    84  	r, err = reg.Get("p1/f1/v1")
    85  	require.NoError(t, err)
    86  	require.NotNil(t, r)
    87  	assert.Equal(t, "p1.1", r.(*testRecord).Value) //nolint:forcetypeassert
    88  
    89  	r, err = reg.Get("p1/v3")
    90  	require.NoError(t, err)
    91  	require.NotNil(t, r)
    92  	assert.Equal(t, "p1.3", r.(*testRecord).Value) //nolint:forcetypeassert
    93  
    94  	r, err = reg.Get("p1/v4")
    95  	require.Error(t, err)
    96  	assert.Nil(t, r)
    97  
    98  	r, err = reg.Get("no-provider/foo")
    99  	require.Error(t, err)
   100  	assert.Nil(t, r)
   101  }
   102  
   103  func TestRegistryQuery(t *testing.T) {
   104  	t.Parallel()
   105  
   106  	reg := getTestRegistry(t)
   107  
   108  	q := query.New("runtime:p")
   109  	iter, err := reg.Query(q, true, true)
   110  	require.NoError(t, err)
   111  	require.NotNil(t, iter)
   112  	var records []record.Record //nolint:prealloc
   113  	for r := range iter.Next {
   114  		records = append(records, r)
   115  	}
   116  	assert.Len(t, records, 6)
   117  
   118  	q = query.New("runtime:p1/f")
   119  	iter, err = reg.Query(q, true, true)
   120  	require.NoError(t, err)
   121  	require.NotNil(t, iter)
   122  	records = nil
   123  	for r := range iter.Next {
   124  		records = append(records, r)
   125  	}
   126  	assert.Len(t, records, 2)
   127  }
   128  
   129  func TestRegistryRegister(t *testing.T) {
   130  	t.Parallel()
   131  
   132  	r := NewRegistry()
   133  
   134  	cases := []struct {
   135  		inp string
   136  		err bool
   137  	}{
   138  		{"runtime:foo/bar/bar", false},
   139  		{"runtime:foo/bar/bar2", false},
   140  		{"runtime:foo/bar", false},
   141  		{"runtime:foo/bar", true},  // already used
   142  		{"runtime:foo/bar/", true}, // cannot register a prefix if there are providers below
   143  		{"runtime:foo/baz/", false},
   144  		{"runtime:foo/baz2/", false},
   145  		{"runtime:foo/baz3", false},
   146  		{"runtime:foo/baz/bar", true},
   147  	}
   148  
   149  	for _, c := range cases {
   150  		_, err := r.Register(c.inp, nil)
   151  		if c.err {
   152  			assert.Error(t, err, c.inp)
   153  		} else {
   154  			assert.NoError(t, err, c.inp)
   155  		}
   156  	}
   157  }