github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/handler.go (about) 1 // Copyright 2021 Edward McFarlane. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package larking 6 7 import ( 8 "fmt" 9 "log" 10 "reflect" 11 12 "google.golang.org/grpc" 13 "google.golang.org/grpc/metadata" 14 "google.golang.org/protobuf/reflect/protoreflect" 15 ) 16 17 type handlerFunc func(*muxOptions, grpc.ServerStream) error 18 19 type handler struct { 20 method string 21 descriptor protoreflect.MethodDescriptor 22 handler handlerFunc 23 } 24 25 // TODO: use grpclog? 26 //var logger = grpclog.Component("core") 27 28 // RegisterService satisfies gprc generated service code. 29 func (m *Mux) RegisterService(sd *grpc.ServiceDesc, ss interface{}) { 30 if ss != nil { 31 ht := reflect.TypeOf(sd.HandlerType).Elem() 32 st := reflect.TypeOf(ss) 33 if !st.Implements(ht) { 34 log.Fatalf("grpc: RegisterService found the handler of type %v that does not satisfy %v", st, ht) 35 } 36 } 37 if err := m.registerService(sd, ss); err != nil { 38 log.Fatalf("larking: RegisterService error: %v", err) 39 } 40 } 41 42 func (m *Mux) registerService(gsd *grpc.ServiceDesc, ss interface{}) error { 43 44 // Load the state for writing. 45 m.mu.Lock() 46 defer m.mu.Unlock() 47 s := m.loadState().clone() 48 49 d, err := m.opts.files.FindDescriptorByName(protoreflect.FullName(gsd.ServiceName)) 50 if err != nil { 51 return err 52 } 53 sd, ok := d.(protoreflect.ServiceDescriptor) 54 if !ok { 55 return fmt.Errorf("invalid method descriptor %T", d) 56 } 57 mds := sd.Methods() 58 59 findMethod := func(methodName string) (protoreflect.MethodDescriptor, error) { 60 md := mds.ByName(protoreflect.Name(methodName)) 61 if md == nil { 62 return nil, fmt.Errorf("missing method descriptor for %v", methodName) 63 } 64 return md, nil 65 } 66 67 for i := range gsd.Methods { 68 d := &gsd.Methods[i] 69 method := "/" + gsd.ServiceName + "/" + d.MethodName 70 71 md, err := findMethod(d.MethodName) 72 if err != nil { 73 return err 74 } 75 76 rule := getExtensionHTTP(md.Options()) 77 if rule == nil { 78 continue 79 } 80 81 h := &handler{ 82 method: method, 83 descriptor: md, 84 handler: func(opts *muxOptions, stream grpc.ServerStream) error { 85 ctx := stream.Context() 86 87 // TODO: opts? 88 reply, err := d.Handler(ss, ctx, stream.RecvMsg, opts.unaryInterceptor) 89 if err != nil { 90 return err 91 } 92 return stream.SendMsg(reply) 93 }, 94 } 95 96 if err := s.appendHandler(rule, md, h); err != nil { 97 return err 98 } 99 } 100 for i := range gsd.Streams { 101 d := &gsd.Streams[i] 102 method := "/" + gsd.ServiceName + "/" + d.StreamName 103 md, err := findMethod(d.StreamName) 104 if err != nil { 105 return err 106 } 107 108 rule := getExtensionHTTP(md.Options()) 109 if rule == nil { 110 continue 111 } 112 113 h := &handler{ 114 method: method, 115 descriptor: md, 116 handler: func(opts *muxOptions, stream grpc.ServerStream) error { 117 info := &grpc.StreamServerInfo{ 118 FullMethod: method, 119 IsClientStream: d.ClientStreams, 120 IsServerStream: d.ServerStreams, 121 } 122 123 return opts.stream(ss, stream, info, d.Handler) 124 }, 125 } 126 if err := s.appendHandler(rule, md, h); err != nil { 127 return err 128 } 129 } 130 131 m.storeState(s) 132 return nil 133 } 134 135 var _ grpc.ServerTransportStream = (*serverTransportStream)(nil) 136 137 // serverTransportStream wraps gprc.SeverStream to support header/trailers. 138 type serverTransportStream struct { 139 grpc.ServerStream 140 method string 141 } 142 143 func (s *serverTransportStream) Method() string { return s.method } 144 func (s *serverTransportStream) SetHeader(md metadata.MD) error { 145 return s.ServerStream.SetHeader(md) 146 } 147 func (s *serverTransportStream) SendHeader(md metadata.MD) error { 148 return s.ServerStream.SendHeader(md) 149 } 150 func (s *serverTransportStream) SetTrailer(md metadata.MD) error { 151 s.ServerStream.SetTrailer(md) 152 return nil 153 }