github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/client_test.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm_test 8 9 import ( 10 "bytes" 11 "context" 12 "crypto/tls" 13 "crypto/x509" 14 "net" 15 "testing" 16 "time" 17 18 "github.com/golang/protobuf/proto" 19 "github.com/hechain20/hechain/internal/pkg/comm" 20 "github.com/hechain20/hechain/internal/pkg/comm/testpb" 21 "github.com/pkg/errors" 22 "github.com/stretchr/testify/require" 23 "google.golang.org/grpc" 24 "google.golang.org/grpc/credentials" 25 ) 26 27 const testTimeout = 1 * time.Second // conservative 28 29 type echoServer struct{} 30 31 func (es *echoServer) EchoCall(ctx context.Context, 32 echo *testpb.Echo) (*testpb.Echo, error) { 33 return echo, nil 34 } 35 36 func TestClientConfigDial(t *testing.T) { 37 t.Parallel() 38 testCerts := comm.LoadTestCerts(t) 39 40 l, err := net.Listen("tcp", "127.0.0.1:0") 41 require.NoError(t, err) 42 badAddress := l.Addr().String() 43 defer l.Close() 44 45 certPool := x509.NewCertPool() 46 ok := certPool.AppendCertsFromPEM(testCerts.CAPEM) 47 if !ok { 48 t.Fatal("failed to create test root cert pool") 49 } 50 51 tests := []struct { 52 name string 53 clientAddress string 54 config comm.ClientConfig 55 serverTLS *tls.Config 56 success bool 57 errorMsg string 58 }{ 59 { 60 name: "client / server same port", 61 config: comm.ClientConfig{ 62 DialTimeout: testTimeout, 63 }, 64 success: true, 65 }, 66 { 67 name: "client / server wrong port", 68 clientAddress: badAddress, 69 config: comm.ClientConfig{ 70 DialTimeout: time.Second, 71 }, 72 success: false, 73 errorMsg: "(connection refused|context deadline exceeded)", 74 }, 75 { 76 name: "client / server wrong port but with asynchronous should succeed", 77 clientAddress: badAddress, 78 config: comm.ClientConfig{ 79 AsyncConnect: true, 80 DialTimeout: testTimeout, 81 }, 82 success: true, 83 }, 84 { 85 name: "client TLS / server no TLS", 86 config: comm.ClientConfig{ 87 SecOpts: comm.SecureOptions{ 88 Certificate: testCerts.CertPEM, 89 Key: testCerts.KeyPEM, 90 UseTLS: true, 91 ServerRootCAs: [][]byte{testCerts.CAPEM}, 92 RequireClientCert: true, 93 }, 94 DialTimeout: time.Second, 95 }, 96 success: false, 97 errorMsg: "context deadline exceeded", 98 }, 99 { 100 name: "client TLS / server TLS match", 101 config: comm.ClientConfig{ 102 SecOpts: comm.SecureOptions{ 103 Certificate: testCerts.CertPEM, 104 Key: testCerts.KeyPEM, 105 UseTLS: true, 106 ServerRootCAs: [][]byte{testCerts.CAPEM}, 107 }, 108 DialTimeout: testTimeout, 109 }, 110 serverTLS: &tls.Config{ 111 Certificates: []tls.Certificate{testCerts.ServerCert}, 112 }, 113 success: true, 114 }, 115 { 116 name: "client TLS / server TLS no server roots", 117 config: comm.ClientConfig{ 118 SecOpts: comm.SecureOptions{ 119 Certificate: testCerts.CertPEM, 120 Key: testCerts.KeyPEM, 121 UseTLS: true, 122 ServerRootCAs: [][]byte{}, 123 }, 124 DialTimeout: testTimeout, 125 }, 126 serverTLS: &tls.Config{ 127 Certificates: []tls.Certificate{testCerts.ServerCert}, 128 }, 129 success: false, 130 errorMsg: "context deadline exceeded", 131 }, 132 { 133 name: "client TLS / server TLS missing client cert", 134 config: comm.ClientConfig{ 135 SecOpts: comm.SecureOptions{ 136 Certificate: testCerts.CertPEM, 137 Key: testCerts.KeyPEM, 138 UseTLS: true, 139 ServerRootCAs: [][]byte{testCerts.CAPEM}, 140 }, 141 DialTimeout: testTimeout, 142 }, 143 serverTLS: &tls.Config{ 144 Certificates: []tls.Certificate{testCerts.ServerCert}, 145 ClientAuth: tls.RequireAndVerifyClientCert, 146 MaxVersion: tls.VersionTLS12, // https://github.com/golang/go/issues/33368 147 }, 148 success: false, 149 errorMsg: "tls: bad certificate", 150 }, 151 { 152 name: "client TLS / server TLS client cert", 153 config: comm.ClientConfig{ 154 SecOpts: comm.SecureOptions{ 155 Certificate: testCerts.CertPEM, 156 Key: testCerts.KeyPEM, 157 UseTLS: true, 158 RequireClientCert: true, 159 ServerRootCAs: [][]byte{testCerts.CAPEM}, 160 }, 161 DialTimeout: testTimeout, 162 }, 163 serverTLS: &tls.Config{ 164 Certificates: []tls.Certificate{testCerts.ServerCert}, 165 ClientAuth: tls.RequireAndVerifyClientCert, 166 ClientCAs: certPool, 167 }, 168 success: true, 169 }, 170 { 171 name: "server TLS pinning success", 172 config: comm.ClientConfig{ 173 SecOpts: comm.SecureOptions{ 174 VerifyCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 175 if bytes.Equal(rawCerts[0], testCerts.ServerCert.Certificate[0]) { 176 return nil 177 } 178 panic("mismatched certificate") 179 }, 180 Certificate: testCerts.CertPEM, 181 Key: testCerts.KeyPEM, 182 UseTLS: true, 183 RequireClientCert: true, 184 ServerRootCAs: [][]byte{testCerts.CAPEM}, 185 }, 186 DialTimeout: testTimeout, 187 }, 188 serverTLS: &tls.Config{ 189 Certificates: []tls.Certificate{testCerts.ServerCert}, 190 ClientAuth: tls.RequireAndVerifyClientCert, 191 ClientCAs: certPool, 192 }, 193 success: true, 194 }, 195 { 196 name: "server TLS pinning failure", 197 config: comm.ClientConfig{ 198 SecOpts: comm.SecureOptions{ 199 VerifyCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 200 return errors.New("TLS certificate mismatch") 201 }, 202 Certificate: testCerts.CertPEM, 203 Key: testCerts.KeyPEM, 204 UseTLS: true, 205 RequireClientCert: true, 206 ServerRootCAs: [][]byte{testCerts.CAPEM}, 207 }, 208 DialTimeout: testTimeout, 209 }, 210 serverTLS: &tls.Config{ 211 Certificates: []tls.Certificate{testCerts.ServerCert}, 212 ClientAuth: tls.RequireAndVerifyClientCert, 213 ClientCAs: certPool, 214 }, 215 success: false, 216 errorMsg: "context deadline exceeded", 217 }, 218 } 219 220 for _, test := range tests { 221 test := test 222 t.Run(test.name, func(t *testing.T) { 223 t.Parallel() 224 lis, err := net.Listen("tcp", "127.0.0.1:0") 225 if err != nil { 226 t.Fatalf("error creating server for test: %v", err) 227 } 228 defer lis.Close() 229 serverOpts := []grpc.ServerOption{} 230 if test.serverTLS != nil { 231 serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(test.serverTLS))) 232 } 233 srv := grpc.NewServer(serverOpts...) 234 defer srv.Stop() 235 go srv.Serve(lis) 236 address := lis.Addr().String() 237 if test.clientAddress != "" { 238 address = test.clientAddress 239 } 240 conn, err := test.config.Dial(address) 241 if test.success { 242 require.NoError(t, err) 243 require.NotNil(t, conn) 244 } else { 245 t.Log(errors.WithStack(err)) 246 require.Regexp(t, test.errorMsg, err.Error()) 247 } 248 }) 249 } 250 } 251 252 func TestSetMessageSize(t *testing.T) { 253 t.Parallel() 254 255 // setup test server 256 lis, err := net.Listen("tcp", "127.0.0.1:0") 257 if err != nil { 258 t.Fatalf("failed to create listener for test server: %v", err) 259 } 260 srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{}) 261 if err != nil { 262 t.Fatalf("failed to create test server: %v", err) 263 } 264 testpb.RegisterEchoServiceServer(srv.Server(), &echoServer{}) 265 defer srv.Stop() 266 go srv.Start() 267 268 tests := []struct { 269 name string 270 maxRecvSize int 271 maxSendSize int 272 failRecv bool 273 failSend bool 274 }{ 275 { 276 name: "defaults should pass", 277 failRecv: false, 278 failSend: false, 279 }, 280 { 281 name: "non-defaults should pass", 282 failRecv: false, 283 failSend: false, 284 maxRecvSize: 20, 285 maxSendSize: 20, 286 }, 287 { 288 name: "recv should fail", 289 failRecv: true, 290 failSend: false, 291 maxRecvSize: 1, 292 }, 293 { 294 name: "send should fail", 295 failRecv: false, 296 failSend: true, 297 maxSendSize: 1, 298 }, 299 } 300 301 // run tests 302 for _, test := range tests { 303 test := test 304 address := lis.Addr().String() 305 t.Run(test.name, func(t *testing.T) { 306 t.Log(test.name) 307 config := comm.ClientConfig{ 308 DialTimeout: testTimeout, 309 MaxRecvMsgSize: test.maxRecvSize, 310 MaxSendMsgSize: test.maxSendSize, 311 } 312 conn, err := config.Dial(address) 313 require.NoError(t, err) 314 defer conn.Close() 315 // create service client from conn 316 svcClient := testpb.NewEchoServiceClient(conn) 317 callCtx := context.Background() 318 callCtx, cancel := context.WithTimeout(callCtx, testTimeout) 319 defer cancel() 320 // invoke service 321 echo := &testpb.Echo{ 322 Payload: []byte{0, 0, 0, 0, 0}, 323 } 324 resp, err := svcClient.EchoCall(callCtx, echo) 325 if !test.failRecv && !test.failSend { 326 require.NoError(t, err) 327 require.True(t, proto.Equal(echo, resp)) 328 } 329 if test.failSend { 330 t.Logf("send error: %v", err) 331 require.Contains(t, err.Error(), "trying to send message larger than max") 332 } 333 if test.failRecv { 334 t.Logf("recv error: %v", err) 335 require.Contains(t, err.Error(), "received message larger than max") 336 } 337 }) 338 } 339 }