go.uber.org/yarpc@v1.72.1/internal/testutils/testutils.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package testutils
    22  
    23  import (
    24  	"fmt"
    25  	"net"
    26  	"strconv"
    27  
    28  	"go.uber.org/multierr"
    29  	"go.uber.org/yarpc"
    30  	"go.uber.org/yarpc/api/transport"
    31  	"go.uber.org/yarpc/encoding/protobuf"
    32  	"go.uber.org/yarpc/internal/grpcctx"
    33  	"go.uber.org/yarpc/transport/grpc"
    34  	"go.uber.org/yarpc/transport/http"
    35  	"go.uber.org/yarpc/transport/tchannel"
    36  	"go.uber.org/zap"
    37  	ggrpc "google.golang.org/grpc"
    38  )
    39  
    40  const (
    41  	// TransportTypeHTTP represents using HTTP.
    42  	TransportTypeHTTP TransportType = iota
    43  	// TransportTypeTChannel represents using TChannel.
    44  	TransportTypeTChannel
    45  	// TransportTypeGRPC represents using GRPC.
    46  	TransportTypeGRPC
    47  )
    48  
    49  var (
    50  	// AllTransportTypes are all TransportTypes,
    51  	AllTransportTypes = []TransportType{
    52  		TransportTypeHTTP,
    53  		TransportTypeTChannel,
    54  		TransportTypeGRPC,
    55  	}
    56  )
    57  
    58  // TransportType is a transport type.
    59  type TransportType int
    60  
    61  // String returns a string representation of t.
    62  func (t TransportType) String() string {
    63  	switch t {
    64  	case TransportTypeHTTP:
    65  		return "http"
    66  	case TransportTypeTChannel:
    67  		return "tchannel"
    68  	case TransportTypeGRPC:
    69  		return "grpc"
    70  	default:
    71  		return strconv.Itoa(int(t))
    72  	}
    73  }
    74  
    75  // ParseTransportType parses a transport type from a string.
    76  func ParseTransportType(s string) (TransportType, error) {
    77  	switch s {
    78  	case "http":
    79  		return TransportTypeHTTP, nil
    80  	case "tchannel":
    81  		return TransportTypeTChannel, nil
    82  	case "grpc":
    83  		return TransportTypeGRPC, nil
    84  	default:
    85  		return 0, fmt.Errorf("invalid TransportType: %s", s)
    86  	}
    87  }
    88  
    89  // ClientInfo holds the client info for testing.
    90  type ClientInfo struct {
    91  	ClientConfig   transport.ClientConfig
    92  	GRPCClientConn *ggrpc.ClientConn
    93  	ContextWrapper *grpcctx.ContextWrapper
    94  }
    95  
    96  // WithClientInfo wraps a function by setting up a client and server dispatcher and giving
    97  // the function the client configuration to use in tests for the given TransportType.
    98  //
    99  // The server dispatcher will be brought up using all TransportTypes and with the serviceName.
   100  // The client dispatcher will be brought up using the given TransportType for Unary, HTTP for
   101  // Oneway, and the serviceName with a "-client" suffix.
   102  func WithClientInfo(serviceName string, procedures []transport.Procedure, transportType TransportType, logger *zap.Logger, f func(*ClientInfo) error) (err error) {
   103  	if logger == nil {
   104  		logger = zap.NewNop()
   105  	}
   106  	dispatcherConfig, err := NewDispatcherConfig(serviceName)
   107  	if err != nil {
   108  		return err
   109  	}
   110  	serverDispatcher, err := NewServerDispatcher(procedures, dispatcherConfig, logger)
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	clientDispatcher, err := NewClientDispatcher(transportType, dispatcherConfig, logger)
   116  	if err != nil {
   117  		return err
   118  	}
   119  
   120  	if err := serverDispatcher.Start(); err != nil {
   121  		return err
   122  	}
   123  	defer func() { err = multierr.Append(err, serverDispatcher.Stop()) }()
   124  
   125  	if err := clientDispatcher.Start(); err != nil {
   126  		return err
   127  	}
   128  	defer func() { err = multierr.Append(err, clientDispatcher.Stop()) }()
   129  	grpcPort, err := dispatcherConfig.GetPort(TransportTypeGRPC)
   130  	if err != nil {
   131  		return err
   132  	}
   133  	grpcClientConn, err := ggrpc.Dial(fmt.Sprintf("127.0.0.1:%d", grpcPort), ggrpc.WithInsecure())
   134  	if err != nil {
   135  		return err
   136  	}
   137  	return f(
   138  		&ClientInfo{
   139  			clientDispatcher.ClientConfig(serviceName),
   140  			grpcClientConn,
   141  			grpcctx.NewContextWrapper().
   142  				WithCaller(serviceName + "-client").
   143  				WithService(serviceName).
   144  				WithEncoding(string(protobuf.Encoding)),
   145  		},
   146  	)
   147  }
   148  
   149  // NewClientDispatcher returns a new client Dispatcher.
   150  //
   151  // HTTP always will be configured as an outbound for Oneway.
   152  // gRPC always will be configured as an outbound for Stream.
   153  func NewClientDispatcher(transportType TransportType, config *DispatcherConfig, logger *zap.Logger) (*yarpc.Dispatcher, error) {
   154  	port, err := config.GetPort(transportType)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  	httpPort, err := config.GetPort(TransportTypeHTTP)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	grpcPort, err := config.GetPort(TransportTypeGRPC)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	onewayOutbound := http.NewTransport(http.Logger(logger)).NewSingleOutbound(fmt.Sprintf("http://127.0.0.1:%d", httpPort))
   167  	streamOutbound := grpc.NewTransport(grpc.Logger(logger)).NewSingleOutbound(fmt.Sprintf("127.0.0.1:%d", grpcPort))
   168  	var unaryOutbound transport.UnaryOutbound
   169  	switch transportType {
   170  	case TransportTypeTChannel:
   171  		tchannelTransport, err := tchannel.NewChannelTransport(tchannel.ServiceName(config.GetServiceName()), tchannel.Logger(logger))
   172  		if err != nil {
   173  			return nil, err
   174  		}
   175  		unaryOutbound = tchannelTransport.NewSingleOutbound(fmt.Sprintf("127.0.0.1:%d", port))
   176  	case TransportTypeHTTP:
   177  		unaryOutbound = onewayOutbound
   178  	case TransportTypeGRPC:
   179  		unaryOutbound = streamOutbound
   180  	default:
   181  		return nil, fmt.Errorf("invalid TransportType: %v", transportType)
   182  	}
   183  	return yarpc.NewDispatcher(
   184  		yarpc.Config{
   185  			Name: fmt.Sprintf("%s-client", config.GetServiceName()),
   186  			Outbounds: yarpc.Outbounds{
   187  				config.GetServiceName(): {
   188  					Oneway: onewayOutbound,
   189  					Unary:  unaryOutbound,
   190  					Stream: streamOutbound,
   191  				},
   192  			},
   193  		},
   194  	), nil
   195  }
   196  
   197  // NewServerDispatcher returns a new server Dispatcher.
   198  func NewServerDispatcher(procedures []transport.Procedure, config *DispatcherConfig, logger *zap.Logger) (*yarpc.Dispatcher, error) {
   199  	tchannelPort, err := config.GetPort(TransportTypeTChannel)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  	httpPort, err := config.GetPort(TransportTypeHTTP)
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  	grpcPort, err := config.GetPort(TransportTypeGRPC)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  	tchannelTransport, err := tchannel.NewChannelTransport(
   212  		tchannel.ServiceName(config.GetServiceName()),
   213  		tchannel.ListenAddr(fmt.Sprintf("127.0.0.1:%d", tchannelPort)),
   214  		tchannel.Logger(logger),
   215  	)
   216  	if err != nil {
   217  		return nil, err
   218  	}
   219  	grpcListener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", grpcPort))
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  	dispatcher := yarpc.NewDispatcher(
   224  		yarpc.Config{
   225  			Name: config.GetServiceName(),
   226  			Inbounds: yarpc.Inbounds{
   227  				tchannelTransport.NewInbound(),
   228  				http.NewTransport(http.Logger(logger)).NewInbound(fmt.Sprintf("127.0.0.1:%d", httpPort)),
   229  				grpc.NewTransport(grpc.Logger(logger)).NewInbound(grpcListener),
   230  			},
   231  		},
   232  	)
   233  	dispatcher.Register(procedures)
   234  	return dispatcher, nil
   235  }
   236  
   237  // DispatcherConfig is the configuration for a Dispatcher.
   238  type DispatcherConfig struct {
   239  	serviceName         string
   240  	transportTypeToPort map[TransportType]uint16
   241  }
   242  
   243  // NewDispatcherConfig returns a new DispatcherConfig with assigned ports.
   244  func NewDispatcherConfig(serviceName string) (*DispatcherConfig, error) {
   245  	transportTypeToPort, err := getTransportTypeToPort()
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	return &DispatcherConfig{
   250  		serviceName,
   251  		transportTypeToPort,
   252  	}, nil
   253  }
   254  
   255  // GetServiceName gets the service name.
   256  func (d *DispatcherConfig) GetServiceName() string {
   257  	return d.serviceName
   258  }
   259  
   260  // GetPort gets the port for the TransportType.
   261  func (d *DispatcherConfig) GetPort(transportType TransportType) (uint16, error) {
   262  	port, ok := d.transportTypeToPort[transportType]
   263  	if !ok {
   264  		return 0, fmt.Errorf("no port for TransportType %v", transportType)
   265  	}
   266  	return port, nil
   267  }
   268  
   269  func getTransportTypeToPort() (map[TransportType]uint16, error) {
   270  	m := make(map[TransportType]uint16, len(AllTransportTypes))
   271  	for _, transportType := range AllTransportTypes {
   272  		port, err := getFreePort()
   273  		if err != nil {
   274  			return nil, err
   275  		}
   276  		m[transportType] = port
   277  	}
   278  	return m, nil
   279  }
   280  
   281  func getFreePort() (uint16, error) {
   282  	address, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
   283  	if err != nil {
   284  		return 0, err
   285  	}
   286  
   287  	listener, err := net.ListenTCP("tcp", address)
   288  	if err != nil {
   289  		return 0, err
   290  	}
   291  	port := uint16(listener.Addr().(*net.TCPAddr).Port)
   292  	if err := listener.Close(); err != nil {
   293  		return 0, err
   294  	}
   295  	return port, nil
   296  }