github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/interceptor/tokens.go (about) 1 package interceptor 2 3 import ( 4 "context" 5 "fmt" 6 "strings" 7 8 "connectrpc.com/connect" 9 "github.com/quickfeed/quickfeed/web/auth" 10 ) 11 12 type ( 13 // The userIDs interface must be implemented by request types that may need to update the tokens. 14 userIDs interface{ UserIDs() []uint64 } 15 // Marker interface to detect the GroupRequest type needed for DeleteGroup. 16 isGroup interface{ GetGroupID() uint64 } 17 ) 18 19 var defaultTokenUpdater = func(_ context.Context, tm *auth.TokenManager, msg userIDs) error { 20 for _, userID := range msg.UserIDs() { 21 if err := tm.Add(userID); err != nil { 22 return err 23 } 24 } 25 return nil 26 } 27 28 // tokenUpdateMethods is a map of methods that require updating the list of users who need a new JWT. 29 var tokenUpdateMethods = map[string]func(context.Context, *auth.TokenManager, userIDs) error{ 30 "UpdateUser": defaultTokenUpdater, // User has been promoted to admin or demoted. 31 "UpdateGroup": defaultTokenUpdater, // Users added to a group or removed from a group. 32 "UpdateEnrollments": defaultTokenUpdater, // User enrolled into a new course or promoted to TA. 33 34 "CreateCourse": // The signed in user gets the teacher role in the new course. 35 func(ctx context.Context, tm *auth.TokenManager, _ userIDs) error { 36 claims, ok := auth.ClaimsFromContext(ctx) 37 if !ok { 38 return fmt.Errorf("CreateCourse: missing claims in context") 39 } 40 return tm.Add(claims.UserID) 41 }, 42 43 "DeleteGroup": // Group members removed from the group. 44 func(ctx context.Context, tm *auth.TokenManager, msg userIDs) error { 45 if grp, ok := msg.(isGroup); ok { 46 group, err := tm.Database().GetGroup(grp.GetGroupID()) 47 if err != nil { 48 return err 49 } 50 return defaultTokenUpdater(ctx, tm, group) 51 } 52 return connect.NewError(connect.CodePermissionDenied, fmt.Errorf("cannot update token for %s: request does not contain a group", "DeleteGroup")) 53 }, 54 } 55 56 type TokenInterceptor struct { 57 tokenManager *auth.TokenManager 58 } 59 60 func NewTokenInterceptor(tm *auth.TokenManager) *TokenInterceptor { 61 return &TokenInterceptor{tokenManager: tm} 62 } 63 64 func (*TokenInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { 65 return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error { 66 return next(ctx, conn) 67 }) 68 } 69 70 func (*TokenInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { 71 return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { 72 return next(ctx, spec) 73 }) 74 } 75 76 // WrapUnary updates list of users who need a new JWT next time they send a request to the server. 77 // This method only logs errors to avoid overwriting the gRPC error messages returned by the server. 78 func (t *TokenInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { 79 return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { 80 procedure := request.Spec().Procedure 81 method := procedure[strings.LastIndex(procedure, "/")+1:] 82 if tokenUpdateFn, ok := tokenUpdateMethods[method]; ok { 83 if msg, ok := request.Any().(userIDs); ok { 84 if err := tokenUpdateFn(ctx, t.tokenManager, msg); err != nil { 85 return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("cannot update token for %s: %w", method, err)) 86 } 87 } else { 88 return nil, connect.NewError(connect.CodeUnimplemented, fmt.Errorf("cannot update token for %s: message type %T does not implement 'userIDs' interface", method, request)) 89 } 90 } 91 return next(ctx, request) 92 }) 93 }