github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/handler.go (about)

     1  // Copyright 2021 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package larking
     6  
     7  import (
     8  	"fmt"
     9  	"log"
    10  	"reflect"
    11  
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/metadata"
    14  	"google.golang.org/protobuf/reflect/protoreflect"
    15  )
    16  
    17  type handlerFunc func(*muxOptions, grpc.ServerStream) error
    18  
    19  type handler struct {
    20  	method     string
    21  	descriptor protoreflect.MethodDescriptor
    22  	handler    handlerFunc
    23  }
    24  
    25  // TODO: use grpclog?
    26  //var logger = grpclog.Component("core")
    27  
    28  // RegisterService satisfies gprc generated service code.
    29  func (m *Mux) RegisterService(sd *grpc.ServiceDesc, ss interface{}) {
    30  	if ss != nil {
    31  		ht := reflect.TypeOf(sd.HandlerType).Elem()
    32  		st := reflect.TypeOf(ss)
    33  		if !st.Implements(ht) {
    34  			log.Fatalf("grpc: RegisterService found the handler of type %v that does not satisfy %v", st, ht)
    35  		}
    36  	}
    37  	if err := m.registerService(sd, ss); err != nil {
    38  		log.Fatalf("larking: RegisterService error: %v", err)
    39  	}
    40  }
    41  
    42  func (m *Mux) registerService(gsd *grpc.ServiceDesc, ss interface{}) error {
    43  
    44  	// Load the state for writing.
    45  	m.mu.Lock()
    46  	defer m.mu.Unlock()
    47  	s := m.loadState().clone()
    48  
    49  	d, err := m.opts.files.FindDescriptorByName(protoreflect.FullName(gsd.ServiceName))
    50  	if err != nil {
    51  		return err
    52  	}
    53  	sd, ok := d.(protoreflect.ServiceDescriptor)
    54  	if !ok {
    55  		return fmt.Errorf("invalid method descriptor %T", d)
    56  	}
    57  	mds := sd.Methods()
    58  
    59  	findMethod := func(methodName string) (protoreflect.MethodDescriptor, error) {
    60  		md := mds.ByName(protoreflect.Name(methodName))
    61  		if md == nil {
    62  			return nil, fmt.Errorf("missing method descriptor for %v", methodName)
    63  		}
    64  		return md, nil
    65  	}
    66  
    67  	for i := range gsd.Methods {
    68  		d := &gsd.Methods[i]
    69  		method := "/" + gsd.ServiceName + "/" + d.MethodName
    70  
    71  		md, err := findMethod(d.MethodName)
    72  		if err != nil {
    73  			return err
    74  		}
    75  
    76  		rule := getExtensionHTTP(md.Options())
    77  		if rule == nil {
    78  			continue
    79  		}
    80  
    81  		h := &handler{
    82  			method:     method,
    83  			descriptor: md,
    84  			handler: func(opts *muxOptions, stream grpc.ServerStream) error {
    85  				ctx := stream.Context()
    86  
    87  				// TODO: opts?
    88  				reply, err := d.Handler(ss, ctx, stream.RecvMsg, opts.unaryInterceptor)
    89  				if err != nil {
    90  					return err
    91  				}
    92  				return stream.SendMsg(reply)
    93  			},
    94  		}
    95  
    96  		if err := s.appendHandler(rule, md, h); err != nil {
    97  			return err
    98  		}
    99  	}
   100  	for i := range gsd.Streams {
   101  		d := &gsd.Streams[i]
   102  		method := "/" + gsd.ServiceName + "/" + d.StreamName
   103  		md, err := findMethod(d.StreamName)
   104  		if err != nil {
   105  			return err
   106  		}
   107  
   108  		rule := getExtensionHTTP(md.Options())
   109  		if rule == nil {
   110  			continue
   111  		}
   112  
   113  		h := &handler{
   114  			method:     method,
   115  			descriptor: md,
   116  			handler: func(opts *muxOptions, stream grpc.ServerStream) error {
   117  				info := &grpc.StreamServerInfo{
   118  					FullMethod:     method,
   119  					IsClientStream: d.ClientStreams,
   120  					IsServerStream: d.ServerStreams,
   121  				}
   122  
   123  				return opts.stream(ss, stream, info, d.Handler)
   124  			},
   125  		}
   126  		if err := s.appendHandler(rule, md, h); err != nil {
   127  			return err
   128  		}
   129  	}
   130  
   131  	m.storeState(s)
   132  	return nil
   133  }
   134  
   135  var _ grpc.ServerTransportStream = (*serverTransportStream)(nil)
   136  
   137  // serverTransportStream wraps gprc.SeverStream to support header/trailers.
   138  type serverTransportStream struct {
   139  	grpc.ServerStream
   140  	method string
   141  }
   142  
   143  func (s *serverTransportStream) Method() string { return s.method }
   144  func (s *serverTransportStream) SetHeader(md metadata.MD) error {
   145  	return s.ServerStream.SetHeader(md)
   146  }
   147  func (s *serverTransportStream) SendHeader(md metadata.MD) error {
   148  	return s.ServerStream.SendHeader(md)
   149  }
   150  func (s *serverTransportStream) SetTrailer(md metadata.MD) error {
   151  	s.ServerStream.SetTrailer(md)
   152  	return nil
   153  }