github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/go-grpc-middleware/validator/validator.go (about) 1 // Copyright 2016 Michal Witkowski. All Rights Reserved. 2 // See LICENSE for licensing terms. 3 4 package grpc_validator 5 6 import ( 7 "github.com/hxx258456/ccgo/grpc" 8 "github.com/hxx258456/ccgo/grpc/codes" 9 "github.com/hxx258456/ccgo/net/context" 10 ) 11 12 type validator interface { 13 Validate() error 14 } 15 16 // UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages. 17 // 18 // Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers. 19 func UnaryServerInterceptor() grpc.UnaryServerInterceptor { 20 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 21 if v, ok := req.(validator); ok { 22 if err := v.Validate(); err != nil { 23 return nil, grpc.Errorf(codes.InvalidArgument, err.Error()) 24 } 25 } 26 return handler(ctx, req) 27 } 28 } 29 30 // StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages. 31 // 32 // The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the 33 // type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace 34 // handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on 35 // calls to `stream.Recv()`. 36 func StreamServerInterceptor() grpc.StreamServerInterceptor { 37 return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 38 wrapper := &recvWrapper{stream} 39 return handler(srv, wrapper) 40 } 41 } 42 43 type recvWrapper struct { 44 grpc.ServerStream 45 } 46 47 func (s *recvWrapper) RecvMsg(m interface{}) error { 48 if err := s.ServerStream.RecvMsg(m); err != nil { 49 return err 50 } 51 if v, ok := m.(validator); ok { 52 if err := v.Validate(); err != nil { 53 return grpc.Errorf(codes.InvalidArgument, err.Error()) 54 } 55 } 56 return nil 57 }