github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/interceptor/req_validation.go (about)

     1  package interceptor
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"time"
     7  
     8  	"connectrpc.com/connect"
     9  	"go.uber.org/zap"
    10  )
    11  
    12  // MaxWait is the maximum time a request is allowed to stay open before aborting.
    13  const MaxWait = 2 * time.Minute
    14  
    15  // validator should be implemented by request types to validate its content.
    16  type validator interface {
    17  	IsValid() bool
    18  }
    19  
    20  // idCleaner should be implemented by response types that have a remote ID that should be removed.
    21  type idCleaner interface {
    22  	RemoveRemoteID()
    23  }
    24  
    25  type ValidationInterceptor struct {
    26  	logger *zap.SugaredLogger
    27  }
    28  
    29  func NewValidationInterceptor(logger *zap.SugaredLogger) *ValidationInterceptor {
    30  	return &ValidationInterceptor{logger: logger}
    31  }
    32  
    33  func (*ValidationInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
    34  	return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error {
    35  		return next(ctx, conn)
    36  	})
    37  }
    38  
    39  func (*ValidationInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
    40  	return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
    41  		return next(ctx, spec)
    42  	})
    43  }
    44  
    45  // WrapUnary returns a new unary server interceptor that validates requests
    46  // that implements the validator interface.
    47  // Invalid requests are rejected without logging and before it reaches any
    48  // user-level code and returns an illegal argument to the client.
    49  // Further, the response values are cleaned of any remote IDs.
    50  // In addition, the interceptor also implements a cancellation mechanism.
    51  func (v *ValidationInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
    52  	return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
    53  		if request.Any() != nil {
    54  			if err := validate(v.logger, request.Any()); err != nil {
    55  				// Reject the request if it is invalid.
    56  				return nil, err
    57  			}
    58  		}
    59  		resp, err := next(ctx, request)
    60  		if err != nil {
    61  			// Do not return the message to the client if an error occurs.
    62  			// We log the error and return an empty response.
    63  			v.logger.Errorf("Method '%s' failed: %v", request.Spec().Procedure, err)
    64  			v.logger.Errorf("Request Message: %T: %v", request.Any(), request.Any())
    65  			return nil, err
    66  		}
    67  		clean(resp.Any())
    68  		return resp, err
    69  	})
    70  }
    71  
    72  func validate(logger *zap.SugaredLogger, req interface{}) error {
    73  	if v, ok := req.(validator); ok {
    74  		if !v.IsValid() {
    75  			return connect.NewError(connect.CodeInvalidArgument, errors.New("invalid payload"))
    76  		}
    77  	} else {
    78  		// just logging, but still handling the call
    79  		logger.Debugf("message type %T does not implement validator interface", req)
    80  	}
    81  	return nil
    82  }
    83  
    84  func clean(resp interface{}) {
    85  	if resp != nil {
    86  		if v, ok := resp.(idCleaner); ok {
    87  			v.RemoveRemoteID()
    88  		}
    89  	}
    90  }