vitess.io/vitess@v0.16.2/go/test/endtoend/vtgate/grpc_server_auth_static/main_test.go (about) 1 /* 2 Copyright 2023 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package grpcserverauthstatic 18 19 import ( 20 "context" 21 "flag" 22 "fmt" 23 "os" 24 "path" 25 "testing" 26 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 "google.golang.org/grpc" 30 31 "vitess.io/vitess/go/test/endtoend/cluster" 32 "vitess.io/vitess/go/vt/grpcclient" 33 "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" 34 "vitess.io/vitess/go/vt/vtgate/vtgateconn" 35 ) 36 37 var ( 38 clusterInstance *cluster.LocalProcessCluster 39 vtgateGrpcAddress string 40 hostname = "localhost" 41 keyspaceName = "ks" 42 cell = "zone1" 43 sqlSchema = ` 44 create table test_table ( 45 id bigint, 46 val varchar(128), 47 primary key(id) 48 ) Engine=InnoDB; 49 ` 50 grpcServerAuthStaticJSON = ` 51 [ 52 { 53 "Username": "user_with_access", 54 "Password": "test_password" 55 }, 56 { 57 "Username": "user_no_access", 58 "Password": "test_password" 59 } 60 ] 61 ` 62 tableACLJSON = ` 63 { 64 "table_groups": [ 65 { 66 "name": "default", 67 "table_names_or_prefixes": ["%"], 68 "readers": ["user_with_access"], 69 "writers": ["user_with_access"], 70 "admins": ["user_with_access"] 71 } 72 ] 73 } 74 ` 75 ) 76 77 func TestMain(m *testing.M) { 78 defer cluster.PanicHandler(nil) 79 flag.Parse() 80 81 exitcode := func() int { 82 clusterInstance = cluster.NewCluster(cell, hostname) 83 defer clusterInstance.Teardown() 84 85 // Start topo server 86 if err := clusterInstance.StartTopo(); err != nil { 87 return 1 88 } 89 90 // Directory for authn / authz config files 91 authDirectory := path.Join(clusterInstance.TmpDirectory, "auth") 92 if err := os.Mkdir(authDirectory, 0700); err != nil { 93 return 1 94 } 95 96 // Create grpc_server_auth_static.json file 97 grpcServerAuthStaticPath := path.Join(authDirectory, "grpc_server_auth_static.json") 98 if err := createFile(grpcServerAuthStaticPath, grpcServerAuthStaticJSON); err != nil { 99 return 1 100 } 101 102 // Create table_acl.json file 103 tableACLPath := path.Join(authDirectory, "table_acl.json") 104 if err := createFile(tableACLPath, tableACLJSON); err != nil { 105 return 1 106 } 107 108 // Configure vtgate to use static auth 109 clusterInstance.VtGateExtraArgs = []string{ 110 "--grpc_auth_mode", "static", 111 "--grpc_auth_static_password_file", grpcServerAuthStaticPath, 112 "--grpc-use-static-authentication-callerid", 113 } 114 115 // Configure vttablet to use table ACL 116 clusterInstance.VtTabletExtraArgs = []string{ 117 "--enforce-tableacl-config", 118 "--queryserver-config-strict-table-acl", 119 "--table-acl-config", tableACLPath, 120 } 121 122 // Start keyspace 123 keyspace := &cluster.Keyspace{ 124 Name: keyspaceName, 125 SchemaSQL: sqlSchema, 126 } 127 if err := clusterInstance.StartUnshardedKeyspace(*keyspace, 1, false); err != nil { 128 return 1 129 } 130 131 // Start vtgate 132 if err := clusterInstance.StartVtgate(); err != nil { 133 clusterInstance.VtgateProcess = cluster.VtgateProcess{} 134 return 1 135 } 136 vtgateGrpcAddress = fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateGrpcPort) 137 138 return m.Run() 139 }() 140 os.Exit(exitcode) 141 } 142 143 // TestAuthenticatedUserWithAccess verifies that an authenticated gRPC static user with ACL access can execute queries 144 func TestAuthenticatedUserWithAccess(t *testing.T) { 145 ctx, cancel := context.WithCancel(context.Background()) 146 defer cancel() 147 148 vtgateConn, err := dialVTGate(ctx, t, "user_with_access", "test_password") 149 if err != nil { 150 t.Fatal(err) 151 } 152 defer vtgateConn.Close() 153 154 session := vtgateConn.Session(keyspaceName+"@primary", nil) 155 query := "SELECT id FROM test_table" 156 _, err = session.Execute(ctx, query, nil) 157 assert.NoError(t, err) 158 } 159 160 // TestAuthenticatedUserNoAccess verifies that an authenticated gRPC static user with no ACL access cannot execute queries 161 func TestAuthenticatedUserNoAccess(t *testing.T) { 162 ctx, cancel := context.WithCancel(context.Background()) 163 defer cancel() 164 165 vtgateConn, err := dialVTGate(ctx, t, "user_no_access", "test_password") 166 if err != nil { 167 t.Fatal(err) 168 } 169 defer vtgateConn.Close() 170 171 session := vtgateConn.Session(keyspaceName+"@primary", nil) 172 query := "SELECT id FROM test_table" 173 _, err = session.Execute(ctx, query, nil) 174 require.Error(t, err) 175 assert.Contains(t, err.Error(), "Select command denied to user") 176 assert.Contains(t, err.Error(), "for table 'test_table' (ACL check error)") 177 } 178 179 // TestUnauthenticatedUser verifies that an unauthenticated gRPC user cannot execute queries 180 func TestUnauthenticatedUser(t *testing.T) { 181 ctx, cancel := context.WithCancel(context.Background()) 182 defer cancel() 183 184 vtgateConn, err := dialVTGate(ctx, t, "", "") 185 if err != nil { 186 t.Fatal(err) 187 } 188 defer vtgateConn.Close() 189 190 session := vtgateConn.Session(keyspaceName+"@primary", nil) 191 query := "SELECT id FROM test_table" 192 _, err = session.Execute(ctx, query, nil) 193 require.Error(t, err) 194 assert.Contains(t, err.Error(), "invalid credentials") 195 } 196 197 func dialVTGate(ctx context.Context, t *testing.T, username string, password string) (*vtgateconn.VTGateConn, error) { 198 clientCreds := &grpcclient.StaticAuthClientCreds{Username: username, Password: password} 199 creds := grpc.WithPerRPCCredentials(clientCreds) 200 dialerFunc := grpcvtgateconn.DialWithOpts(ctx, creds) 201 dialerName := t.Name() 202 vtgateconn.RegisterDialer(dialerName, dialerFunc) 203 return vtgateconn.DialProtocol(ctx, dialerName, vtgateGrpcAddress) 204 } 205 206 func createFile(path string, contents string) error { 207 f, err := os.Create(path) 208 if err != nil { 209 return err 210 } 211 _, err = f.WriteString(contents) 212 if err != nil { 213 return err 214 } 215 return f.Close() 216 }