github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/interceptor/token_auth.go (about)

     1  package interceptor
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"strings"
     7  	"sync"
     8  
     9  	"connectrpc.com/connect"
    10  	"github.com/quickfeed/quickfeed/database"
    11  	"github.com/quickfeed/quickfeed/web/auth"
    12  	"go.uber.org/zap"
    13  	"golang.org/x/oauth2"
    14  )
    15  
    16  const tokenHeader = "Authorization"
    17  
    18  type TokenAuthInterceptor struct {
    19  	tm       *auth.TokenManager
    20  	logger   *zap.SugaredLogger
    21  	db       database.Database
    22  	tokenMap map[string]string
    23  	mu       sync.Mutex
    24  }
    25  
    26  func NewTokenAuthInterceptor(logger *zap.SugaredLogger, tm *auth.TokenManager, db database.Database) *TokenAuthInterceptor {
    27  	return &TokenAuthInterceptor{
    28  		tm:       tm,
    29  		logger:   logger,
    30  		db:       db,
    31  		tokenMap: make(map[string]string),
    32  	}
    33  }
    34  
    35  func (t *TokenAuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
    36  	return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error {
    37  		token := conn.RequestHeader().Get(tokenHeader)
    38  		if len(token) == 0 {
    39  			return next(ctx, conn)
    40  		}
    41  
    42  		cookie, err := t.lookupToken(token)
    43  		if err != nil {
    44  			return err
    45  		}
    46  
    47  		conn.RequestHeader().Set(auth.Cookie, cookie)
    48  		if err = next(ctx, conn); err != nil {
    49  			return err
    50  		}
    51  		updatedCookie := conn.ResponseHeader().Get(auth.SetCookie)
    52  		if len(updatedCookie) != 0 && updatedCookie != cookie {
    53  			t.update(token, updatedCookie)
    54  		}
    55  		return nil
    56  	})
    57  }
    58  
    59  func (*TokenAuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
    60  	return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
    61  		return next(ctx, spec)
    62  	})
    63  }
    64  
    65  func (t *TokenAuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
    66  	return connect.UnaryFunc(func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
    67  		token := request.Header().Get(tokenHeader)
    68  		if len(token) == 0 {
    69  			return next(ctx, request)
    70  		}
    71  
    72  		cookie, err := t.lookupToken(token)
    73  		if err != nil {
    74  			return nil, err
    75  		}
    76  
    77  		request.Header().Set(auth.Cookie, cookie)
    78  		response, err := next(ctx, request)
    79  		if response != nil {
    80  			updatedCookie := response.Header().Get(auth.SetCookie)
    81  			if len(updatedCookie) != 0 && updatedCookie != cookie {
    82  				t.update(token, updatedCookie)
    83  			}
    84  		}
    85  		return response, err
    86  	})
    87  }
    88  
    89  func (t *TokenAuthInterceptor) lookup(token string) (string, bool) {
    90  	t.mu.Lock()
    91  	defer t.mu.Unlock()
    92  	cookie, exists := t.tokenMap[token]
    93  	return cookie, exists
    94  }
    95  
    96  func (t *TokenAuthInterceptor) update(token, cookie string) {
    97  	t.mu.Lock()
    98  	t.tokenMap[token] = cookie
    99  	t.mu.Unlock()
   100  }
   101  
   102  // lookupToken checks if a given token exists in the tokenMap. If it does
   103  // not, it will attempt to query GitHub for user information associated
   104  // with the token. If a user exists for the token, we verify that the user
   105  // exists in our database, and create a cookie with claims for the user.
   106  func (t *TokenAuthInterceptor) lookupToken(token string) (string, error) {
   107  	if cookie, exists := t.lookup(token); exists {
   108  		return cookie, nil
   109  	}
   110  
   111  	// Verify that token has correct prefixes before continuing
   112  	if !(strings.HasPrefix(token, "ghp_") || strings.HasPrefix(token, "github_pat_")) {
   113  		// could also pass through for next interceptor to determine if the request
   114  		// has a valid cookie
   115  		return "", connect.NewError(connect.CodeInvalidArgument, errors.New("invalid token"))
   116  	}
   117  
   118  	// Attempt to fetch user from GitHub using provided token
   119  	externalUser, err := auth.FetchExternalUser(&oauth2.Token{
   120  		AccessToken: token,
   121  	})
   122  	if err != nil {
   123  		return "", connect.NewError(connect.CodeUnauthenticated, err)
   124  	}
   125  	t.logger.Debug("Retrieved user", externalUser)
   126  	// Fetch user from database using the remote identity received from GitHub.
   127  	user, err := t.db.GetUserByRemoteIdentity(externalUser.ID)
   128  	if err != nil {
   129  		return "", connect.NewError(connect.CodeUnauthenticated, err)
   130  	}
   131  
   132  	// Create a new authentication cookie, which contains
   133  	// claims for the user associated with the token
   134  	// received in the request
   135  	cookie, err := t.tm.NewAuthCookie(user.ID)
   136  	if err != nil {
   137  		return "", connect.NewError(connect.CodeUnauthenticated, err)
   138  	}
   139  
   140  	// Store the generated cookie in our token map
   141  	cookieString := cookie.String()
   142  	t.update(token, cookieString)
   143  	return cookieString, nil
   144  }