github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/server_test.go (about)

     1  // Copyright 2022 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package larking
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/rand"
    11  	"crypto/rsa"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"crypto/x509/pkix"
    15  	"encoding/pem"
    16  	"errors"
    17  	"io/ioutil"
    18  	"math/big"
    19  	"net"
    20  	"net/http"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/go-logr/logr"
    25  	testing_logr "github.com/go-logr/logr/testing"
    26  	"github.com/google/go-cmp/cmp"
    27  	"golang.org/x/sync/errgroup"
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/credentials"
    30  	"google.golang.org/grpc/credentials/insecure"
    31  	"google.golang.org/grpc/metadata"
    32  	"google.golang.org/grpc/reflection"
    33  	"google.golang.org/protobuf/encoding/protojson"
    34  	"google.golang.org/protobuf/proto"
    35  	"google.golang.org/protobuf/testing/protocmp"
    36  
    37  	"github.com/emcfarlane/larking/apipb/healthpb"
    38  	"github.com/emcfarlane/larking/health"
    39  	"github.com/emcfarlane/larking/testpb"
    40  )
    41  
    42  func testContext(t *testing.T) context.Context {
    43  	ctx := context.Background()
    44  	log := testing_logr.NewTestLogger(t)
    45  	ctx = logr.NewContext(ctx, log)
    46  	return ctx
    47  }
    48  
    49  func TestServer(t *testing.T) {
    50  	ms := &testpb.UnimplementedMessagingServer{}
    51  
    52  	o := &overrides{}
    53  	gs := grpc.NewServer(o.streamOption(), o.unaryOption())
    54  	testpb.RegisterMessagingServer(gs, ms)
    55  	reflection.Register(gs)
    56  
    57  	lis, err := net.Listen("tcp", "localhost:0")
    58  	if err != nil {
    59  		t.Fatalf("failed to listen: %v", err)
    60  	}
    61  	defer lis.Close()
    62  
    63  	var g errgroup.Group
    64  	defer func() {
    65  		if err := g.Wait(); err != nil {
    66  			t.Fatal(err)
    67  		}
    68  	}()
    69  
    70  	g.Go(func() error {
    71  		return gs.Serve(lis)
    72  	})
    73  	defer gs.Stop()
    74  
    75  	// Create the client.
    76  	creds := insecure.NewCredentials()
    77  	conn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(creds))
    78  	if err != nil {
    79  		t.Fatalf("cannot connect to server: %v", err)
    80  	}
    81  	defer conn.Close()
    82  
    83  	mux, err := NewMux()
    84  	if err != nil {
    85  		t.Fatal(err)
    86  	}
    87  	if err := mux.RegisterConn(context.Background(), conn); err != nil {
    88  		t.Fatal(err)
    89  	}
    90  
    91  	ts, err := NewServer(mux, InsecureServerOption())
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  
    96  	lisProxy, err := net.Listen("tcp", "localhost:0")
    97  	if err != nil {
    98  		t.Fatalf("failed to listen: %v", err)
    99  	}
   100  	defer lisProxy.Close()
   101  
   102  	g.Go(func() error {
   103  		if err := ts.Serve(lisProxy); err != nil && err != http.ErrServerClosed {
   104  			return err
   105  		}
   106  		return nil
   107  	})
   108  	defer func() {
   109  		if err := ts.Shutdown(context.Background()); err != nil {
   110  			t.Fatal(err)
   111  		}
   112  	}()
   113  
   114  	cc, err := grpc.Dial(
   115  		lisProxy.Addr().String(),
   116  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   117  		grpc.WithBlock(),
   118  	)
   119  	if err != nil {
   120  		t.Fatal(err)
   121  	}
   122  
   123  	cmpOpts := cmp.Options{protocmp.Transform()}
   124  
   125  	var unaryStreamDesc = &grpc.StreamDesc{
   126  		ClientStreams: false,
   127  		ServerStreams: false,
   128  	}
   129  
   130  	tests := []struct {
   131  		name   string
   132  		desc   *grpc.StreamDesc
   133  		method string
   134  		inouts []interface{}
   135  		//ins    []in
   136  		//outs   []out
   137  	}{{
   138  		name:   "unary_message",
   139  		desc:   unaryStreamDesc,
   140  		method: "/larking.testpb.Messaging/GetMessageOne",
   141  		inouts: []interface{}{
   142  			in{msg: &testpb.GetMessageRequestOne{Name: "proxy"}},
   143  			out{msg: &testpb.Message{Text: "success"}},
   144  		},
   145  	}}
   146  
   147  	for _, tt := range tests {
   148  		t.Run(tt.name, func(t *testing.T) {
   149  			o.reset(t, "test", tt.inouts)
   150  
   151  			ctx := testContext(t)
   152  			ctx = metadata.AppendToOutgoingContext(ctx, "test", tt.method)
   153  
   154  			s, err := cc.NewStream(ctx, tt.desc, tt.method)
   155  			if err != nil {
   156  				t.Fatal(err)
   157  			}
   158  
   159  			for i := 0; i < len(tt.inouts); i++ {
   160  				switch typ := tt.inouts[i].(type) {
   161  				case in:
   162  					if err := s.SendMsg(typ.msg); err != nil {
   163  						t.Fatal(err)
   164  					}
   165  				case out:
   166  					out := proto.Clone(typ.msg)
   167  					if err := s.RecvMsg(out); err != nil {
   168  						t.Fatal(err)
   169  					}
   170  					diff := cmp.Diff(out, typ.msg, cmpOpts...)
   171  					if diff != "" {
   172  						t.Fatal(diff)
   173  					}
   174  				}
   175  			}
   176  		})
   177  	}
   178  }
   179  
   180  func TestMuxHandleOption(t *testing.T) {
   181  	mux, err := NewMux()
   182  	if err != nil {
   183  		t.Fatal(err)
   184  	}
   185  
   186  	hs := health.NewServer()
   187  	defer hs.Shutdown()
   188  	mux.RegisterService(&healthpb.Health_ServiceDesc, hs)
   189  
   190  	s, err := NewServer(
   191  		mux,
   192  		InsecureServerOption(),
   193  		MuxHandleOption("/", "/api/", "/pfx"),
   194  	)
   195  	if err != nil {
   196  		t.Fatal(err)
   197  	}
   198  
   199  	lis, err := net.Listen("tcp", ":0")
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  	defer lis.Close()
   204  
   205  	var g errgroup.Group
   206  	defer func() {
   207  		if err := g.Wait(); err != nil {
   208  			t.Fatal(err)
   209  		}
   210  	}()
   211  
   212  	g.Go(func() (err error) {
   213  		if err := s.Serve(lis); err != nil && err != http.ErrServerClosed {
   214  			return err
   215  		}
   216  		return nil
   217  	})
   218  	defer func() {
   219  		if err := s.Shutdown(context.Background()); err != nil {
   220  			t.Fatal(err)
   221  		}
   222  	}()
   223  
   224  	for _, tt := range []struct {
   225  		path string
   226  		okay bool
   227  	}{
   228  		{"/v1/health", true},
   229  		{"/api/v1/health", true},
   230  		{"/pfx/v1/health", true},
   231  		{"/bad/v1/health", false},
   232  		{"/v1/health/bad", false},
   233  	} {
   234  		t.Run(tt.path, func(t *testing.T) {
   235  			rsp, err := http.Get("http://" + lis.Addr().String() + tt.path)
   236  			if err != nil {
   237  				t.Fatal(err)
   238  			}
   239  			okay := rsp.StatusCode == 200
   240  			if okay != tt.okay {
   241  				t.Errorf("request got %t for %s", okay, tt.path)
   242  			}
   243  		})
   244  	}
   245  }
   246  
   247  func createCertificateAuthority() ([]byte, []byte, error) {
   248  	ca := &x509.Certificate{
   249  		SerialNumber: big.NewInt(2021),
   250  		Subject: pkix.Name{
   251  			Organization: []string{"Acme Co"},
   252  		},
   253  		NotBefore:             time.Now(),
   254  		NotAfter:              time.Now().AddDate(10, 0, 0),
   255  		IsCA:                  true,
   256  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   257  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   258  		BasicConstraintsValid: true,
   259  	}
   260  
   261  	caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
   262  	if err != nil {
   263  		return nil, nil, err
   264  	}
   265  	caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
   266  	if err != nil {
   267  		return nil, nil, err
   268  	}
   269  
   270  	caPEM := new(bytes.Buffer)
   271  	if err := pem.Encode(caPEM, &pem.Block{
   272  		Type:  "CERTIFICATE",
   273  		Bytes: caBytes,
   274  	}); err != nil {
   275  		return nil, nil, err
   276  	}
   277  	caPrivKeyPEM := new(bytes.Buffer)
   278  	if err := pem.Encode(caPrivKeyPEM, &pem.Block{
   279  		Type:  "RSA PRIVATE KEY",
   280  		Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
   281  	}); err != nil {
   282  		return nil, nil, err
   283  	}
   284  	return caPEM.Bytes(), caPrivKeyPEM.Bytes(), nil
   285  }
   286  
   287  func createCertificate(caCertPEM, caKeyPEM []byte, commonName string) ([]byte, []byte, error) {
   288  	keyPEMBlock, _ := pem.Decode(caKeyPEM)
   289  	privateKey, err := x509.ParsePKCS1PrivateKey(keyPEMBlock.Bytes)
   290  	if err != nil {
   291  		return nil, nil, err
   292  	}
   293  
   294  	certPEMBlock, _ := pem.Decode(caCertPEM)
   295  	parent, err := x509.ParseCertificate(certPEMBlock.Bytes)
   296  	if err != nil {
   297  		return nil, nil, err
   298  	}
   299  
   300  	cert := &x509.Certificate{
   301  		SerialNumber: big.NewInt(1658),
   302  		Subject: pkix.Name{
   303  			Organization: []string{"Acme Co"},
   304  			CommonName:   commonName,
   305  		},
   306  		IPAddresses: []net.IP{
   307  			net.IPv4(127, 0, 0, 1),
   308  			net.IPv6loopback,
   309  			net.IPv4(0, 0, 0, 0),
   310  			net.IPv6zero,
   311  		},
   312  		NotBefore:    time.Now(),
   313  		NotAfter:     time.Now().AddDate(10, 0, 0),
   314  		SubjectKeyId: []byte{1, 2, 3, 4, 6},
   315  		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   316  		KeyUsage:     x509.KeyUsageDigitalSignature,
   317  	}
   318  
   319  	certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
   320  	if err != nil {
   321  		return nil, nil, err
   322  	}
   323  	certBytes, err := x509.CreateCertificate(rand.Reader, cert, parent, &certPrivKey.PublicKey, privateKey)
   324  	if err != nil {
   325  		return nil, nil, err
   326  	}
   327  
   328  	certPEM := new(bytes.Buffer)
   329  	if err := pem.Encode(certPEM, &pem.Block{
   330  		Type:  "CERTIFICATE",
   331  		Bytes: certBytes,
   332  	}); err != nil {
   333  		return nil, nil, err
   334  	}
   335  	certPrivKeyPEM := new(bytes.Buffer)
   336  	if err := pem.Encode(certPrivKeyPEM, &pem.Block{
   337  		Type:  "RSA PRIVATE KEY",
   338  		Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
   339  	}); err != nil {
   340  		return nil, nil, err
   341  	}
   342  	return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
   343  }
   344  
   345  func TestTLSServer(t *testing.T) {
   346  	ctx, cancel := context.WithCancel(testContext(t))
   347  	defer cancel()
   348  
   349  	// certPool
   350  	certPool := x509.NewCertPool()
   351  	caCertPEM, caKeyPEM, err := createCertificateAuthority()
   352  	if err != nil {
   353  		t.Fatal(err)
   354  	}
   355  	if ok := certPool.AppendCertsFromPEM(caCertPEM); !ok {
   356  		t.Fatal("failed to append client certs")
   357  	}
   358  
   359  	certPEM, keyPEM, err := createCertificate(caCertPEM, caKeyPEM, "Server")
   360  	if err != nil {
   361  		t.Fatal(err)
   362  	}
   363  	certificate, err := tls.X509KeyPair(certPEM, keyPEM)
   364  	if err != nil {
   365  		t.Fatal(err)
   366  	}
   367  	tlsConfig := &tls.Config{
   368  		ClientAuth:   tls.RequireAndVerifyClientCert,
   369  		Certificates: []tls.Certificate{certificate},
   370  		ClientCAs:    certPool,
   371  	}
   372  
   373  	// TODO!
   374  	verfiyPeer := func(ctx context.Context) error {
   375  		//	p, ok := peer.FromContext(ctx)
   376  		//	if !ok {
   377  		//		return status.Error(codes.Unauthenticated, "no peer found")
   378  		//	}
   379  		//	tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo)
   380  		//	if !ok {
   381  		//		return status.Error(codes.Unauthenticated, "unexpected peer transport credentials")
   382  		//	}
   383  		//	if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 {
   384  		//		return status.Error(codes.Unauthenticated, "could not verify peer certificate")
   385  		//	}
   386  		//	fmt.Println(
   387  		//		"tlsAuth.State.VerifiedChains[0][0].Subject.CommonName",
   388  		//		tlsAuth.State.VerifiedChains[0][0].Subject.CommonName,
   389  		//	)
   390  		//	// Check subject common name against configured username
   391  		//	if tlsAuth.State.VerifiedChains[0][0].Subject.CommonName != "Client" {
   392  		//		return status.Error(codes.Unauthenticated, "invalid subject common name")
   393  		//	}
   394  		return nil
   395  	}
   396  
   397  	mux, err := NewMux(
   398  		UnaryServerInterceptorOption(
   399  			func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
   400  				if err := verfiyPeer(ctx); err != nil {
   401  					return nil, err
   402  				}
   403  				return handler(ctx, req)
   404  			},
   405  		),
   406  	)
   407  	if err != nil {
   408  		t.Fatal(err)
   409  	}
   410  	healthServer := health.NewServer()
   411  	defer healthServer.Shutdown()
   412  	mux.RegisterService(&healthpb.Health_ServiceDesc, healthServer)
   413  
   414  	s, err := NewServer(mux,
   415  		TLSCredsOption(tlsConfig),
   416  	)
   417  	if err != nil {
   418  		t.Fatal(err)
   419  	}
   420  
   421  	l, err := net.Listen("tcp", ":0")
   422  	if err != nil {
   423  		t.Fatal(err)
   424  	}
   425  
   426  	g := errgroup.Group{}
   427  	g.Go(func() error { return s.Serve(l) })
   428  	defer func() {
   429  		if err := s.Shutdown(ctx); err != nil {
   430  			t.Error(err)
   431  		}
   432  		if err := g.Wait(); err != nil && err != http.ErrServerClosed {
   433  			t.Error(err)
   434  		}
   435  	}()
   436  
   437  	certPEM, keyPEM, err = createCertificate(caCertPEM, caKeyPEM, "Client")
   438  	if err != nil {
   439  		t.Fatal(err)
   440  	}
   441  	certificate, err = tls.X509KeyPair(certPEM, keyPEM)
   442  	if err != nil {
   443  		t.Fatal(err)
   444  	}
   445  	tlsConfig = &tls.Config{
   446  		Certificates: []tls.Certificate{certificate},
   447  		RootCAs:      certPool,
   448  	}
   449  	tlsInsecure := &tls.Config{
   450  		InsecureSkipVerify: true,
   451  	}
   452  
   453  	t.Run("httpClient", func(t *testing.T) {
   454  		client := &http.Client{
   455  			Transport: &http.Transport{
   456  				TLSClientConfig: tlsConfig,
   457  			},
   458  		}
   459  		rsp, err := client.Get("https://" + l.Addr().String() + "/v1/health")
   460  		if err != nil {
   461  			t.Fatal(err)
   462  		}
   463  		if rsp.StatusCode != http.StatusOK {
   464  			t.Fatal("invalid status code", rsp.StatusCode)
   465  		}
   466  		defer rsp.Body.Close()
   467  		b, err := ioutil.ReadAll(rsp.Body)
   468  		if err != nil {
   469  			t.Fatal(err)
   470  		}
   471  
   472  		var check healthpb.HealthCheckResponse
   473  		if err := protojson.Unmarshal(b, &check); err != nil {
   474  			t.Fatal(err)
   475  		}
   476  		t.Logf("http threads: %+v", &check)
   477  	})
   478  	t.Run("grpcClient", func(t *testing.T) {
   479  		creds := credentials.NewTLS(tlsConfig)
   480  		cc, err := grpc.DialContext(ctx, l.Addr().String(),
   481  			grpc.WithTransportCredentials(creds),
   482  		)
   483  		if err != nil {
   484  			t.Fatal(err)
   485  		}
   486  		client := healthpb.NewHealthClient(cc)
   487  
   488  		check, err := client.Check(ctx, &healthpb.HealthCheckRequest{})
   489  		if err != nil {
   490  			t.Fatal(err)
   491  		}
   492  		t.Logf("grpc threads: %+v", check)
   493  	})
   494  	t.Run("httpNoMTLS", func(t *testing.T) {
   495  		client := &http.Client{
   496  			Transport: &http.Transport{
   497  				TLSClientConfig: tlsInsecure,
   498  			},
   499  		}
   500  		_, err := client.Get("https://" + l.Addr().String() + "/v1/health")
   501  		if err == nil {
   502  			t.Fatal("got nil error")
   503  		}
   504  		var nerr *net.OpError
   505  		if errors.As(err, &nerr) {
   506  			t.Log("nerr", nerr)
   507  		} else {
   508  			t.Fatal("unknown error:", err)
   509  		}
   510  		//for err != nil {
   511  		//	t.Logf("%T", err)
   512  		//	err = errors.Unwrap(err)
   513  		//}
   514  	})
   515  	t.Run("grpcNoMTLS", func(t *testing.T) {
   516  		creds := credentials.NewTLS(tlsInsecure)
   517  		cc, err := grpc.DialContext(ctx, l.Addr().String(),
   518  			grpc.WithTransportCredentials(creds),
   519  		)
   520  		if err != nil {
   521  			t.Fatal(err)
   522  		}
   523  		client := healthpb.NewHealthClient(cc)
   524  
   525  		// TODO: why NIL NIL!?!
   526  		check, err := client.Check(ctx, &healthpb.HealthCheckRequest{})
   527  		if check != nil && err != nil {
   528  			t.Fatal("got nil error", check, err)
   529  		}
   530  		for err != nil {
   531  			t.Logf("%T", err)
   532  			err = errors.Unwrap(err)
   533  		}
   534  
   535  	})
   536  }