github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/interceptor/user_auth.go (about)

     1  package interceptor
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  
     8  	"connectrpc.com/connect"
     9  	"go.uber.org/zap"
    10  
    11  	"github.com/quickfeed/quickfeed/web/auth"
    12  )
    13  
    14  type UserInterceptor struct {
    15  	tm     *auth.TokenManager
    16  	logger *zap.SugaredLogger
    17  }
    18  
    19  func NewUserInterceptor(logger *zap.SugaredLogger, tm *auth.TokenManager) *UserInterceptor {
    20  	return &UserInterceptor{
    21  		tm:     tm,
    22  		logger: logger,
    23  	}
    24  }
    25  
    26  func (u *UserInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
    27  	return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error {
    28  		claims, updatedCookie, err := u.processHeader(conn.RequestHeader())
    29  		if err != nil {
    30  			return err
    31  		}
    32  		if updatedCookie != nil {
    33  			conn.ResponseHeader().Set(auth.SetCookie, updatedCookie.String())
    34  		}
    35  		return next(claims.Context(ctx), conn)
    36  	})
    37  }
    38  
    39  func (*UserInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
    40  	return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
    41  		return next(ctx, spec)
    42  	})
    43  }
    44  
    45  // WrapUnary returns a unary server interceptor verifying that the user is authenticated.
    46  // The request's session cookie is verified that it contains a valid JWT claim.
    47  // If a valid claim is found, the interceptor injects the user ID as metadata in the incoming context
    48  // for service methods that come after this interceptor.
    49  // The interceptor also updates the session cookie if needed.
    50  func (u *UserInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
    51  	return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
    52  		claims, updatedCookie, err := u.processHeader(request.Header())
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  		response, err := next(claims.Context(ctx), request)
    57  		if err != nil {
    58  			return nil, err
    59  		}
    60  		if updatedCookie != nil {
    61  			response.Header().Set(auth.SetCookie, updatedCookie.String())
    62  		}
    63  		return response, nil
    64  	})
    65  }
    66  
    67  // processHeader returns claims extracted from the given http.Header's cookie
    68  // and an updated cookie if needed. An error is returned if the cookie is invalid
    69  // or could not be updated.
    70  func (u *UserInterceptor) processHeader(header http.Header) (*auth.Claims, *http.Cookie, error) {
    71  	cookie := header.Get(auth.Cookie)
    72  	claims, err := u.tm.GetClaims(cookie)
    73  	if err != nil {
    74  		return nil, nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("failed to extract JWT claims from session cookie: %w", err))
    75  	}
    76  	updatedCookie, err := u.tm.UpdateCookie(claims)
    77  	if err != nil {
    78  		return claims, nil, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("failed to update session cookie: %w", err))
    79  	}
    80  	if updatedCookie == nil {
    81  		return claims, nil, nil
    82  	}
    83  	return claims, updatedCookie, nil
    84  }