github.com/ovsinc/prototool@v1.3.0/internal/grpc/handler.go (about) 1 // Copyright (c) 2018 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package grpc 22 23 import ( 24 "context" 25 "encoding/json" 26 "fmt" 27 "io" 28 "strings" 29 "time" 30 31 "github.com/fullstorydev/grpcurl" 32 "github.com/golang/protobuf/protoc-gen-go/descriptor" 33 "github.com/uber/prototool/internal/desc" 34 "github.com/uber/prototool/internal/extract" 35 "go.uber.org/zap" 36 "google.golang.org/grpc" 37 "google.golang.org/grpc/keepalive" 38 ) 39 40 type handler struct { 41 logger *zap.Logger 42 callTimeout time.Duration 43 connectTimeout time.Duration 44 keepaliveTime time.Duration 45 headers []string 46 47 getter extract.Getter 48 } 49 50 func newHandler(options ...HandlerOption) *handler { 51 handler := &handler{ 52 logger: zap.NewNop(), 53 } 54 for _, option := range options { 55 option(handler) 56 } 57 if handler.callTimeout == 0 { 58 handler.callTimeout = DefaultCallTimeout 59 } 60 if handler.connectTimeout == 0 { 61 handler.connectTimeout = DefaultConnectTimeout 62 } 63 // TODO(pedge): composition 64 handler.getter = extract.NewGetter( 65 extract.GetterWithLogger(handler.logger), 66 ) 67 return handler 68 } 69 70 func (h *handler) Invoke(fileDescriptorSets []*descriptor.FileDescriptorSet, address string, method string, inputReader io.Reader, outputWriter io.Writer) error { 71 descriptorSource, err := h.getDescriptorSourceForMethod(fileDescriptorSets, method) 72 if err != nil { 73 return err 74 } 75 clientConn, err := h.dial(address) 76 if err != nil { 77 return err 78 } 79 defer func() { _ = clientConn.Close() }() 80 invocationEventHandler := newInvocationEventHandler(outputWriter, h.logger) 81 ctx, cancel := context.WithTimeout(context.Background(), h.callTimeout) 82 defer cancel() 83 if err := grpcurl.InvokeRpc( 84 ctx, 85 descriptorSource, 86 clientConn, 87 method, 88 h.headers, 89 invocationEventHandler, 90 decodeFunc(inputReader), 91 ); err != nil { 92 return err 93 } 94 return invocationEventHandler.Err() 95 } 96 97 func (h *handler) dial(address string) (*grpc.ClientConn, error) { 98 ctx, cancel := context.WithTimeout(context.Background(), h.connectTimeout) 99 defer cancel() 100 return grpcurl.BlockingDial(ctx, "tcp", address, nil, h.getDialOptions()...) 101 } 102 103 func (h *handler) getDialOptions() []grpc.DialOption { 104 var dialOptions []grpc.DialOption 105 if h.keepaliveTime != 0 { 106 dialOptions = append( 107 dialOptions, 108 grpc.WithKeepaliveParams( 109 keepalive.ClientParameters{ 110 Time: h.keepaliveTime, 111 Timeout: h.keepaliveTime, 112 }, 113 ), 114 ) 115 } 116 return dialOptions 117 } 118 119 func (h *handler) getDescriptorSourceForMethod(fileDescriptorSets []*descriptor.FileDescriptorSet, method string) (grpcurl.DescriptorSource, error) { 120 servicePath, err := getServiceForMethod(method) 121 if err != nil { 122 return nil, err 123 } 124 service, err := h.getter.GetService(fileDescriptorSets, servicePath) 125 if err != nil { 126 return nil, err 127 } 128 fileDescriptorSet, err := desc.SortFileDescriptorSet(service.FileDescriptorSet, service.FileDescriptorProto) 129 if err != nil { 130 return nil, err 131 } 132 return grpcurl.DescriptorSourceFromFileDescriptorSet(fileDescriptorSet) 133 } 134 135 func getServiceForMethod(method string) (string, error) { 136 split := strings.Split(method, "/") 137 if len(split) != 2 { 138 return "", fmt.Errorf("invalid gRPC method: %s", method) 139 } 140 return split[0], nil 141 } 142 143 func decodeFunc(reader io.Reader) func() ([]byte, error) { 144 decoder := json.NewDecoder(reader) 145 return func() ([]byte, error) { 146 var rawMessage json.RawMessage 147 if err := decoder.Decode(&rawMessage); err != nil { 148 return nil, err 149 } 150 return rawMessage, nil 151 } 152 }