github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/interceptor/req_validation.go (about) 1 package interceptor 2 3 import ( 4 "context" 5 "errors" 6 "time" 7 8 "connectrpc.com/connect" 9 "go.uber.org/zap" 10 ) 11 12 // MaxWait is the maximum time a request is allowed to stay open before aborting. 13 const MaxWait = 2 * time.Minute 14 15 // validator should be implemented by request types to validate its content. 16 type validator interface { 17 IsValid() bool 18 } 19 20 // idCleaner should be implemented by response types that have a remote ID that should be removed. 21 type idCleaner interface { 22 RemoveRemoteID() 23 } 24 25 type ValidationInterceptor struct { 26 logger *zap.SugaredLogger 27 } 28 29 func NewValidationInterceptor(logger *zap.SugaredLogger) *ValidationInterceptor { 30 return &ValidationInterceptor{logger: logger} 31 } 32 33 func (*ValidationInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { 34 return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error { 35 return next(ctx, conn) 36 }) 37 } 38 39 func (*ValidationInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { 40 return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { 41 return next(ctx, spec) 42 }) 43 } 44 45 // WrapUnary returns a new unary server interceptor that validates requests 46 // that implements the validator interface. 47 // Invalid requests are rejected without logging and before it reaches any 48 // user-level code and returns an illegal argument to the client. 49 // Further, the response values are cleaned of any remote IDs. 50 // In addition, the interceptor also implements a cancellation mechanism. 51 func (v *ValidationInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { 52 return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { 53 if request.Any() != nil { 54 if err := validate(v.logger, request.Any()); err != nil { 55 // Reject the request if it is invalid. 56 return nil, err 57 } 58 } 59 resp, err := next(ctx, request) 60 if err != nil { 61 // Do not return the message to the client if an error occurs. 62 // We log the error and return an empty response. 63 v.logger.Errorf("Method '%s' failed: %v", request.Spec().Procedure, err) 64 v.logger.Errorf("Request Message: %T: %v", request.Any(), request.Any()) 65 return nil, err 66 } 67 clean(resp.Any()) 68 return resp, err 69 }) 70 } 71 72 func validate(logger *zap.SugaredLogger, req interface{}) error { 73 if v, ok := req.(validator); ok { 74 if !v.IsValid() { 75 return connect.NewError(connect.CodeInvalidArgument, errors.New("invalid payload")) 76 } 77 } else { 78 // just logging, but still handling the call 79 logger.Debugf("message type %T does not implement validator interface", req) 80 } 81 return nil 82 } 83 84 func clean(resp interface{}) { 85 if resp != nil { 86 if v, ok := resp.(idCleaner); ok { 87 v.RemoveRemoteID() 88 } 89 } 90 }