github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/interceptor/token_auth.go (about) 1 package interceptor 2 3 import ( 4 "context" 5 "errors" 6 "strings" 7 "sync" 8 9 "connectrpc.com/connect" 10 "github.com/quickfeed/quickfeed/database" 11 "github.com/quickfeed/quickfeed/web/auth" 12 "go.uber.org/zap" 13 "golang.org/x/oauth2" 14 ) 15 16 const tokenHeader = "Authorization" 17 18 type TokenAuthInterceptor struct { 19 tm *auth.TokenManager 20 logger *zap.SugaredLogger 21 db database.Database 22 tokenMap map[string]string 23 mu sync.Mutex 24 } 25 26 func NewTokenAuthInterceptor(logger *zap.SugaredLogger, tm *auth.TokenManager, db database.Database) *TokenAuthInterceptor { 27 return &TokenAuthInterceptor{ 28 tm: tm, 29 logger: logger, 30 db: db, 31 tokenMap: make(map[string]string), 32 } 33 } 34 35 func (t *TokenAuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { 36 return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error { 37 token := conn.RequestHeader().Get(tokenHeader) 38 if len(token) == 0 { 39 return next(ctx, conn) 40 } 41 42 cookie, err := t.lookupToken(token) 43 if err != nil { 44 return err 45 } 46 47 conn.RequestHeader().Set(auth.Cookie, cookie) 48 if err = next(ctx, conn); err != nil { 49 return err 50 } 51 updatedCookie := conn.ResponseHeader().Get(auth.SetCookie) 52 if len(updatedCookie) != 0 && updatedCookie != cookie { 53 t.update(token, updatedCookie) 54 } 55 return nil 56 }) 57 } 58 59 func (*TokenAuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { 60 return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { 61 return next(ctx, spec) 62 }) 63 } 64 65 func (t *TokenAuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { 66 return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { 67 token := request.Header().Get(tokenHeader) 68 if len(token) == 0 { 69 return next(ctx, request) 70 } 71 72 cookie, err := t.lookupToken(token) 73 if err != nil { 74 return nil, err 75 } 76 77 request.Header().Set(auth.Cookie, cookie) 78 response, err := next(ctx, request) 79 if response != nil { 80 updatedCookie := response.Header().Get(auth.SetCookie) 81 if len(updatedCookie) != 0 && updatedCookie != cookie { 82 t.update(token, updatedCookie) 83 } 84 } 85 return response, err 86 }) 87 } 88 89 func (t *TokenAuthInterceptor) lookup(token string) (string, bool) { 90 t.mu.Lock() 91 defer t.mu.Unlock() 92 cookie, exists := t.tokenMap[token] 93 return cookie, exists 94 } 95 96 func (t *TokenAuthInterceptor) update(token, cookie string) { 97 t.mu.Lock() 98 t.tokenMap[token] = cookie 99 t.mu.Unlock() 100 } 101 102 // lookupToken checks if a given token exists in the tokenMap. If it does 103 // not, it will attempt to query GitHub for user information associated 104 // with the token. If a user exists for the token, we verify that the user 105 // exists in our database, and create a cookie with claims for the user. 106 func (t *TokenAuthInterceptor) lookupToken(token string) (string, error) { 107 if cookie, exists := t.lookup(token); exists { 108 return cookie, nil 109 } 110 111 // Verify that token has correct prefixes before continuing 112 if !(strings.HasPrefix(token, "ghp_") || strings.HasPrefix(token, "github_pat_")) { 113 // could also pass through for next interceptor to determine if the request 114 // has a valid cookie 115 return "", connect.NewError(connect.CodeInvalidArgument, errors.New("invalid token")) 116 } 117 118 // Attempt to fetch user from GitHub using provided token 119 externalUser, err := auth.FetchExternalUser(&oauth2.Token{ 120 AccessToken: token, 121 }) 122 if err != nil { 123 return "", connect.NewError(connect.CodeUnauthenticated, err) 124 } 125 t.logger.Debug("Retrieved user", externalUser) 126 // Fetch user from database using the remote identity received from GitHub. 127 user, err := t.db.GetUserByRemoteIdentity(externalUser.ID) 128 if err != nil { 129 return "", connect.NewError(connect.CodeUnauthenticated, err) 130 } 131 132 // Create a new authentication cookie, which contains 133 // claims for the user associated with the token 134 // received in the request 135 cookie, err := t.tm.NewAuthCookie(user.ID) 136 if err != nil { 137 return "", connect.NewError(connect.CodeUnauthenticated, err) 138 } 139 140 // Store the generated cookie in our token map 141 cookieString := cookie.String() 142 t.update(token, cookieString) 143 return cookieString, nil 144 }