go.temporal.io/server@v1.23.0/common/rpc/test/rpc_common_test.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package rpc
    26  
    27  import (
    28  	"context"
    29  	"crypto/tls"
    30  	"math/rand"
    31  	"net"
    32  	"strings"
    33  
    34  	"github.com/stretchr/testify/suite"
    35  	"google.golang.org/grpc"
    36  	"google.golang.org/grpc/credentials"
    37  	"google.golang.org/grpc/examples/helloworld/helloworld"
    38  	"google.golang.org/grpc/peer"
    39  
    40  	"go.temporal.io/server/common/config"
    41  	"go.temporal.io/server/common/convert"
    42  	"go.temporal.io/server/common/log"
    43  	"go.temporal.io/server/common/rpc"
    44  )
    45  
    46  // HelloServer is used to implement helloworld.GreeterServer.
    47  type HelloServer struct {
    48  	helloworld.UnimplementedGreeterServer
    49  }
    50  
    51  type ServerUsageType int32
    52  
    53  const (
    54  	Frontend ServerUsageType = iota
    55  	Internode
    56  	RemoteCluster
    57  )
    58  
    59  const (
    60  	localhostIPv4 = "127.0.0.1"
    61  	localhost     = "localhost"
    62  )
    63  
    64  type TestFactory struct {
    65  	*rpc.RPCFactory
    66  	serverUsage ServerUsageType
    67  }
    68  
    69  // SayHello implements helloworld.GreeterServer
    70  func (s *HelloServer) SayHello(ctx context.Context, in *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
    71  	return &helloworld.HelloReply{Message: "Hello " + in.Name}, nil
    72  }
    73  
    74  var (
    75  	rpcTestCfgDefault = &config.RPC{
    76  		GRPCPort:       0,
    77  		MembershipPort: 7600,
    78  		BindOnIP:       localhostIPv4,
    79  	}
    80  	serverCfgInsecure = &config.Global{
    81  		Membership: config.Membership{
    82  			MaxJoinDuration:  5,
    83  			BroadcastAddress: localhostIPv4,
    84  		},
    85  	}
    86  )
    87  
    88  func startHelloWorldServer(s *suite.Suite, factory *TestFactory) (*grpc.Server, string) {
    89  	var opts []grpc.ServerOption
    90  	var err error
    91  	if factory.serverUsage == Internode {
    92  		opts, err = factory.GetInternodeGRPCServerOptions()
    93  	} else {
    94  		opts, err = factory.GetFrontendGRPCServerOptions()
    95  	}
    96  	s.NoError(err)
    97  
    98  	server := grpc.NewServer(opts...)
    99  	greeter := &HelloServer{}
   100  	helloworld.RegisterGreeterServer(server, greeter)
   101  
   102  	listener := factory.GetGRPCListener()
   103  
   104  	port := strings.Split(listener.Addr().String(), ":")[1]
   105  	s.NoError(err)
   106  	go func() {
   107  		err = server.Serve(listener)
   108  	}()
   109  	s.NoError(err)
   110  	return server, port
   111  }
   112  
   113  func runHelloWorldTest(s *suite.Suite, host string, serverFactory *TestFactory, clientFactory *TestFactory, isValid bool) {
   114  	server, port := startHelloWorldServer(s, serverFactory)
   115  	defer server.Stop()
   116  	err := dialHello(s, host+":"+port, clientFactory, serverFactory.serverUsage)
   117  
   118  	if isValid {
   119  		s.NoError(err)
   120  	} else {
   121  		s.Error(err)
   122  	}
   123  }
   124  
   125  func runHelloWorldMultipleDials(
   126  	s *suite.Suite,
   127  	host string,
   128  	serverFactory *TestFactory,
   129  	clientFactory *TestFactory,
   130  	nDials int,
   131  	validator func(*credentials.TLSInfo, error),
   132  ) {
   133  
   134  	server, port := startHelloWorldServer(s, serverFactory)
   135  	defer server.Stop()
   136  
   137  	for i := 0; i < nDials; i++ {
   138  		tlsInfo, err := dialHelloAndGetTLSInfo(s, host+":"+port, clientFactory, serverFactory.serverUsage)
   139  		validator(tlsInfo, err)
   140  	}
   141  }
   142  
   143  func dialHello(s *suite.Suite, hostport string, clientFactory *TestFactory, serverType ServerUsageType) error {
   144  	_, err := dialHelloAndGetTLSInfo(s, hostport, clientFactory, serverType)
   145  	return err
   146  }
   147  
   148  func dialHelloAndGetTLSInfo(
   149  	s *suite.Suite,
   150  	hostport string,
   151  	clientFactory *TestFactory,
   152  	serverType ServerUsageType,
   153  ) (*credentials.TLSInfo, error) {
   154  
   155  	logger := log.NewNoopLogger()
   156  	var cfg *tls.Config
   157  	var err error
   158  	switch serverType {
   159  	case Internode:
   160  		cfg, err = clientFactory.GetInternodeClientTlsConfig()
   161  		s.NoError(err)
   162  	case Frontend:
   163  		cfg, err = clientFactory.GetFrontendClientTlsConfig()
   164  		s.NoError(err)
   165  	case RemoteCluster:
   166  		host, _, err := net.SplitHostPort(hostport)
   167  		s.NoError(err)
   168  		cfg, err = clientFactory.GetRemoteClusterClientConfig(host)
   169  		s.NoError(err)
   170  	}
   171  
   172  	clientConn, err := rpc.Dial(hostport, cfg, logger)
   173  	s.NoError(err)
   174  
   175  	client := helloworld.NewGreeterClient(clientConn)
   176  
   177  	request := &helloworld.HelloRequest{Name: convert.Uint64ToString(rand.Uint64())}
   178  	var reply *helloworld.HelloReply
   179  	peer := new(peer.Peer)
   180  	reply, err = client.SayHello(context.Background(), request, grpc.Peer(peer))
   181  	tlsInfo, _ := peer.AuthInfo.(credentials.TLSInfo)
   182  
   183  	if err == nil {
   184  		s.NotNil(reply)
   185  		s.True(strings.Contains(reply.Message, request.Name))
   186  	}
   187  
   188  	_ = clientConn.Close()
   189  	return &tlsInfo, err
   190  }