github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/interceptor/tokens.go (about)

     1  package interceptor
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"connectrpc.com/connect"
     9  	"github.com/quickfeed/quickfeed/web/auth"
    10  )
    11  
    12  type (
    13  	// The userIDs interface must be implemented by request types that may need to update the tokens.
    14  	userIDs interface{ UserIDs() []uint64 }
    15  	// Marker interface to detect the GroupRequest type needed for DeleteGroup.
    16  	isGroup interface{ GetGroupID() uint64 }
    17  )
    18  
    19  var defaultTokenUpdater = func(_ context.Context, tm *auth.TokenManager, msg userIDs) error {
    20  	for _, userID := range msg.UserIDs() {
    21  		if err := tm.Add(userID); err != nil {
    22  			return err
    23  		}
    24  	}
    25  	return nil
    26  }
    27  
    28  // tokenUpdateMethods is a map of methods that require updating the list of users who need a new JWT.
    29  var tokenUpdateMethods = map[string]func(context.Context, *auth.TokenManager, userIDs) error{
    30  	"UpdateUser":        defaultTokenUpdater, // User has been promoted to admin or demoted.
    31  	"UpdateGroup":       defaultTokenUpdater, // Users added to a group or removed from a group.
    32  	"UpdateEnrollments": defaultTokenUpdater, // User enrolled into a new course or promoted to TA.
    33  
    34  	"CreateCourse": // The signed in user gets the teacher role in the new course.
    35  	func(ctx context.Context, tm *auth.TokenManager, _ userIDs) error {
    36  		claims, ok := auth.ClaimsFromContext(ctx)
    37  		if !ok {
    38  			return fmt.Errorf("CreateCourse: missing claims in context")
    39  		}
    40  		return tm.Add(claims.UserID)
    41  	},
    42  
    43  	"DeleteGroup": // Group members removed from the group.
    44  	func(ctx context.Context, tm *auth.TokenManager, msg userIDs) error {
    45  		if grp, ok := msg.(isGroup); ok {
    46  			group, err := tm.Database().GetGroup(grp.GetGroupID())
    47  			if err != nil {
    48  				return err
    49  			}
    50  			return defaultTokenUpdater(ctx, tm, group)
    51  		}
    52  		return connect.NewError(connect.CodePermissionDenied, fmt.Errorf("cannot update token for %s: request does not contain a group", "DeleteGroup"))
    53  	},
    54  }
    55  
    56  type TokenInterceptor struct {
    57  	tokenManager *auth.TokenManager
    58  }
    59  
    60  func NewTokenInterceptor(tm *auth.TokenManager) *TokenInterceptor {
    61  	return &TokenInterceptor{tokenManager: tm}
    62  }
    63  
    64  func (*TokenInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
    65  	return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error {
    66  		return next(ctx, conn)
    67  	})
    68  }
    69  
    70  func (*TokenInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
    71  	return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
    72  		return next(ctx, spec)
    73  	})
    74  }
    75  
    76  // WrapUnary updates list of users who need a new JWT next time they send a request to the server.
    77  // This method only logs errors to avoid overwriting the gRPC error messages returned by the server.
    78  func (t *TokenInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
    79  	return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
    80  		procedure := request.Spec().Procedure
    81  		method := procedure[strings.LastIndex(procedure, "/")+1:]
    82  		if tokenUpdateFn, ok := tokenUpdateMethods[method]; ok {
    83  			if msg, ok := request.Any().(userIDs); ok {
    84  				if err := tokenUpdateFn(ctx, t.tokenManager, msg); err != nil {
    85  					return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("cannot update token for %s: %w", method, err))
    86  				}
    87  			} else {
    88  				return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("cannot update token for %s: message type %T does not implement 'userIDs' interface", method, request))
    89  			}
    90  		}
    91  		return next(ctx, request)
    92  	})
    93  }