github.com/oinume/lekcije@v0.0.0-20231017100347-5b4c5eb6ab24/backend/interface/http/oauth.go (about) 1 package http 2 3 import ( 4 "context" 5 "fmt" 6 "net/http" 7 "os" 8 "strings" 9 "time" 10 11 "github.com/morikuni/failure" 12 "go.uber.org/zap" 13 "goji.io/v3" 14 "goji.io/v3/pat" 15 "golang.org/x/oauth2" 16 "golang.org/x/oauth2/google" 17 google_auth2 "google.golang.org/api/oauth2/v2" 18 "google.golang.org/api/option" 19 20 "github.com/oinume/lekcije/backend/context_data" 21 "github.com/oinume/lekcije/backend/domain/config" 22 "github.com/oinume/lekcije/backend/errors" 23 "github.com/oinume/lekcije/backend/infrastructure/ga_measurement" 24 "github.com/oinume/lekcije/backend/model2" 25 "github.com/oinume/lekcije/backend/randoms" 26 "github.com/oinume/lekcije/backend/registration_email" 27 "github.com/oinume/lekcije/backend/usecase" 28 ) 29 30 var googleOAuthConfig = oauth2.Config{ 31 ClientID: os.Getenv("GOOGLE_CLIENT_ID"), 32 ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"), 33 Endpoint: google.Endpoint, 34 RedirectURL: "", 35 Scopes: []string{ 36 "openid email", 37 "openid profile", 38 }, 39 } 40 41 type oauthError int 42 43 const ( 44 oauthErrorUnknown oauthError = 1 + iota 45 oauthErrorAccessDenied 46 ) 47 48 func (e oauthError) Error() string { 49 switch e { 50 case oauthErrorUnknown: 51 return "oauthError: unknown" 52 case oauthErrorAccessDenied: 53 return "oauthError: access denied" 54 } 55 return fmt.Sprintf("oauthError: unknown error: %d", int(e)) 56 } 57 58 func checkState(r *http.Request) error { 59 state := r.FormValue("state") 60 oauthState, err := r.Cookie("oauthState") 61 if err != nil { 62 return failure.Wrap(err, failure.Messagef("Failed to get cookie oauthState: userAgent=%v, remoteAddr=%v", 63 r.UserAgent(), getRemoteAddress(r))) 64 } 65 if state != oauthState.Value { 66 return failure.Wrap(err, failure.Messagef("state mismatch")) 67 } 68 return nil 69 } 70 71 func exchange(r *http.Request) (*oauth2.Token, string, error) { 72 if e := r.FormValue("error"); e != "" { 73 switch e { 74 case "access_denied": 75 return nil, "", oauthErrorAccessDenied 76 default: 77 return nil, "", oauthErrorUnknown 78 } 79 } 80 code := r.FormValue("code") 81 c := getGoogleOAuthConfig(r) 82 token, err := c.Exchange(context.Background(), code) 83 if err != nil { 84 return nil, "", failure.Wrap(err, failure.Messagef("failed to exchange")) 85 } 86 idToken, ok := token.Extra("id_token").(string) 87 if !ok { 88 return nil, "", failure.New(errors.Internal, failure.Messagef("failed to get id_token")) 89 } 90 return token, idToken, nil 91 } 92 93 // Returns userId, name, email, error 94 func getGoogleUserInfo(token *oauth2.Token, idToken string) (string, string, string, error) { 95 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 96 defer cancel() 97 oauth2Client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) 98 service, err := google_auth2.NewService( 99 context.Background(), 100 // TODO: Not sure which is correct 101 //option.WithTokenSource(oauth2.StaticTokenSource(token)), 102 option.WithHTTPClient(oauth2Client), 103 ) 104 if err != nil { 105 return "", "", "", failure.Wrap(err, failure.Messagef("failed to create oauth2.client")) 106 } 107 108 userinfo, err := service.Userinfo.V2.Me.Get().Do() 109 if err != nil { 110 return "", "", "", failure.Wrap(err, failure.Messagef("failed to get userinfo")) 111 } 112 113 return userinfo.Id, userinfo.Name, userinfo.Email, nil 114 } 115 116 func getGoogleOAuthConfig(r *http.Request) oauth2.Config { 117 c := googleOAuthConfig 118 host := r.Header.Get("X-Original-Host") // For ngrok 119 if host == "" { 120 host = r.Host 121 } 122 c.RedirectURL = fmt.Sprintf("%s://%s/oauth/google/callback", config.DefaultVars.WebURLScheme(r), host) 123 return c 124 } 125 126 type OAuthServer struct { 127 appLogger *zap.Logger 128 errorRecorder *usecase.ErrorRecorder 129 gaMeasurementClient ga_measurement.Client 130 gaMeasurementUsecase *usecase.GAMeasurement 131 senderHTTPClient *http.Client 132 userUsecase *usecase.User 133 userAPITokenUsecase *usecase.UserAPIToken 134 } 135 136 func NewOAuthServer( 137 appLogger *zap.Logger, 138 errorRecorder *usecase.ErrorRecorder, 139 gaMeasurementClient ga_measurement.Client, 140 gaMeasurementUsecase *usecase.GAMeasurement, 141 senderHTTPClient *http.Client, 142 userUsecase *usecase.User, 143 userAPITokenUsecase *usecase.UserAPIToken, 144 ) *OAuthServer { 145 return &OAuthServer{ 146 appLogger: appLogger, 147 errorRecorder: errorRecorder, 148 gaMeasurementClient: gaMeasurementClient, 149 gaMeasurementUsecase: gaMeasurementUsecase, 150 senderHTTPClient: senderHTTPClient, 151 userUsecase: userUsecase, 152 userAPITokenUsecase: userAPITokenUsecase, 153 } 154 } 155 156 func (s *OAuthServer) Setup(mux *goji.Mux) { 157 mux.HandleFunc(pat.Get("/oauth/google"), s.oauthGoogle) 158 mux.HandleFunc(pat.Get("/oauth/google/callback"), s.oauthGoogleCallback) 159 } 160 161 func (s *OAuthServer) oauthGoogle(w http.ResponseWriter, r *http.Request) { 162 state := randoms.MustNewString(32) 163 cookie := &http.Cookie{ 164 Name: "oauthState", 165 Value: state, 166 Path: "/", 167 Expires: time.Now().Add(time.Minute * 30), 168 HttpOnly: true, 169 // TODO: Secure: true 170 } 171 http.SetCookie(w, cookie) 172 c := getGoogleOAuthConfig(r) 173 http.Redirect(w, r, c.AuthCodeURL(state), http.StatusFound) 174 } 175 176 func (s *OAuthServer) oauthGoogleCallback(w http.ResponseWriter, r *http.Request) { 177 if err := checkState(r); err != nil { 178 internalServerError(r.Context(), s.errorRecorder, w, err, 0) 179 return 180 } 181 token, idToken, err := exchange(r) 182 if err != nil { 183 if err == oauthErrorAccessDenied { 184 http.Redirect(w, r, "/", http.StatusFound) 185 return 186 } 187 internalServerError(r.Context(), s.errorRecorder, w, err, 0) 188 return 189 } 190 googleID, name, email, err := getGoogleUserInfo(token, idToken) 191 if err != nil { 192 internalServerError(r.Context(), s.errorRecorder, w, err, 0) 193 return 194 } 195 196 ctx := r.Context() 197 user, err := s.userUsecase.FindByGoogleID(ctx, googleID) 198 userCreated := false 199 if err == nil { 200 go func() { 201 if err := s.gaMeasurementUsecase.SendEvent( 202 r.Context(), 203 context_data.MustGAMeasurementEvent(ctx), 204 model2.GAMeasurementEventCategoryUser, 205 "login", 206 fmt.Sprint(user.ID), 207 0, 208 uint32(user.ID), 209 ); err != nil { 210 s.errorRecorder.Record(ctx, err, fmt.Sprint(user.ID)) 211 } 212 }() 213 } else { 214 if !errors.IsNotFound(err) { 215 internalServerError(r.Context(), s.errorRecorder, w, err, 0) 216 return 217 } 218 u, _, err := s.userUsecase.CreateWithGoogle(ctx, name, email, googleID) 219 if err != nil { 220 if strings.Contains(err.Error(), "Error 1062: Duplicate entry") { 221 s.appLogger.Error("duplicate entry from CreateWithGoogle", zap.Error(err), zap.String("googleID", googleID)) 222 } 223 internalServerError(r.Context(), s.errorRecorder, w, err, 0) 224 return 225 } 226 userCreated = true 227 user = u 228 go func() { 229 if err := s.gaMeasurementUsecase.SendEvent( 230 r.Context(), 231 context_data.MustGAMeasurementEvent(ctx), 232 model2.GAMeasurementEventCategoryUser, 233 "create", 234 fmt.Sprint(user.ID), 235 0, 236 uint32(user.ID), 237 ); err != nil { 238 s.errorRecorder.Record(ctx, err, fmt.Sprint(user.ID)) 239 } 240 }() 241 } 242 243 userAPIToken, err := s.userAPITokenUsecase.Create(ctx, user.ID) 244 if err != nil { 245 internalServerError(r.Context(), s.errorRecorder, w, err, uint32(user.ID)) 246 return 247 } 248 s.appLogger.Debug(fmt.Sprintf("userCreated = %v", userCreated)) 249 250 if userCreated { 251 // TODO: Move to usecase layer 252 // Record registration email 253 go func() { 254 sender := registration_email.NewEmailSender(s.senderHTTPClient, s.appLogger) 255 if err := sender.Send(r.Context(), user); err != nil { 256 s.appLogger.Error( 257 "Failed to send registration email", 258 zap.String("email", user.Email), zap.Error(err), 259 ) 260 s.errorRecorder.Record(r.Context(), err, fmt.Sprint(user.ID)) 261 } 262 }() 263 } 264 265 cookie := &http.Cookie{ 266 Name: APITokenCookieName, 267 Value: userAPIToken.Token, 268 Path: "/", 269 Expires: time.Now().Add(model2.UserAPITokenExpiration), 270 HttpOnly: false, 271 } 272 http.SetCookie(w, cookie) 273 http.Redirect(w, r, "/me", http.StatusFound) 274 }