vitess.io/vitess@v0.16.2/go/vt/grpcoptionaltls/server_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 */ 15 package grpcoptionaltls 16 17 import ( 18 "context" 19 "crypto/tls" 20 "net" 21 "testing" 22 "time" 23 24 "google.golang.org/grpc/credentials/insecure" 25 26 "google.golang.org/grpc" 27 "google.golang.org/grpc/credentials" 28 pb "google.golang.org/grpc/examples/helloworld/helloworld" 29 30 "vitess.io/vitess/go/vt/tlstest" 31 ) 32 33 // server is used to implement helloworld.GreeterServer. 34 type server struct { 35 pb.UnimplementedGreeterServer 36 } 37 38 // SayHello implements helloworld.GreeterServer 39 func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { 40 return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil 41 } 42 43 func createUnstartedServer(creds credentials.TransportCredentials) *grpc.Server { 44 s := grpc.NewServer(grpc.Creds(creds)) 45 pb.RegisterGreeterServer(s, &server{}) 46 return s 47 } 48 49 type testCredentials struct { 50 client credentials.TransportCredentials 51 server credentials.TransportCredentials 52 } 53 54 func createCredentials(t *testing.T) (*testCredentials, error) { 55 // Create a temporary directory. 56 certDir := t.TempDir() 57 58 certs := tlstest.CreateClientServerCertPairs(certDir) 59 cert, err := tls.LoadX509KeyPair(certs.ServerCert, certs.ServerKey) 60 if err != nil { 61 return nil, err 62 } 63 64 clientCredentials, err := credentials.NewClientTLSFromFile(certs.ServerCA, certs.ServerName) 65 if err != nil { 66 return nil, err 67 } 68 tc := &testCredentials{ 69 client: clientCredentials, 70 server: credentials.NewServerTLSFromCert(&cert), 71 } 72 return tc, nil 73 } 74 75 func TestOptionalTLS(t *testing.T) { 76 testCtx, testCancel := context.WithCancel(context.Background()) 77 defer testCancel() 78 79 tc, err := createCredentials(t) 80 if err != nil { 81 t.Fatalf("failed to create credentials %v", err) 82 } 83 84 lis, err := net.Listen("tcp", "127.0.0.1:0") 85 if err != nil { 86 t.Fatalf("failed to listen %v", err) 87 } 88 defer lis.Close() 89 addr := lis.Addr().String() 90 91 srv := createUnstartedServer(New(tc.server)) 92 go func() { 93 srv.Serve(lis) 94 }() 95 defer srv.Stop() 96 97 testFunc := func(t *testing.T, dialOpt grpc.DialOption) { 98 ctx, cancel := context.WithTimeout(testCtx, 5*time.Second) 99 defer cancel() 100 conn, err := grpc.DialContext(ctx, addr, dialOpt) 101 if err != nil { 102 t.Fatalf("failed to connect to the server %v", err) 103 } 104 defer conn.Close() 105 c := pb.NewGreeterClient(conn) 106 resp, err := c.SayHello(ctx, &pb.HelloRequest{Name: "Vittes"}) 107 if err != nil { 108 t.Fatalf("could not greet: %v", err) 109 } 110 if resp.Message != "Hello Vittes" { 111 t.Fatalf("unexpected reply %s", resp.Message) 112 } 113 } 114 115 t.Run("Plain2TLS", func(t *testing.T) { 116 for i := 0; i < 5; i++ { 117 testFunc(t, grpc.WithTransportCredentials(insecure.NewCredentials())) 118 } 119 }) 120 t.Run("TLS2TLS", func(t *testing.T) { 121 for i := 0; i < 5; i++ { 122 testFunc(t, grpc.WithTransportCredentials(tc.client)) 123 } 124 }) 125 }