github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/crdb/keys_test.go (about) 1 package crdb 2 3 import ( 4 "context" 5 "net" 6 "sort" 7 "strings" 8 "testing" 9 10 "github.com/authzed/authzed-go/pkg/requestmeta" 11 "github.com/dustin/go-humanize" 12 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" 13 "github.com/stretchr/testify/require" 14 "golang.org/x/exp/maps" 15 "google.golang.org/grpc" 16 "google.golang.org/grpc/credentials/insecure" 17 "google.golang.org/grpc/metadata" 18 "google.golang.org/grpc/test/bufconn" 19 ) 20 21 func TestOverlapKeyAddition(t *testing.T) { 22 cases := []struct { 23 name string 24 keyer overlapKeyer 25 namespaces []string 26 expected keySet 27 }{ 28 { 29 name: "none", 30 keyer: noOverlapKeyer, 31 namespaces: []string{"a", "a/b", "c", "a/b/c"}, 32 expected: map[string]struct{}{}, 33 }, 34 { 35 name: "static", 36 keyer: appendStaticKey("test"), 37 namespaces: []string{"a", "a/b", "c", "a/b/c"}, 38 expected: map[string]struct{}{"test": {}}, 39 }, 40 { 41 name: "prefix with default", 42 keyer: prefixKeyer, 43 namespaces: []string{"a", "a/b", "c", "a/b/c"}, 44 expected: map[string]struct{}{ 45 defaultOverlapKey: {}, 46 "a": {}, 47 }, 48 }, 49 { 50 name: "prefix no default", 51 keyer: prefixKeyer, 52 namespaces: []string{"a/b", "a/b/c"}, 53 expected: map[string]struct{}{ 54 "a": {}, 55 }, 56 }, 57 } 58 for _, tt := range cases { 59 tt := tt 60 t.Run(tt.name, func(t *testing.T) { 61 set := newKeySet(context.Background()) 62 for _, n := range tt.namespaces { 63 tt.keyer.addKey(set, n) 64 } 65 require.EqualValues(t, tt.expected, set) 66 }) 67 } 68 } 69 70 type testServer struct { 71 testpb.UnimplementedTestServiceServer 72 } 73 74 func (t testServer) Ping(ctx context.Context, _ *testpb.PingRequest) (*testpb.PingResponse, error) { 75 keys := maps.Keys(overlapKeysFromContext(ctx)) 76 sort.Strings(keys) 77 return &testpb.PingResponse{Value: strings.Join(keys, ",")}, nil 78 } 79 80 func TestOverlapKeysFromContext(t *testing.T) { 81 overlapKey := string(requestmeta.RequestOverlapKey) 82 tests := []struct { 83 name string 84 headers []map[string]string 85 expected string 86 }{ 87 { 88 name: "no overlap keys", 89 expected: "", 90 }, 91 { 92 name: "an overlap key", 93 headers: []map[string]string{{ 94 overlapKey: "test", 95 }}, 96 expected: "test", 97 }, 98 { 99 name: "collapses duplicate overlap keys", 100 headers: []map[string]string{{ 101 overlapKey: "test,test", 102 }}, 103 expected: "test", 104 }, 105 { 106 name: "collapses duplicate overlap keys in different headers", 107 headers: []map[string]string{{ 108 overlapKey: "test,test", 109 }, { 110 overlapKey: "test,test", 111 }}, 112 expected: "test", 113 }, 114 { 115 name: "collects overlap keys from different headers, ignoring duplicates", 116 headers: []map[string]string{{ 117 overlapKey: "test,test1", 118 }, { 119 overlapKey: "test,test2", 120 }}, 121 expected: "test,test1,test2", 122 }, 123 { 124 name: "sanitizes space", 125 headers: []map[string]string{{ 126 overlapKey: " test,test1 , ", 127 }, { 128 overlapKey: "test, test2 , , ", 129 }}, 130 expected: "test,test1,test2", 131 }, 132 } 133 for _, tt := range tests { 134 listener := bufconn.Listen(humanize.MiByte) 135 s := grpc.NewServer() 136 testpb.RegisterTestServiceServer(s, &testServer{}) 137 go func() { 138 // Ignore any errors 139 _ = s.Serve(listener) 140 }() 141 142 conn, err := grpc.DialContext( 143 context.Background(), 144 "", 145 grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { 146 return listener.Dial() 147 }), 148 grpc.WithTransportCredentials(insecure.NewCredentials()), 149 grpc.WithBlock(), 150 ) 151 require.NoError(t, err) 152 153 t.Cleanup(func() { 154 conn.Close() 155 listener.Close() 156 s.Stop() 157 }) 158 client := testpb.NewTestServiceClient(conn) 159 160 t.Run(tt.name, func(t *testing.T) { 161 md := metadata.New(map[string]string{}) 162 for _, h := range tt.headers { 163 part := metadata.New(h) 164 md = metadata.Join(md, part) 165 } 166 ctx := metadata.NewOutgoingContext(context.Background(), md) 167 resp, err := client.Ping(ctx, &testpb.PingRequest{}) 168 require.NoError(t, err) 169 require.Equal(t, tt.expected, resp.Value) 170 }) 171 } 172 }