github.com/oinume/lekcije@v0.0.0-20231017100347-5b4c5eb6ab24/backend/interface/http/oauth.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"os"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/morikuni/failure"
    12  	"go.uber.org/zap"
    13  	"goji.io/v3"
    14  	"goji.io/v3/pat"
    15  	"golang.org/x/oauth2"
    16  	"golang.org/x/oauth2/google"
    17  	google_auth2 "google.golang.org/api/oauth2/v2"
    18  	"google.golang.org/api/option"
    19  
    20  	"github.com/oinume/lekcije/backend/context_data"
    21  	"github.com/oinume/lekcije/backend/domain/config"
    22  	"github.com/oinume/lekcije/backend/errors"
    23  	"github.com/oinume/lekcije/backend/infrastructure/ga_measurement"
    24  	"github.com/oinume/lekcije/backend/model2"
    25  	"github.com/oinume/lekcije/backend/randoms"
    26  	"github.com/oinume/lekcije/backend/registration_email"
    27  	"github.com/oinume/lekcije/backend/usecase"
    28  )
    29  
    30  var googleOAuthConfig = oauth2.Config{
    31  	ClientID:     os.Getenv("GOOGLE_CLIENT_ID"),
    32  	ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
    33  	Endpoint:     google.Endpoint,
    34  	RedirectURL:  "",
    35  	Scopes: []string{
    36  		"openid email",
    37  		"openid profile",
    38  	},
    39  }
    40  
    41  type oauthError int
    42  
    43  const (
    44  	oauthErrorUnknown oauthError = 1 + iota
    45  	oauthErrorAccessDenied
    46  )
    47  
    48  func (e oauthError) Error() string {
    49  	switch e {
    50  	case oauthErrorUnknown:
    51  		return "oauthError: unknown"
    52  	case oauthErrorAccessDenied:
    53  		return "oauthError: access denied"
    54  	}
    55  	return fmt.Sprintf("oauthError: unknown error: %d", int(e))
    56  }
    57  
    58  func checkState(r *http.Request) error {
    59  	state := r.FormValue("state")
    60  	oauthState, err := r.Cookie("oauthState")
    61  	if err != nil {
    62  		return failure.Wrap(err, failure.Messagef("Failed to get cookie oauthState: userAgent=%v, remoteAddr=%v",
    63  			r.UserAgent(), getRemoteAddress(r)))
    64  	}
    65  	if state != oauthState.Value {
    66  		return failure.Wrap(err, failure.Messagef("state mismatch"))
    67  	}
    68  	return nil
    69  }
    70  
    71  func exchange(r *http.Request) (*oauth2.Token, string, error) {
    72  	if e := r.FormValue("error"); e != "" {
    73  		switch e {
    74  		case "access_denied":
    75  			return nil, "", oauthErrorAccessDenied
    76  		default:
    77  			return nil, "", oauthErrorUnknown
    78  		}
    79  	}
    80  	code := r.FormValue("code")
    81  	c := getGoogleOAuthConfig(r)
    82  	token, err := c.Exchange(context.Background(), code)
    83  	if err != nil {
    84  		return nil, "", failure.Wrap(err, failure.Messagef("failed to exchange"))
    85  	}
    86  	idToken, ok := token.Extra("id_token").(string)
    87  	if !ok {
    88  		return nil, "", failure.New(errors.Internal, failure.Messagef("failed to get id_token"))
    89  	}
    90  	return token, idToken, nil
    91  }
    92  
    93  // Returns userId, name, email, error
    94  func getGoogleUserInfo(token *oauth2.Token, idToken string) (string, string, string, error) {
    95  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    96  	defer cancel()
    97  	oauth2Client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
    98  	service, err := google_auth2.NewService(
    99  		context.Background(),
   100  		// TODO: Not sure which is correct
   101  		//option.WithTokenSource(oauth2.StaticTokenSource(token)),
   102  		option.WithHTTPClient(oauth2Client),
   103  	)
   104  	if err != nil {
   105  		return "", "", "", failure.Wrap(err, failure.Messagef("failed to create oauth2.client"))
   106  	}
   107  
   108  	userinfo, err := service.Userinfo.V2.Me.Get().Do()
   109  	if err != nil {
   110  		return "", "", "", failure.Wrap(err, failure.Messagef("failed to get userinfo"))
   111  	}
   112  
   113  	return userinfo.Id, userinfo.Name, userinfo.Email, nil
   114  }
   115  
   116  func getGoogleOAuthConfig(r *http.Request) oauth2.Config {
   117  	c := googleOAuthConfig
   118  	host := r.Header.Get("X-Original-Host") // For ngrok
   119  	if host == "" {
   120  		host = r.Host
   121  	}
   122  	c.RedirectURL = fmt.Sprintf("%s://%s/oauth/google/callback", config.DefaultVars.WebURLScheme(r), host)
   123  	return c
   124  }
   125  
   126  type OAuthServer struct {
   127  	appLogger            *zap.Logger
   128  	errorRecorder        *usecase.ErrorRecorder
   129  	gaMeasurementClient  ga_measurement.Client
   130  	gaMeasurementUsecase *usecase.GAMeasurement
   131  	senderHTTPClient     *http.Client
   132  	userUsecase          *usecase.User
   133  	userAPITokenUsecase  *usecase.UserAPIToken
   134  }
   135  
   136  func NewOAuthServer(
   137  	appLogger *zap.Logger,
   138  	errorRecorder *usecase.ErrorRecorder,
   139  	gaMeasurementClient ga_measurement.Client,
   140  	gaMeasurementUsecase *usecase.GAMeasurement,
   141  	senderHTTPClient *http.Client,
   142  	userUsecase *usecase.User,
   143  	userAPITokenUsecase *usecase.UserAPIToken,
   144  ) *OAuthServer {
   145  	return &OAuthServer{
   146  		appLogger:            appLogger,
   147  		errorRecorder:        errorRecorder,
   148  		gaMeasurementClient:  gaMeasurementClient,
   149  		gaMeasurementUsecase: gaMeasurementUsecase,
   150  		senderHTTPClient:     senderHTTPClient,
   151  		userUsecase:          userUsecase,
   152  		userAPITokenUsecase:  userAPITokenUsecase,
   153  	}
   154  }
   155  
   156  func (s *OAuthServer) Setup(mux *goji.Mux) {
   157  	mux.HandleFunc(pat.Get("/oauth/google"), s.oauthGoogle)
   158  	mux.HandleFunc(pat.Get("/oauth/google/callback"), s.oauthGoogleCallback)
   159  }
   160  
   161  func (s *OAuthServer) oauthGoogle(w http.ResponseWriter, r *http.Request) {
   162  	state := randoms.MustNewString(32)
   163  	cookie := &http.Cookie{
   164  		Name:     "oauthState",
   165  		Value:    state,
   166  		Path:     "/",
   167  		Expires:  time.Now().Add(time.Minute * 30),
   168  		HttpOnly: true,
   169  		// TODO: Secure: true
   170  	}
   171  	http.SetCookie(w, cookie)
   172  	c := getGoogleOAuthConfig(r)
   173  	http.Redirect(w, r, c.AuthCodeURL(state), http.StatusFound)
   174  }
   175  
   176  func (s *OAuthServer) oauthGoogleCallback(w http.ResponseWriter, r *http.Request) {
   177  	if err := checkState(r); err != nil {
   178  		internalServerError(r.Context(), s.errorRecorder, w, err, 0)
   179  		return
   180  	}
   181  	token, idToken, err := exchange(r)
   182  	if err != nil {
   183  		if err == oauthErrorAccessDenied {
   184  			http.Redirect(w, r, "/", http.StatusFound)
   185  			return
   186  		}
   187  		internalServerError(r.Context(), s.errorRecorder, w, err, 0)
   188  		return
   189  	}
   190  	googleID, name, email, err := getGoogleUserInfo(token, idToken)
   191  	if err != nil {
   192  		internalServerError(r.Context(), s.errorRecorder, w, err, 0)
   193  		return
   194  	}
   195  
   196  	ctx := r.Context()
   197  	user, err := s.userUsecase.FindByGoogleID(ctx, googleID)
   198  	userCreated := false
   199  	if err == nil {
   200  		go func() {
   201  			if err := s.gaMeasurementUsecase.SendEvent(
   202  				r.Context(),
   203  				context_data.MustGAMeasurementEvent(ctx),
   204  				model2.GAMeasurementEventCategoryUser,
   205  				"login",
   206  				fmt.Sprint(user.ID),
   207  				0,
   208  				uint32(user.ID),
   209  			); err != nil {
   210  				s.errorRecorder.Record(ctx, err, fmt.Sprint(user.ID))
   211  			}
   212  		}()
   213  	} else {
   214  		if !errors.IsNotFound(err) {
   215  			internalServerError(r.Context(), s.errorRecorder, w, err, 0)
   216  			return
   217  		}
   218  		u, _, err := s.userUsecase.CreateWithGoogle(ctx, name, email, googleID)
   219  		if err != nil {
   220  			if strings.Contains(err.Error(), "Error 1062: Duplicate entry") {
   221  				s.appLogger.Error("duplicate entry from CreateWithGoogle", zap.Error(err), zap.String("googleID", googleID))
   222  			}
   223  			internalServerError(r.Context(), s.errorRecorder, w, err, 0)
   224  			return
   225  		}
   226  		userCreated = true
   227  		user = u
   228  		go func() {
   229  			if err := s.gaMeasurementUsecase.SendEvent(
   230  				r.Context(),
   231  				context_data.MustGAMeasurementEvent(ctx),
   232  				model2.GAMeasurementEventCategoryUser,
   233  				"create",
   234  				fmt.Sprint(user.ID),
   235  				0,
   236  				uint32(user.ID),
   237  			); err != nil {
   238  				s.errorRecorder.Record(ctx, err, fmt.Sprint(user.ID))
   239  			}
   240  		}()
   241  	}
   242  
   243  	userAPIToken, err := s.userAPITokenUsecase.Create(ctx, user.ID)
   244  	if err != nil {
   245  		internalServerError(r.Context(), s.errorRecorder, w, err, uint32(user.ID))
   246  		return
   247  	}
   248  	s.appLogger.Debug(fmt.Sprintf("userCreated = %v", userCreated))
   249  
   250  	if userCreated {
   251  		// TODO: Move to usecase layer
   252  		// Record registration email
   253  		go func() {
   254  			sender := registration_email.NewEmailSender(s.senderHTTPClient, s.appLogger)
   255  			if err := sender.Send(r.Context(), user); err != nil {
   256  				s.appLogger.Error(
   257  					"Failed to send registration email",
   258  					zap.String("email", user.Email), zap.Error(err),
   259  				)
   260  				s.errorRecorder.Record(r.Context(), err, fmt.Sprint(user.ID))
   261  			}
   262  		}()
   263  	}
   264  
   265  	cookie := &http.Cookie{
   266  		Name:     APITokenCookieName,
   267  		Value:    userAPIToken.Token,
   268  		Path:     "/",
   269  		Expires:  time.Now().Add(model2.UserAPITokenExpiration),
   270  		HttpOnly: false,
   271  	}
   272  	http.SetCookie(w, cookie)
   273  	http.Redirect(w, r, "/me", http.StatusFound)
   274  }