github.com/safing/portbase@v0.19.5/runtime/registry_test.go (about) 1 package runtime 2 3 import ( 4 "errors" 5 "sync" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 "github.com/stretchr/testify/require" 10 11 "github.com/safing/portbase/database/query" 12 "github.com/safing/portbase/database/record" 13 ) 14 15 type testRecord struct { 16 record.Base 17 sync.Mutex 18 Value string 19 } 20 21 func makeTestRecord(key, value string) record.Record { 22 r := &testRecord{Value: value} 23 r.CreateMeta() 24 r.SetKey("runtime:" + key) 25 return r 26 } 27 28 type testProvider struct { 29 k string 30 r []record.Record 31 } 32 33 func (tp *testProvider) Get(key string) ([]record.Record, error) { 34 return tp.r, nil 35 } 36 37 func (tp *testProvider) Set(r record.Record) (record.Record, error) { 38 return nil, errors.New("not implemented") 39 } 40 41 func getTestRegistry(t *testing.T) *Registry { 42 t.Helper() 43 44 r := NewRegistry() 45 46 providers := []testProvider{ 47 { 48 k: "p1/", 49 r: []record.Record{ 50 makeTestRecord("p1/f1/v1", "p1.1"), 51 makeTestRecord("p1/f2/v2", "p1.2"), 52 makeTestRecord("p1/v3", "p1.3"), 53 }, 54 }, 55 { 56 k: "p2/f1", 57 r: []record.Record{ 58 makeTestRecord("p2/f1/v1", "p2.1"), 59 makeTestRecord("p2/f1/f2/v2", "p2.2"), 60 makeTestRecord("p2/f1/v3", "p2.3"), 61 }, 62 }, 63 } 64 65 for idx := range providers { 66 p := providers[idx] 67 _, err := r.Register(p.k, &p) 68 require.NoError(t, err) 69 } 70 71 return r 72 } 73 74 func TestRegistryGet(t *testing.T) { 75 t.Parallel() 76 77 var ( 78 r record.Record 79 err error 80 ) 81 82 reg := getTestRegistry(t) 83 84 r, err = reg.Get("p1/f1/v1") 85 require.NoError(t, err) 86 require.NotNil(t, r) 87 assert.Equal(t, "p1.1", r.(*testRecord).Value) //nolint:forcetypeassert 88 89 r, err = reg.Get("p1/v3") 90 require.NoError(t, err) 91 require.NotNil(t, r) 92 assert.Equal(t, "p1.3", r.(*testRecord).Value) //nolint:forcetypeassert 93 94 r, err = reg.Get("p1/v4") 95 require.Error(t, err) 96 assert.Nil(t, r) 97 98 r, err = reg.Get("no-provider/foo") 99 require.Error(t, err) 100 assert.Nil(t, r) 101 } 102 103 func TestRegistryQuery(t *testing.T) { 104 t.Parallel() 105 106 reg := getTestRegistry(t) 107 108 q := query.New("runtime:p") 109 iter, err := reg.Query(q, true, true) 110 require.NoError(t, err) 111 require.NotNil(t, iter) 112 var records []record.Record //nolint:prealloc 113 for r := range iter.Next { 114 records = append(records, r) 115 } 116 assert.Len(t, records, 6) 117 118 q = query.New("runtime:p1/f") 119 iter, err = reg.Query(q, true, true) 120 require.NoError(t, err) 121 require.NotNil(t, iter) 122 records = nil 123 for r := range iter.Next { 124 records = append(records, r) 125 } 126 assert.Len(t, records, 2) 127 } 128 129 func TestRegistryRegister(t *testing.T) { 130 t.Parallel() 131 132 r := NewRegistry() 133 134 cases := []struct { 135 inp string 136 err bool 137 }{ 138 {"runtime:foo/bar/bar", false}, 139 {"runtime:foo/bar/bar2", false}, 140 {"runtime:foo/bar", false}, 141 {"runtime:foo/bar", true}, // already used 142 {"runtime:foo/bar/", true}, // cannot register a prefix if there are providers below 143 {"runtime:foo/baz/", false}, 144 {"runtime:foo/baz2/", false}, 145 {"runtime:foo/baz3", false}, 146 {"runtime:foo/baz/bar", true}, 147 } 148 149 for _, c := range cases { 150 _, err := r.Register(c.inp, nil) 151 if c.err { 152 assert.Error(t, err, c.inp) 153 } else { 154 assert.NoError(t, err, c.inp) 155 } 156 } 157 }