vitess.io/vitess@v0.16.2/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package grpcvtgateconn 18 19 import ( 20 "context" 21 "io" 22 "net" 23 "os" 24 "testing" 25 26 "github.com/spf13/pflag" 27 "github.com/stretchr/testify/require" 28 "google.golang.org/grpc" 29 30 "vitess.io/vitess/go/vt/grpcclient" 31 "vitess.io/vitess/go/vt/servenv" 32 "vitess.io/vitess/go/vt/vtgate/grpcvtgateservice" 33 "vitess.io/vitess/go/vt/vtgate/vtgateconn" 34 ) 35 36 // TestGRPCVTGateConn makes sure the grpc service works 37 func TestGRPCVTGateConn(t *testing.T) { 38 // fake service 39 service := CreateFakeServer(t) 40 41 // listen on a random port 42 listener, err := net.Listen("tcp", "127.0.0.1:0") 43 if err != nil { 44 t.Fatalf("Cannot listen: %v", err) 45 } 46 47 // Create a gRPC server and listen on the port 48 server := grpc.NewServer() 49 grpcvtgateservice.RegisterForTest(server, service) 50 go server.Serve(listener) 51 52 // Create a Go RPC client connecting to the server 53 ctx := context.Background() 54 client, err := dial(ctx, listener.Addr().String()) 55 if err != nil { 56 t.Fatalf("dial failed: %v", err) 57 } 58 RegisterTestDialProtocol(client) 59 60 // run the test suite 61 RunTests(t, client, service) 62 RunErrorTests(t, service) 63 64 // and clean up 65 client.Close() 66 } 67 68 // TestGRPCVTGateConnAuth makes sure the grpc with auth plugin works 69 func TestGRPCVTGateConnAuth(t *testing.T) { 70 var opts []grpc.ServerOption 71 // fake service 72 service := CreateFakeServer(t) 73 74 // listen on a random port 75 listener, err := net.Listen("tcp", "127.0.0.1:0") 76 if err != nil { 77 t.Fatalf("Cannot listen: %v", err) 78 } 79 80 // add auth interceptors 81 opts = append(opts, grpc.StreamInterceptor(servenv.FakeAuthStreamInterceptor)) 82 opts = append(opts, grpc.UnaryInterceptor(servenv.FakeAuthUnaryInterceptor)) 83 84 // Create a gRPC server and listen on the port 85 server := grpc.NewServer(opts...) 86 grpcvtgateservice.RegisterForTest(server, service) 87 go server.Serve(listener) 88 89 authJSON := `{ 90 "Username": "valid", 91 "Password": "valid" 92 }` 93 94 f, err := os.CreateTemp("", "static_auth_creds.json") 95 if err != nil { 96 t.Fatal(err) 97 } 98 defer os.Remove(f.Name()) 99 if _, err := io.WriteString(f, authJSON); err != nil { 100 t.Fatal(err) 101 } 102 if err := f.Close(); err != nil { 103 t.Fatal(err) 104 } 105 106 // Create a Go RPC client connecting to the server 107 ctx := context.Background() 108 fs := pflag.NewFlagSet("", pflag.ContinueOnError) 109 grpcclient.RegisterFlags(fs) 110 111 err = fs.Parse([]string{ 112 "--grpc_auth_static_client_creds", 113 f.Name(), 114 }) 115 require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", f.Name()) 116 client, err := dial(ctx, listener.Addr().String()) 117 if err != nil { 118 t.Fatalf("dial failed: %v", err) 119 } 120 RegisterTestDialProtocol(client) 121 122 // run the test suite 123 RunTests(t, client, service) 124 RunErrorTests(t, service) 125 126 // and clean up 127 client.Close() 128 129 invalidAuthJSON := `{ 130 "Username": "invalid", 131 "Password": "valid" 132 }` 133 134 f, err = os.CreateTemp("", "static_auth_creds.json") 135 if err != nil { 136 t.Fatal(err) 137 } 138 defer os.Remove(f.Name()) 139 if _, err := io.WriteString(f, invalidAuthJSON); err != nil { 140 t.Fatal(err) 141 } 142 if err := f.Close(); err != nil { 143 t.Fatal(err) 144 } 145 146 // Create a Go RPC client connecting to the server 147 ctx = context.Background() 148 fs = pflag.NewFlagSet("", pflag.ContinueOnError) 149 grpcclient.RegisterFlags(fs) 150 151 err = fs.Parse([]string{ 152 "--grpc_auth_static_client_creds", 153 f.Name(), 154 }) 155 require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", f.Name()) 156 client, err = dial(ctx, listener.Addr().String()) 157 if err != nil { 158 t.Fatalf("dial failed: %v", err) 159 } 160 RegisterTestDialProtocol(client) 161 conn, _ := vtgateconn.DialProtocol(context.Background(), "test", "") 162 // run the test suite 163 _, err = conn.Session("", nil).Execute(context.Background(), "select * from t", nil) 164 want := "rpc error: code = Unauthenticated desc = username and password must be provided" 165 if err == nil || err.Error() != want { 166 t.Errorf("expected auth failure:\n%v, want\n%s", err, want) 167 } 168 // and clean up again 169 client.Close() 170 }