go.uber.org/yarpc@v1.72.1/x/yarpctest/service.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package yarpctest 22 23 import ( 24 "errors" 25 "fmt" 26 "net" 27 "testing" 28 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 "github.com/uber-go/tally" 32 "go.uber.org/multierr" 33 "go.uber.org/yarpc" 34 "go.uber.org/yarpc/api/transport" 35 "go.uber.org/yarpc/transport/grpc" 36 "go.uber.org/yarpc/transport/http" 37 "go.uber.org/yarpc/transport/tchannel" 38 "go.uber.org/yarpc/x/yarpctest/api" 39 ) 40 41 // HTTPService will create a runnable HTTP service. 42 func HTTPService(options ...api.ServiceOption) api.Lifecycle { 43 return startThatCreatesStopFunc(func(t testing.TB) (stopper func(testing.TB) error, startErr error) { 44 opts := api.ServiceOpts{} 45 for _, option := range options { 46 option.ApplyService(&opts) 47 } 48 if opts.Listener != nil { 49 require.NoError(t, opts.Listener.Close()) 50 } 51 inbound := http.NewTransport().NewInbound(fmt.Sprintf("127.0.0.1:%d", opts.Port)) 52 s := createService(opts.Name, inbound, opts.Procedures, options) 53 return s.Stop, s.Start(t) 54 }) 55 } 56 57 // TChannelService will create a runnable TChannel service. 58 func TChannelService(options ...api.ServiceOption) api.Lifecycle { 59 return startThatCreatesStopFunc(func(t testing.TB) (stopper func(testing.TB) error, startErr error) { 60 opts := api.ServiceOpts{} 61 for _, option := range options { 62 option.ApplyService(&opts) 63 } 64 listener := opts.Listener 65 var err error 66 if listener == nil { 67 listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", opts.Port)) 68 require.NoError(t, err) 69 } 70 trans, err := tchannel.NewTransport( 71 tchannel.ServiceName(opts.Name), 72 tchannel.Listener(listener), 73 ) 74 require.NoError(t, err) 75 inbound := trans.NewInbound() 76 s := createService(opts.Name, inbound, opts.Procedures, options) 77 return s.Stop, s.Start(t) 78 }) 79 } 80 81 // GRPCService will create a runnable GRPC service. 82 func GRPCService(options ...api.ServiceOption) api.Lifecycle { 83 return startThatCreatesStopFunc(func(t testing.TB) (stopper func(testing.TB) error, startErr error) { 84 opts := api.ServiceOpts{} 85 for _, option := range options { 86 option.ApplyService(&opts) 87 } 88 trans := grpc.NewTransport() 89 listener := opts.Listener 90 var err error 91 if listener == nil { 92 listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", opts.Port)) 93 require.NoError(t, err) 94 } 95 inbound := trans.NewInbound(listener) 96 service := createService(opts.Name, inbound, opts.Procedures, options) 97 return service.Stop, service.Start(t) 98 }) 99 } 100 101 func startThatCreatesStopFunc(startToStop func(t testing.TB) (stopper func(testing.TB) error, startErr error)) api.Lifecycle { 102 return &startToStopper{ 103 startWithReturnedStop: startToStop, 104 } 105 } 106 107 type startToStopper struct { 108 startWithReturnedStop func(testing.TB) (stopper func(testing.TB) error, startErr error) 109 stop func(testing.TB) error 110 } 111 112 func (s *startToStopper) Start(t testing.TB) error { 113 var err error 114 s.stop, err = s.startWithReturnedStop(t) 115 return err 116 } 117 118 func (s *startToStopper) Stop(t testing.TB) error { 119 if s.stop == nil { 120 return errors.New("did not start lifecycle") 121 } 122 return s.stop(t) 123 } 124 125 func createService( 126 name string, 127 inbound transport.Inbound, 128 procedures []transport.Procedure, 129 options []api.ServiceOption, 130 ) *wrappedDispatcher { 131 d := yarpc.NewDispatcher( 132 yarpc.Config{ 133 Name: name, 134 Inbounds: yarpc.Inbounds{inbound}, 135 Metrics: yarpc.MetricsConfig{ 136 Tally: tally.NoopScope, 137 }, 138 }, 139 ) 140 d.Register(procedures) 141 return &wrappedDispatcher{ 142 Dispatcher: d, 143 options: options, 144 procedures: procedures, 145 } 146 } 147 148 type wrappedDispatcher struct { 149 *yarpc.Dispatcher 150 options []api.ServiceOption 151 procedures []transport.Procedure 152 } 153 154 func (w *wrappedDispatcher) Start(t testing.TB) error { 155 var err error 156 for _, option := range w.options { 157 err = multierr.Append(err, option.Start(t)) 158 } 159 for _, procedure := range w.procedures { 160 if unary := procedure.HandlerSpec.Unary(); unary != nil { 161 if lc, ok := unary.(api.Lifecycle); ok { 162 err = multierr.Append(err, lc.Start(t)) 163 } 164 } 165 if oneway := procedure.HandlerSpec.Oneway(); oneway != nil { 166 if lc, ok := oneway.(api.Lifecycle); ok { 167 err = multierr.Append(err, lc.Start(t)) 168 } 169 } 170 if stream := procedure.HandlerSpec.Stream(); stream != nil { 171 if lc, ok := stream.(api.Lifecycle); ok { 172 err = multierr.Append(err, lc.Start(t)) 173 } 174 } 175 } 176 err = multierr.Append(err, w.Dispatcher.Start()) 177 assert.NoError(t, err, "error starting dispatcher: %s", w.Name()) 178 return err 179 } 180 181 func (w *wrappedDispatcher) Stop(t testing.TB) error { 182 var err error 183 for _, option := range w.options { 184 err = multierr.Append(err, option.Stop(t)) 185 } 186 for _, procedure := range w.procedures { 187 if unary := procedure.HandlerSpec.Unary(); unary != nil { 188 if lc, ok := unary.(api.Lifecycle); ok { 189 err = multierr.Append(err, lc.Stop(t)) 190 } 191 } 192 if oneway := procedure.HandlerSpec.Oneway(); oneway != nil { 193 if lc, ok := oneway.(api.Lifecycle); ok { 194 err = multierr.Append(err, lc.Stop(t)) 195 } 196 } 197 if stream := procedure.HandlerSpec.Stream(); stream != nil { 198 if lc, ok := stream.(api.Lifecycle); ok { 199 err = multierr.Append(err, lc.Stop(t)) 200 } 201 } 202 } 203 err = multierr.Append(err, w.Dispatcher.Stop()) 204 assert.NoError(t, err, "error stopping dispatcher: %s", w.Name()) 205 return err 206 }