go.uber.org/yarpc@v1.72.1/x/yarpctest/service.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package yarpctest
    22  
    23  import (
    24  	"errors"
    25  	"fmt"
    26  	"net"
    27  	"testing"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"github.com/uber-go/tally"
    32  	"go.uber.org/multierr"
    33  	"go.uber.org/yarpc"
    34  	"go.uber.org/yarpc/api/transport"
    35  	"go.uber.org/yarpc/transport/grpc"
    36  	"go.uber.org/yarpc/transport/http"
    37  	"go.uber.org/yarpc/transport/tchannel"
    38  	"go.uber.org/yarpc/x/yarpctest/api"
    39  )
    40  
    41  // HTTPService will create a runnable HTTP service.
    42  func HTTPService(options ...api.ServiceOption) api.Lifecycle {
    43  	return startThatCreatesStopFunc(func(t testing.TB) (stopper func(testing.TB) error, startErr error) {
    44  		opts := api.ServiceOpts{}
    45  		for _, option := range options {
    46  			option.ApplyService(&opts)
    47  		}
    48  		if opts.Listener != nil {
    49  			require.NoError(t, opts.Listener.Close())
    50  		}
    51  		inbound := http.NewTransport().NewInbound(fmt.Sprintf("127.0.0.1:%d", opts.Port))
    52  		s := createService(opts.Name, inbound, opts.Procedures, options)
    53  		return s.Stop, s.Start(t)
    54  	})
    55  }
    56  
    57  // TChannelService will create a runnable TChannel service.
    58  func TChannelService(options ...api.ServiceOption) api.Lifecycle {
    59  	return startThatCreatesStopFunc(func(t testing.TB) (stopper func(testing.TB) error, startErr error) {
    60  		opts := api.ServiceOpts{}
    61  		for _, option := range options {
    62  			option.ApplyService(&opts)
    63  		}
    64  		listener := opts.Listener
    65  		var err error
    66  		if listener == nil {
    67  			listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", opts.Port))
    68  			require.NoError(t, err)
    69  		}
    70  		trans, err := tchannel.NewTransport(
    71  			tchannel.ServiceName(opts.Name),
    72  			tchannel.Listener(listener),
    73  		)
    74  		require.NoError(t, err)
    75  		inbound := trans.NewInbound()
    76  		s := createService(opts.Name, inbound, opts.Procedures, options)
    77  		return s.Stop, s.Start(t)
    78  	})
    79  }
    80  
    81  // GRPCService will create a runnable GRPC service.
    82  func GRPCService(options ...api.ServiceOption) api.Lifecycle {
    83  	return startThatCreatesStopFunc(func(t testing.TB) (stopper func(testing.TB) error, startErr error) {
    84  		opts := api.ServiceOpts{}
    85  		for _, option := range options {
    86  			option.ApplyService(&opts)
    87  		}
    88  		trans := grpc.NewTransport()
    89  		listener := opts.Listener
    90  		var err error
    91  		if listener == nil {
    92  			listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", opts.Port))
    93  			require.NoError(t, err)
    94  		}
    95  		inbound := trans.NewInbound(listener)
    96  		service := createService(opts.Name, inbound, opts.Procedures, options)
    97  		return service.Stop, service.Start(t)
    98  	})
    99  }
   100  
   101  func startThatCreatesStopFunc(startToStop func(t testing.TB) (stopper func(testing.TB) error, startErr error)) api.Lifecycle {
   102  	return &startToStopper{
   103  		startWithReturnedStop: startToStop,
   104  	}
   105  }
   106  
   107  type startToStopper struct {
   108  	startWithReturnedStop func(testing.TB) (stopper func(testing.TB) error, startErr error)
   109  	stop                  func(testing.TB) error
   110  }
   111  
   112  func (s *startToStopper) Start(t testing.TB) error {
   113  	var err error
   114  	s.stop, err = s.startWithReturnedStop(t)
   115  	return err
   116  }
   117  
   118  func (s *startToStopper) Stop(t testing.TB) error {
   119  	if s.stop == nil {
   120  		return errors.New("did not start lifecycle")
   121  	}
   122  	return s.stop(t)
   123  }
   124  
   125  func createService(
   126  	name string,
   127  	inbound transport.Inbound,
   128  	procedures []transport.Procedure,
   129  	options []api.ServiceOption,
   130  ) *wrappedDispatcher {
   131  	d := yarpc.NewDispatcher(
   132  		yarpc.Config{
   133  			Name:     name,
   134  			Inbounds: yarpc.Inbounds{inbound},
   135  			Metrics: yarpc.MetricsConfig{
   136  				Tally: tally.NoopScope,
   137  			},
   138  		},
   139  	)
   140  	d.Register(procedures)
   141  	return &wrappedDispatcher{
   142  		Dispatcher: d,
   143  		options:    options,
   144  		procedures: procedures,
   145  	}
   146  }
   147  
   148  type wrappedDispatcher struct {
   149  	*yarpc.Dispatcher
   150  	options    []api.ServiceOption
   151  	procedures []transport.Procedure
   152  }
   153  
   154  func (w *wrappedDispatcher) Start(t testing.TB) error {
   155  	var err error
   156  	for _, option := range w.options {
   157  		err = multierr.Append(err, option.Start(t))
   158  	}
   159  	for _, procedure := range w.procedures {
   160  		if unary := procedure.HandlerSpec.Unary(); unary != nil {
   161  			if lc, ok := unary.(api.Lifecycle); ok {
   162  				err = multierr.Append(err, lc.Start(t))
   163  			}
   164  		}
   165  		if oneway := procedure.HandlerSpec.Oneway(); oneway != nil {
   166  			if lc, ok := oneway.(api.Lifecycle); ok {
   167  				err = multierr.Append(err, lc.Start(t))
   168  			}
   169  		}
   170  		if stream := procedure.HandlerSpec.Stream(); stream != nil {
   171  			if lc, ok := stream.(api.Lifecycle); ok {
   172  				err = multierr.Append(err, lc.Start(t))
   173  			}
   174  		}
   175  	}
   176  	err = multierr.Append(err, w.Dispatcher.Start())
   177  	assert.NoError(t, err, "error starting dispatcher: %s", w.Name())
   178  	return err
   179  }
   180  
   181  func (w *wrappedDispatcher) Stop(t testing.TB) error {
   182  	var err error
   183  	for _, option := range w.options {
   184  		err = multierr.Append(err, option.Stop(t))
   185  	}
   186  	for _, procedure := range w.procedures {
   187  		if unary := procedure.HandlerSpec.Unary(); unary != nil {
   188  			if lc, ok := unary.(api.Lifecycle); ok {
   189  				err = multierr.Append(err, lc.Stop(t))
   190  			}
   191  		}
   192  		if oneway := procedure.HandlerSpec.Oneway(); oneway != nil {
   193  			if lc, ok := oneway.(api.Lifecycle); ok {
   194  				err = multierr.Append(err, lc.Stop(t))
   195  			}
   196  		}
   197  		if stream := procedure.HandlerSpec.Stream(); stream != nil {
   198  			if lc, ok := stream.(api.Lifecycle); ok {
   199  				err = multierr.Append(err, lc.Stop(t))
   200  			}
   201  		}
   202  	}
   203  	err = multierr.Append(err, w.Dispatcher.Stop())
   204  	assert.NoError(t, err, "error stopping dispatcher: %s", w.Name())
   205  	return err
   206  }