vitess.io/vitess@v0.16.2/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package grpcvtgateconn
    18  
    19  import (
    20  	"context"
    21  	"io"
    22  	"net"
    23  	"os"
    24  	"testing"
    25  
    26  	"github.com/spf13/pflag"
    27  	"github.com/stretchr/testify/require"
    28  	"google.golang.org/grpc"
    29  
    30  	"vitess.io/vitess/go/vt/grpcclient"
    31  	"vitess.io/vitess/go/vt/servenv"
    32  	"vitess.io/vitess/go/vt/vtgate/grpcvtgateservice"
    33  	"vitess.io/vitess/go/vt/vtgate/vtgateconn"
    34  )
    35  
    36  // TestGRPCVTGateConn makes sure the grpc service works
    37  func TestGRPCVTGateConn(t *testing.T) {
    38  	// fake service
    39  	service := CreateFakeServer(t)
    40  
    41  	// listen on a random port
    42  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    43  	if err != nil {
    44  		t.Fatalf("Cannot listen: %v", err)
    45  	}
    46  
    47  	// Create a gRPC server and listen on the port
    48  	server := grpc.NewServer()
    49  	grpcvtgateservice.RegisterForTest(server, service)
    50  	go server.Serve(listener)
    51  
    52  	// Create a Go RPC client connecting to the server
    53  	ctx := context.Background()
    54  	client, err := dial(ctx, listener.Addr().String())
    55  	if err != nil {
    56  		t.Fatalf("dial failed: %v", err)
    57  	}
    58  	RegisterTestDialProtocol(client)
    59  
    60  	// run the test suite
    61  	RunTests(t, client, service)
    62  	RunErrorTests(t, service)
    63  
    64  	// and clean up
    65  	client.Close()
    66  }
    67  
    68  // TestGRPCVTGateConnAuth makes sure the grpc with auth plugin works
    69  func TestGRPCVTGateConnAuth(t *testing.T) {
    70  	var opts []grpc.ServerOption
    71  	// fake service
    72  	service := CreateFakeServer(t)
    73  
    74  	// listen on a random port
    75  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    76  	if err != nil {
    77  		t.Fatalf("Cannot listen: %v", err)
    78  	}
    79  
    80  	// add auth interceptors
    81  	opts = append(opts, grpc.StreamInterceptor(servenv.FakeAuthStreamInterceptor))
    82  	opts = append(opts, grpc.UnaryInterceptor(servenv.FakeAuthUnaryInterceptor))
    83  
    84  	// Create a gRPC server and listen on the port
    85  	server := grpc.NewServer(opts...)
    86  	grpcvtgateservice.RegisterForTest(server, service)
    87  	go server.Serve(listener)
    88  
    89  	authJSON := `{
    90           "Username": "valid",
    91           "Password": "valid"
    92          }`
    93  
    94  	f, err := os.CreateTemp("", "static_auth_creds.json")
    95  	if err != nil {
    96  		t.Fatal(err)
    97  	}
    98  	defer os.Remove(f.Name())
    99  	if _, err := io.WriteString(f, authJSON); err != nil {
   100  		t.Fatal(err)
   101  	}
   102  	if err := f.Close(); err != nil {
   103  		t.Fatal(err)
   104  	}
   105  
   106  	// Create a Go RPC client connecting to the server
   107  	ctx := context.Background()
   108  	fs := pflag.NewFlagSet("", pflag.ContinueOnError)
   109  	grpcclient.RegisterFlags(fs)
   110  
   111  	err = fs.Parse([]string{
   112  		"--grpc_auth_static_client_creds",
   113  		f.Name(),
   114  	})
   115  	require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", f.Name())
   116  	client, err := dial(ctx, listener.Addr().String())
   117  	if err != nil {
   118  		t.Fatalf("dial failed: %v", err)
   119  	}
   120  	RegisterTestDialProtocol(client)
   121  
   122  	// run the test suite
   123  	RunTests(t, client, service)
   124  	RunErrorTests(t, service)
   125  
   126  	// and clean up
   127  	client.Close()
   128  
   129  	invalidAuthJSON := `{
   130  	 "Username": "invalid",
   131  	 "Password": "valid"
   132  	}`
   133  
   134  	f, err = os.CreateTemp("", "static_auth_creds.json")
   135  	if err != nil {
   136  		t.Fatal(err)
   137  	}
   138  	defer os.Remove(f.Name())
   139  	if _, err := io.WriteString(f, invalidAuthJSON); err != nil {
   140  		t.Fatal(err)
   141  	}
   142  	if err := f.Close(); err != nil {
   143  		t.Fatal(err)
   144  	}
   145  
   146  	// Create a Go RPC client connecting to the server
   147  	ctx = context.Background()
   148  	fs = pflag.NewFlagSet("", pflag.ContinueOnError)
   149  	grpcclient.RegisterFlags(fs)
   150  
   151  	err = fs.Parse([]string{
   152  		"--grpc_auth_static_client_creds",
   153  		f.Name(),
   154  	})
   155  	require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", f.Name())
   156  	client, err = dial(ctx, listener.Addr().String())
   157  	if err != nil {
   158  		t.Fatalf("dial failed: %v", err)
   159  	}
   160  	RegisterTestDialProtocol(client)
   161  	conn, _ := vtgateconn.DialProtocol(context.Background(), "test", "")
   162  	// run the test suite
   163  	_, err = conn.Session("", nil).Execute(context.Background(), "select * from t", nil)
   164  	want := "rpc error: code = Unauthenticated desc = username and password must be provided"
   165  	if err == nil || err.Error() != want {
   166  		t.Errorf("expected auth failure:\n%v, want\n%s", err, want)
   167  	}
   168  	// and clean up again
   169  	client.Close()
   170  }