github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/client_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm_test
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"net"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/golang/protobuf/proto"
    19  	"github.com/hechain20/hechain/internal/pkg/comm"
    20  	"github.com/hechain20/hechain/internal/pkg/comm/testpb"
    21  	"github.com/pkg/errors"
    22  	"github.com/stretchr/testify/require"
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/grpc/credentials"
    25  )
    26  
    27  const testTimeout = 1 * time.Second // conservative
    28  
    29  type echoServer struct{}
    30  
    31  func (es *echoServer) EchoCall(ctx context.Context,
    32  	echo *testpb.Echo) (*testpb.Echo, error) {
    33  	return echo, nil
    34  }
    35  
    36  func TestClientConfigDial(t *testing.T) {
    37  	t.Parallel()
    38  	testCerts := comm.LoadTestCerts(t)
    39  
    40  	l, err := net.Listen("tcp", "127.0.0.1:0")
    41  	require.NoError(t, err)
    42  	badAddress := l.Addr().String()
    43  	defer l.Close()
    44  
    45  	certPool := x509.NewCertPool()
    46  	ok := certPool.AppendCertsFromPEM(testCerts.CAPEM)
    47  	if !ok {
    48  		t.Fatal("failed to create test root cert pool")
    49  	}
    50  
    51  	tests := []struct {
    52  		name          string
    53  		clientAddress string
    54  		config        comm.ClientConfig
    55  		serverTLS     *tls.Config
    56  		success       bool
    57  		errorMsg      string
    58  	}{
    59  		{
    60  			name: "client / server same port",
    61  			config: comm.ClientConfig{
    62  				DialTimeout: testTimeout,
    63  			},
    64  			success: true,
    65  		},
    66  		{
    67  			name:          "client / server wrong port",
    68  			clientAddress: badAddress,
    69  			config: comm.ClientConfig{
    70  				DialTimeout: time.Second,
    71  			},
    72  			success:  false,
    73  			errorMsg: "(connection refused|context deadline exceeded)",
    74  		},
    75  		{
    76  			name:          "client / server wrong port but with asynchronous should succeed",
    77  			clientAddress: badAddress,
    78  			config: comm.ClientConfig{
    79  				AsyncConnect: true,
    80  				DialTimeout:  testTimeout,
    81  			},
    82  			success: true,
    83  		},
    84  		{
    85  			name: "client TLS / server no TLS",
    86  			config: comm.ClientConfig{
    87  				SecOpts: comm.SecureOptions{
    88  					Certificate:       testCerts.CertPEM,
    89  					Key:               testCerts.KeyPEM,
    90  					UseTLS:            true,
    91  					ServerRootCAs:     [][]byte{testCerts.CAPEM},
    92  					RequireClientCert: true,
    93  				},
    94  				DialTimeout: time.Second,
    95  			},
    96  			success:  false,
    97  			errorMsg: "context deadline exceeded",
    98  		},
    99  		{
   100  			name: "client TLS / server TLS match",
   101  			config: comm.ClientConfig{
   102  				SecOpts: comm.SecureOptions{
   103  					Certificate:   testCerts.CertPEM,
   104  					Key:           testCerts.KeyPEM,
   105  					UseTLS:        true,
   106  					ServerRootCAs: [][]byte{testCerts.CAPEM},
   107  				},
   108  				DialTimeout: testTimeout,
   109  			},
   110  			serverTLS: &tls.Config{
   111  				Certificates: []tls.Certificate{testCerts.ServerCert},
   112  			},
   113  			success: true,
   114  		},
   115  		{
   116  			name: "client TLS / server TLS no server roots",
   117  			config: comm.ClientConfig{
   118  				SecOpts: comm.SecureOptions{
   119  					Certificate:   testCerts.CertPEM,
   120  					Key:           testCerts.KeyPEM,
   121  					UseTLS:        true,
   122  					ServerRootCAs: [][]byte{},
   123  				},
   124  				DialTimeout: testTimeout,
   125  			},
   126  			serverTLS: &tls.Config{
   127  				Certificates: []tls.Certificate{testCerts.ServerCert},
   128  			},
   129  			success:  false,
   130  			errorMsg: "context deadline exceeded",
   131  		},
   132  		{
   133  			name: "client TLS / server TLS missing client cert",
   134  			config: comm.ClientConfig{
   135  				SecOpts: comm.SecureOptions{
   136  					Certificate:   testCerts.CertPEM,
   137  					Key:           testCerts.KeyPEM,
   138  					UseTLS:        true,
   139  					ServerRootCAs: [][]byte{testCerts.CAPEM},
   140  				},
   141  				DialTimeout: testTimeout,
   142  			},
   143  			serverTLS: &tls.Config{
   144  				Certificates: []tls.Certificate{testCerts.ServerCert},
   145  				ClientAuth:   tls.RequireAndVerifyClientCert,
   146  				MaxVersion:   tls.VersionTLS12, // https://github.com/golang/go/issues/33368
   147  			},
   148  			success:  false,
   149  			errorMsg: "tls: bad certificate",
   150  		},
   151  		{
   152  			name: "client TLS / server TLS client cert",
   153  			config: comm.ClientConfig{
   154  				SecOpts: comm.SecureOptions{
   155  					Certificate:       testCerts.CertPEM,
   156  					Key:               testCerts.KeyPEM,
   157  					UseTLS:            true,
   158  					RequireClientCert: true,
   159  					ServerRootCAs:     [][]byte{testCerts.CAPEM},
   160  				},
   161  				DialTimeout: testTimeout,
   162  			},
   163  			serverTLS: &tls.Config{
   164  				Certificates: []tls.Certificate{testCerts.ServerCert},
   165  				ClientAuth:   tls.RequireAndVerifyClientCert,
   166  				ClientCAs:    certPool,
   167  			},
   168  			success: true,
   169  		},
   170  		{
   171  			name: "server TLS pinning success",
   172  			config: comm.ClientConfig{
   173  				SecOpts: comm.SecureOptions{
   174  					VerifyCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   175  						if bytes.Equal(rawCerts[0], testCerts.ServerCert.Certificate[0]) {
   176  							return nil
   177  						}
   178  						panic("mismatched certificate")
   179  					},
   180  					Certificate:       testCerts.CertPEM,
   181  					Key:               testCerts.KeyPEM,
   182  					UseTLS:            true,
   183  					RequireClientCert: true,
   184  					ServerRootCAs:     [][]byte{testCerts.CAPEM},
   185  				},
   186  				DialTimeout: testTimeout,
   187  			},
   188  			serverTLS: &tls.Config{
   189  				Certificates: []tls.Certificate{testCerts.ServerCert},
   190  				ClientAuth:   tls.RequireAndVerifyClientCert,
   191  				ClientCAs:    certPool,
   192  			},
   193  			success: true,
   194  		},
   195  		{
   196  			name: "server TLS pinning failure",
   197  			config: comm.ClientConfig{
   198  				SecOpts: comm.SecureOptions{
   199  					VerifyCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   200  						return errors.New("TLS certificate mismatch")
   201  					},
   202  					Certificate:       testCerts.CertPEM,
   203  					Key:               testCerts.KeyPEM,
   204  					UseTLS:            true,
   205  					RequireClientCert: true,
   206  					ServerRootCAs:     [][]byte{testCerts.CAPEM},
   207  				},
   208  				DialTimeout: testTimeout,
   209  			},
   210  			serverTLS: &tls.Config{
   211  				Certificates: []tls.Certificate{testCerts.ServerCert},
   212  				ClientAuth:   tls.RequireAndVerifyClientCert,
   213  				ClientCAs:    certPool,
   214  			},
   215  			success:  false,
   216  			errorMsg: "context deadline exceeded",
   217  		},
   218  	}
   219  
   220  	for _, test := range tests {
   221  		test := test
   222  		t.Run(test.name, func(t *testing.T) {
   223  			t.Parallel()
   224  			lis, err := net.Listen("tcp", "127.0.0.1:0")
   225  			if err != nil {
   226  				t.Fatalf("error creating server for test: %v", err)
   227  			}
   228  			defer lis.Close()
   229  			serverOpts := []grpc.ServerOption{}
   230  			if test.serverTLS != nil {
   231  				serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(test.serverTLS)))
   232  			}
   233  			srv := grpc.NewServer(serverOpts...)
   234  			defer srv.Stop()
   235  			go srv.Serve(lis)
   236  			address := lis.Addr().String()
   237  			if test.clientAddress != "" {
   238  				address = test.clientAddress
   239  			}
   240  			conn, err := test.config.Dial(address)
   241  			if test.success {
   242  				require.NoError(t, err)
   243  				require.NotNil(t, conn)
   244  			} else {
   245  				t.Log(errors.WithStack(err))
   246  				require.Regexp(t, test.errorMsg, err.Error())
   247  			}
   248  		})
   249  	}
   250  }
   251  
   252  func TestSetMessageSize(t *testing.T) {
   253  	t.Parallel()
   254  
   255  	// setup test server
   256  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   257  	if err != nil {
   258  		t.Fatalf("failed to create listener for test server: %v", err)
   259  	}
   260  	srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{})
   261  	if err != nil {
   262  		t.Fatalf("failed to create test server: %v", err)
   263  	}
   264  	testpb.RegisterEchoServiceServer(srv.Server(), &echoServer{})
   265  	defer srv.Stop()
   266  	go srv.Start()
   267  
   268  	tests := []struct {
   269  		name        string
   270  		maxRecvSize int
   271  		maxSendSize int
   272  		failRecv    bool
   273  		failSend    bool
   274  	}{
   275  		{
   276  			name:     "defaults should pass",
   277  			failRecv: false,
   278  			failSend: false,
   279  		},
   280  		{
   281  			name:        "non-defaults should pass",
   282  			failRecv:    false,
   283  			failSend:    false,
   284  			maxRecvSize: 20,
   285  			maxSendSize: 20,
   286  		},
   287  		{
   288  			name:        "recv should fail",
   289  			failRecv:    true,
   290  			failSend:    false,
   291  			maxRecvSize: 1,
   292  		},
   293  		{
   294  			name:        "send should fail",
   295  			failRecv:    false,
   296  			failSend:    true,
   297  			maxSendSize: 1,
   298  		},
   299  	}
   300  
   301  	// run tests
   302  	for _, test := range tests {
   303  		test := test
   304  		address := lis.Addr().String()
   305  		t.Run(test.name, func(t *testing.T) {
   306  			t.Log(test.name)
   307  			config := comm.ClientConfig{
   308  				DialTimeout:    testTimeout,
   309  				MaxRecvMsgSize: test.maxRecvSize,
   310  				MaxSendMsgSize: test.maxSendSize,
   311  			}
   312  			conn, err := config.Dial(address)
   313  			require.NoError(t, err)
   314  			defer conn.Close()
   315  			// create service client from conn
   316  			svcClient := testpb.NewEchoServiceClient(conn)
   317  			callCtx := context.Background()
   318  			callCtx, cancel := context.WithTimeout(callCtx, testTimeout)
   319  			defer cancel()
   320  			// invoke service
   321  			echo := &testpb.Echo{
   322  				Payload: []byte{0, 0, 0, 0, 0},
   323  			}
   324  			resp, err := svcClient.EchoCall(callCtx, echo)
   325  			if !test.failRecv && !test.failSend {
   326  				require.NoError(t, err)
   327  				require.True(t, proto.Equal(echo, resp))
   328  			}
   329  			if test.failSend {
   330  				t.Logf("send error: %v", err)
   331  				require.Contains(t, err.Error(), "trying to send message larger than max")
   332  			}
   333  			if test.failRecv {
   334  				t.Logf("recv error: %v", err)
   335  				require.Contains(t, err.Error(), "received message larger than max")
   336  			}
   337  		})
   338  	}
   339  }