github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/server_test.go (about) 1 // Copyright 2022 Edward McFarlane. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package larking 6 7 import ( 8 "bytes" 9 "context" 10 "crypto/rand" 11 "crypto/rsa" 12 "crypto/tls" 13 "crypto/x509" 14 "crypto/x509/pkix" 15 "encoding/pem" 16 "errors" 17 "io/ioutil" 18 "math/big" 19 "net" 20 "net/http" 21 "testing" 22 "time" 23 24 "github.com/go-logr/logr" 25 testing_logr "github.com/go-logr/logr/testing" 26 "github.com/google/go-cmp/cmp" 27 "golang.org/x/sync/errgroup" 28 "google.golang.org/grpc" 29 "google.golang.org/grpc/credentials" 30 "google.golang.org/grpc/credentials/insecure" 31 "google.golang.org/grpc/metadata" 32 "google.golang.org/grpc/reflection" 33 "google.golang.org/protobuf/encoding/protojson" 34 "google.golang.org/protobuf/proto" 35 "google.golang.org/protobuf/testing/protocmp" 36 37 "github.com/emcfarlane/larking/apipb/healthpb" 38 "github.com/emcfarlane/larking/health" 39 "github.com/emcfarlane/larking/testpb" 40 ) 41 42 func testContext(t *testing.T) context.Context { 43 ctx := context.Background() 44 log := testing_logr.NewTestLogger(t) 45 ctx = logr.NewContext(ctx, log) 46 return ctx 47 } 48 49 func TestServer(t *testing.T) { 50 ms := &testpb.UnimplementedMessagingServer{} 51 52 o := &overrides{} 53 gs := grpc.NewServer(o.streamOption(), o.unaryOption()) 54 testpb.RegisterMessagingServer(gs, ms) 55 reflection.Register(gs) 56 57 lis, err := net.Listen("tcp", "localhost:0") 58 if err != nil { 59 t.Fatalf("failed to listen: %v", err) 60 } 61 defer lis.Close() 62 63 var g errgroup.Group 64 defer func() { 65 if err := g.Wait(); err != nil { 66 t.Fatal(err) 67 } 68 }() 69 70 g.Go(func() error { 71 return gs.Serve(lis) 72 }) 73 defer gs.Stop() 74 75 // Create the client. 76 creds := insecure.NewCredentials() 77 conn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(creds)) 78 if err != nil { 79 t.Fatalf("cannot connect to server: %v", err) 80 } 81 defer conn.Close() 82 83 mux, err := NewMux() 84 if err != nil { 85 t.Fatal(err) 86 } 87 if err := mux.RegisterConn(context.Background(), conn); err != nil { 88 t.Fatal(err) 89 } 90 91 ts, err := NewServer(mux, InsecureServerOption()) 92 if err != nil { 93 t.Fatal(err) 94 } 95 96 lisProxy, err := net.Listen("tcp", "localhost:0") 97 if err != nil { 98 t.Fatalf("failed to listen: %v", err) 99 } 100 defer lisProxy.Close() 101 102 g.Go(func() error { 103 if err := ts.Serve(lisProxy); err != nil && err != http.ErrServerClosed { 104 return err 105 } 106 return nil 107 }) 108 defer func() { 109 if err := ts.Shutdown(context.Background()); err != nil { 110 t.Fatal(err) 111 } 112 }() 113 114 cc, err := grpc.Dial( 115 lisProxy.Addr().String(), 116 grpc.WithTransportCredentials(insecure.NewCredentials()), 117 grpc.WithBlock(), 118 ) 119 if err != nil { 120 t.Fatal(err) 121 } 122 123 cmpOpts := cmp.Options{protocmp.Transform()} 124 125 var unaryStreamDesc = &grpc.StreamDesc{ 126 ClientStreams: false, 127 ServerStreams: false, 128 } 129 130 tests := []struct { 131 name string 132 desc *grpc.StreamDesc 133 method string 134 inouts []interface{} 135 //ins []in 136 //outs []out 137 }{{ 138 name: "unary_message", 139 desc: unaryStreamDesc, 140 method: "/larking.testpb.Messaging/GetMessageOne", 141 inouts: []interface{}{ 142 in{msg: &testpb.GetMessageRequestOne{Name: "proxy"}}, 143 out{msg: &testpb.Message{Text: "success"}}, 144 }, 145 }} 146 147 for _, tt := range tests { 148 t.Run(tt.name, func(t *testing.T) { 149 o.reset(t, "test", tt.inouts) 150 151 ctx := testContext(t) 152 ctx = metadata.AppendToOutgoingContext(ctx, "test", tt.method) 153 154 s, err := cc.NewStream(ctx, tt.desc, tt.method) 155 if err != nil { 156 t.Fatal(err) 157 } 158 159 for i := 0; i < len(tt.inouts); i++ { 160 switch typ := tt.inouts[i].(type) { 161 case in: 162 if err := s.SendMsg(typ.msg); err != nil { 163 t.Fatal(err) 164 } 165 case out: 166 out := proto.Clone(typ.msg) 167 if err := s.RecvMsg(out); err != nil { 168 t.Fatal(err) 169 } 170 diff := cmp.Diff(out, typ.msg, cmpOpts...) 171 if diff != "" { 172 t.Fatal(diff) 173 } 174 } 175 } 176 }) 177 } 178 } 179 180 func TestMuxHandleOption(t *testing.T) { 181 mux, err := NewMux() 182 if err != nil { 183 t.Fatal(err) 184 } 185 186 hs := health.NewServer() 187 defer hs.Shutdown() 188 mux.RegisterService(&healthpb.Health_ServiceDesc, hs) 189 190 s, err := NewServer( 191 mux, 192 InsecureServerOption(), 193 MuxHandleOption("/", "/api/", "/pfx"), 194 ) 195 if err != nil { 196 t.Fatal(err) 197 } 198 199 lis, err := net.Listen("tcp", ":0") 200 if err != nil { 201 t.Fatal(err) 202 } 203 defer lis.Close() 204 205 var g errgroup.Group 206 defer func() { 207 if err := g.Wait(); err != nil { 208 t.Fatal(err) 209 } 210 }() 211 212 g.Go(func() (err error) { 213 if err := s.Serve(lis); err != nil && err != http.ErrServerClosed { 214 return err 215 } 216 return nil 217 }) 218 defer func() { 219 if err := s.Shutdown(context.Background()); err != nil { 220 t.Fatal(err) 221 } 222 }() 223 224 for _, tt := range []struct { 225 path string 226 okay bool 227 }{ 228 {"/v1/health", true}, 229 {"/api/v1/health", true}, 230 {"/pfx/v1/health", true}, 231 {"/bad/v1/health", false}, 232 {"/v1/health/bad", false}, 233 } { 234 t.Run(tt.path, func(t *testing.T) { 235 rsp, err := http.Get("http://" + lis.Addr().String() + tt.path) 236 if err != nil { 237 t.Fatal(err) 238 } 239 okay := rsp.StatusCode == 200 240 if okay != tt.okay { 241 t.Errorf("request got %t for %s", okay, tt.path) 242 } 243 }) 244 } 245 } 246 247 func createCertificateAuthority() ([]byte, []byte, error) { 248 ca := &x509.Certificate{ 249 SerialNumber: big.NewInt(2021), 250 Subject: pkix.Name{ 251 Organization: []string{"Acme Co"}, 252 }, 253 NotBefore: time.Now(), 254 NotAfter: time.Now().AddDate(10, 0, 0), 255 IsCA: true, 256 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, 257 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, 258 BasicConstraintsValid: true, 259 } 260 261 caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) 262 if err != nil { 263 return nil, nil, err 264 } 265 caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) 266 if err != nil { 267 return nil, nil, err 268 } 269 270 caPEM := new(bytes.Buffer) 271 if err := pem.Encode(caPEM, &pem.Block{ 272 Type: "CERTIFICATE", 273 Bytes: caBytes, 274 }); err != nil { 275 return nil, nil, err 276 } 277 caPrivKeyPEM := new(bytes.Buffer) 278 if err := pem.Encode(caPrivKeyPEM, &pem.Block{ 279 Type: "RSA PRIVATE KEY", 280 Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey), 281 }); err != nil { 282 return nil, nil, err 283 } 284 return caPEM.Bytes(), caPrivKeyPEM.Bytes(), nil 285 } 286 287 func createCertificate(caCertPEM, caKeyPEM []byte, commonName string) ([]byte, []byte, error) { 288 keyPEMBlock, _ := pem.Decode(caKeyPEM) 289 privateKey, err := x509.ParsePKCS1PrivateKey(keyPEMBlock.Bytes) 290 if err != nil { 291 return nil, nil, err 292 } 293 294 certPEMBlock, _ := pem.Decode(caCertPEM) 295 parent, err := x509.ParseCertificate(certPEMBlock.Bytes) 296 if err != nil { 297 return nil, nil, err 298 } 299 300 cert := &x509.Certificate{ 301 SerialNumber: big.NewInt(1658), 302 Subject: pkix.Name{ 303 Organization: []string{"Acme Co"}, 304 CommonName: commonName, 305 }, 306 IPAddresses: []net.IP{ 307 net.IPv4(127, 0, 0, 1), 308 net.IPv6loopback, 309 net.IPv4(0, 0, 0, 0), 310 net.IPv6zero, 311 }, 312 NotBefore: time.Now(), 313 NotAfter: time.Now().AddDate(10, 0, 0), 314 SubjectKeyId: []byte{1, 2, 3, 4, 6}, 315 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, 316 KeyUsage: x509.KeyUsageDigitalSignature, 317 } 318 319 certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) 320 if err != nil { 321 return nil, nil, err 322 } 323 certBytes, err := x509.CreateCertificate(rand.Reader, cert, parent, &certPrivKey.PublicKey, privateKey) 324 if err != nil { 325 return nil, nil, err 326 } 327 328 certPEM := new(bytes.Buffer) 329 if err := pem.Encode(certPEM, &pem.Block{ 330 Type: "CERTIFICATE", 331 Bytes: certBytes, 332 }); err != nil { 333 return nil, nil, err 334 } 335 certPrivKeyPEM := new(bytes.Buffer) 336 if err := pem.Encode(certPrivKeyPEM, &pem.Block{ 337 Type: "RSA PRIVATE KEY", 338 Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), 339 }); err != nil { 340 return nil, nil, err 341 } 342 return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil 343 } 344 345 func TestTLSServer(t *testing.T) { 346 ctx, cancel := context.WithCancel(testContext(t)) 347 defer cancel() 348 349 // certPool 350 certPool := x509.NewCertPool() 351 caCertPEM, caKeyPEM, err := createCertificateAuthority() 352 if err != nil { 353 t.Fatal(err) 354 } 355 if ok := certPool.AppendCertsFromPEM(caCertPEM); !ok { 356 t.Fatal("failed to append client certs") 357 } 358 359 certPEM, keyPEM, err := createCertificate(caCertPEM, caKeyPEM, "Server") 360 if err != nil { 361 t.Fatal(err) 362 } 363 certificate, err := tls.X509KeyPair(certPEM, keyPEM) 364 if err != nil { 365 t.Fatal(err) 366 } 367 tlsConfig := &tls.Config{ 368 ClientAuth: tls.RequireAndVerifyClientCert, 369 Certificates: []tls.Certificate{certificate}, 370 ClientCAs: certPool, 371 } 372 373 // TODO! 374 verfiyPeer := func(ctx context.Context) error { 375 // p, ok := peer.FromContext(ctx) 376 // if !ok { 377 // return status.Error(codes.Unauthenticated, "no peer found") 378 // } 379 // tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo) 380 // if !ok { 381 // return status.Error(codes.Unauthenticated, "unexpected peer transport credentials") 382 // } 383 // if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 { 384 // return status.Error(codes.Unauthenticated, "could not verify peer certificate") 385 // } 386 // fmt.Println( 387 // "tlsAuth.State.VerifiedChains[0][0].Subject.CommonName", 388 // tlsAuth.State.VerifiedChains[0][0].Subject.CommonName, 389 // ) 390 // // Check subject common name against configured username 391 // if tlsAuth.State.VerifiedChains[0][0].Subject.CommonName != "Client" { 392 // return status.Error(codes.Unauthenticated, "invalid subject common name") 393 // } 394 return nil 395 } 396 397 mux, err := NewMux( 398 UnaryServerInterceptorOption( 399 func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 400 if err := verfiyPeer(ctx); err != nil { 401 return nil, err 402 } 403 return handler(ctx, req) 404 }, 405 ), 406 ) 407 if err != nil { 408 t.Fatal(err) 409 } 410 healthServer := health.NewServer() 411 defer healthServer.Shutdown() 412 mux.RegisterService(&healthpb.Health_ServiceDesc, healthServer) 413 414 s, err := NewServer(mux, 415 TLSCredsOption(tlsConfig), 416 ) 417 if err != nil { 418 t.Fatal(err) 419 } 420 421 l, err := net.Listen("tcp", ":0") 422 if err != nil { 423 t.Fatal(err) 424 } 425 426 g := errgroup.Group{} 427 g.Go(func() error { return s.Serve(l) }) 428 defer func() { 429 if err := s.Shutdown(ctx); err != nil { 430 t.Error(err) 431 } 432 if err := g.Wait(); err != nil && err != http.ErrServerClosed { 433 t.Error(err) 434 } 435 }() 436 437 certPEM, keyPEM, err = createCertificate(caCertPEM, caKeyPEM, "Client") 438 if err != nil { 439 t.Fatal(err) 440 } 441 certificate, err = tls.X509KeyPair(certPEM, keyPEM) 442 if err != nil { 443 t.Fatal(err) 444 } 445 tlsConfig = &tls.Config{ 446 Certificates: []tls.Certificate{certificate}, 447 RootCAs: certPool, 448 } 449 tlsInsecure := &tls.Config{ 450 InsecureSkipVerify: true, 451 } 452 453 t.Run("httpClient", func(t *testing.T) { 454 client := &http.Client{ 455 Transport: &http.Transport{ 456 TLSClientConfig: tlsConfig, 457 }, 458 } 459 rsp, err := client.Get("https://" + l.Addr().String() + "/v1/health") 460 if err != nil { 461 t.Fatal(err) 462 } 463 if rsp.StatusCode != http.StatusOK { 464 t.Fatal("invalid status code", rsp.StatusCode) 465 } 466 defer rsp.Body.Close() 467 b, err := ioutil.ReadAll(rsp.Body) 468 if err != nil { 469 t.Fatal(err) 470 } 471 472 var check healthpb.HealthCheckResponse 473 if err := protojson.Unmarshal(b, &check); err != nil { 474 t.Fatal(err) 475 } 476 t.Logf("http threads: %+v", &check) 477 }) 478 t.Run("grpcClient", func(t *testing.T) { 479 creds := credentials.NewTLS(tlsConfig) 480 cc, err := grpc.DialContext(ctx, l.Addr().String(), 481 grpc.WithTransportCredentials(creds), 482 ) 483 if err != nil { 484 t.Fatal(err) 485 } 486 client := healthpb.NewHealthClient(cc) 487 488 check, err := client.Check(ctx, &healthpb.HealthCheckRequest{}) 489 if err != nil { 490 t.Fatal(err) 491 } 492 t.Logf("grpc threads: %+v", check) 493 }) 494 t.Run("httpNoMTLS", func(t *testing.T) { 495 client := &http.Client{ 496 Transport: &http.Transport{ 497 TLSClientConfig: tlsInsecure, 498 }, 499 } 500 _, err := client.Get("https://" + l.Addr().String() + "/v1/health") 501 if err == nil { 502 t.Fatal("got nil error") 503 } 504 var nerr *net.OpError 505 if errors.As(err, &nerr) { 506 t.Log("nerr", nerr) 507 } else { 508 t.Fatal("unknown error:", err) 509 } 510 //for err != nil { 511 // t.Logf("%T", err) 512 // err = errors.Unwrap(err) 513 //} 514 }) 515 t.Run("grpcNoMTLS", func(t *testing.T) { 516 creds := credentials.NewTLS(tlsInsecure) 517 cc, err := grpc.DialContext(ctx, l.Addr().String(), 518 grpc.WithTransportCredentials(creds), 519 ) 520 if err != nil { 521 t.Fatal(err) 522 } 523 client := healthpb.NewHealthClient(cc) 524 525 // TODO: why NIL NIL!?! 526 check, err := client.Check(ctx, &healthpb.HealthCheckRequest{}) 527 if check != nil && err != nil { 528 t.Fatal("got nil error", check, err) 529 } 530 for err != nil { 531 t.Logf("%T", err) 532 err = errors.Unwrap(err) 533 } 534 535 }) 536 }