vitess.io/vitess@v0.16.2/go/test/endtoend/vtgate/grpc_server_auth_static/main_test.go (about)

     1  /*
     2  Copyright 2023 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package grpcserverauthstatic
    18  
    19  import (
    20  	"context"
    21  	"flag"
    22  	"fmt"
    23  	"os"
    24  	"path"
    25  	"testing"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"google.golang.org/grpc"
    30  
    31  	"vitess.io/vitess/go/test/endtoend/cluster"
    32  	"vitess.io/vitess/go/vt/grpcclient"
    33  	"vitess.io/vitess/go/vt/vtgate/grpcvtgateconn"
    34  	"vitess.io/vitess/go/vt/vtgate/vtgateconn"
    35  )
    36  
    37  var (
    38  	clusterInstance   *cluster.LocalProcessCluster
    39  	vtgateGrpcAddress string
    40  	hostname          = "localhost"
    41  	keyspaceName      = "ks"
    42  	cell              = "zone1"
    43  	sqlSchema         = `
    44  		create table test_table (
    45  			id bigint,
    46  			val varchar(128),
    47  			primary key(id)
    48  		) Engine=InnoDB;
    49  `
    50  	grpcServerAuthStaticJSON = `
    51  		[
    52  		  {
    53  			"Username": "user_with_access",
    54  			"Password": "test_password"
    55  		  },
    56  		  {
    57  			"Username": "user_no_access",
    58  			"Password": "test_password"
    59  		  }
    60  		]
    61  `
    62  	tableACLJSON = `
    63  		{
    64  		  "table_groups": [
    65  			{
    66  			  "name": "default",
    67  			  "table_names_or_prefixes": ["%"],
    68  			  "readers": ["user_with_access"],
    69  			  "writers": ["user_with_access"],
    70  			  "admins": ["user_with_access"]
    71  			}
    72  		  ]
    73  		}
    74  `
    75  )
    76  
    77  func TestMain(m *testing.M) {
    78  	defer cluster.PanicHandler(nil)
    79  	flag.Parse()
    80  
    81  	exitcode := func() int {
    82  		clusterInstance = cluster.NewCluster(cell, hostname)
    83  		defer clusterInstance.Teardown()
    84  
    85  		// Start topo server
    86  		if err := clusterInstance.StartTopo(); err != nil {
    87  			return 1
    88  		}
    89  
    90  		// Directory for authn / authz config files
    91  		authDirectory := path.Join(clusterInstance.TmpDirectory, "auth")
    92  		if err := os.Mkdir(authDirectory, 0700); err != nil {
    93  			return 1
    94  		}
    95  
    96  		// Create grpc_server_auth_static.json file
    97  		grpcServerAuthStaticPath := path.Join(authDirectory, "grpc_server_auth_static.json")
    98  		if err := createFile(grpcServerAuthStaticPath, grpcServerAuthStaticJSON); err != nil {
    99  			return 1
   100  		}
   101  
   102  		// Create table_acl.json file
   103  		tableACLPath := path.Join(authDirectory, "table_acl.json")
   104  		if err := createFile(tableACLPath, tableACLJSON); err != nil {
   105  			return 1
   106  		}
   107  
   108  		// Configure vtgate to use static auth
   109  		clusterInstance.VtGateExtraArgs = []string{
   110  			"--grpc_auth_mode", "static",
   111  			"--grpc_auth_static_password_file", grpcServerAuthStaticPath,
   112  			"--grpc-use-static-authentication-callerid",
   113  		}
   114  
   115  		// Configure vttablet to use table ACL
   116  		clusterInstance.VtTabletExtraArgs = []string{
   117  			"--enforce-tableacl-config",
   118  			"--queryserver-config-strict-table-acl",
   119  			"--table-acl-config", tableACLPath,
   120  		}
   121  
   122  		// Start keyspace
   123  		keyspace := &cluster.Keyspace{
   124  			Name:      keyspaceName,
   125  			SchemaSQL: sqlSchema,
   126  		}
   127  		if err := clusterInstance.StartUnshardedKeyspace(*keyspace, 1, false); err != nil {
   128  			return 1
   129  		}
   130  
   131  		// Start vtgate
   132  		if err := clusterInstance.StartVtgate(); err != nil {
   133  			clusterInstance.VtgateProcess = cluster.VtgateProcess{}
   134  			return 1
   135  		}
   136  		vtgateGrpcAddress = fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateGrpcPort)
   137  
   138  		return m.Run()
   139  	}()
   140  	os.Exit(exitcode)
   141  }
   142  
   143  // TestAuthenticatedUserWithAccess verifies that an authenticated gRPC static user with ACL access can execute queries
   144  func TestAuthenticatedUserWithAccess(t *testing.T) {
   145  	ctx, cancel := context.WithCancel(context.Background())
   146  	defer cancel()
   147  
   148  	vtgateConn, err := dialVTGate(ctx, t, "user_with_access", "test_password")
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  	defer vtgateConn.Close()
   153  
   154  	session := vtgateConn.Session(keyspaceName+"@primary", nil)
   155  	query := "SELECT id FROM test_table"
   156  	_, err = session.Execute(ctx, query, nil)
   157  	assert.NoError(t, err)
   158  }
   159  
   160  // TestAuthenticatedUserNoAccess verifies that an authenticated gRPC static user with no ACL access cannot execute queries
   161  func TestAuthenticatedUserNoAccess(t *testing.T) {
   162  	ctx, cancel := context.WithCancel(context.Background())
   163  	defer cancel()
   164  
   165  	vtgateConn, err := dialVTGate(ctx, t, "user_no_access", "test_password")
   166  	if err != nil {
   167  		t.Fatal(err)
   168  	}
   169  	defer vtgateConn.Close()
   170  
   171  	session := vtgateConn.Session(keyspaceName+"@primary", nil)
   172  	query := "SELECT id FROM test_table"
   173  	_, err = session.Execute(ctx, query, nil)
   174  	require.Error(t, err)
   175  	assert.Contains(t, err.Error(), "Select command denied to user")
   176  	assert.Contains(t, err.Error(), "for table 'test_table' (ACL check error)")
   177  }
   178  
   179  // TestUnauthenticatedUser verifies that an unauthenticated gRPC user cannot execute queries
   180  func TestUnauthenticatedUser(t *testing.T) {
   181  	ctx, cancel := context.WithCancel(context.Background())
   182  	defer cancel()
   183  
   184  	vtgateConn, err := dialVTGate(ctx, t, "", "")
   185  	if err != nil {
   186  		t.Fatal(err)
   187  	}
   188  	defer vtgateConn.Close()
   189  
   190  	session := vtgateConn.Session(keyspaceName+"@primary", nil)
   191  	query := "SELECT id FROM test_table"
   192  	_, err = session.Execute(ctx, query, nil)
   193  	require.Error(t, err)
   194  	assert.Contains(t, err.Error(), "invalid credentials")
   195  }
   196  
   197  func dialVTGate(ctx context.Context, t *testing.T, username string, password string) (*vtgateconn.VTGateConn, error) {
   198  	clientCreds := &grpcclient.StaticAuthClientCreds{Username: username, Password: password}
   199  	creds := grpc.WithPerRPCCredentials(clientCreds)
   200  	dialerFunc := grpcvtgateconn.DialWithOpts(ctx, creds)
   201  	dialerName := t.Name()
   202  	vtgateconn.RegisterDialer(dialerName, dialerFunc)
   203  	return vtgateconn.DialProtocol(ctx, dialerName, vtgateGrpcAddress)
   204  }
   205  
   206  func createFile(path string, contents string) error {
   207  	f, err := os.Create(path)
   208  	if err != nil {
   209  		return err
   210  	}
   211  	_, err = f.WriteString(contents)
   212  	if err != nil {
   213  		return err
   214  	}
   215  	return f.Close()
   216  }