code.gitea.io/gitea@v1.22.3/services/auth/source/oauth2/init.go (about) 1 // Copyright 2021 The Gitea Authors. All rights reserved. 2 // SPDX-License-Identifier: MIT 3 4 package oauth2 5 6 import ( 7 "context" 8 "encoding/gob" 9 "net/http" 10 "sync" 11 12 "code.gitea.io/gitea/models/auth" 13 "code.gitea.io/gitea/models/db" 14 "code.gitea.io/gitea/modules/log" 15 "code.gitea.io/gitea/modules/optional" 16 "code.gitea.io/gitea/modules/setting" 17 18 "github.com/google/uuid" 19 "github.com/gorilla/sessions" 20 "github.com/markbates/goth/gothic" 21 ) 22 23 var gothRWMutex = sync.RWMutex{} 24 25 // UsersStoreKey is the key for the store 26 const UsersStoreKey = "gitea-oauth2-sessions" 27 28 // ProviderHeaderKey is the HTTP header key 29 const ProviderHeaderKey = "gitea-oauth2-provider" 30 31 // Init initializes the oauth source 32 func Init(ctx context.Context) error { 33 // this is for oauth2 provider 34 if setting.OAuth2.Enabled { 35 if err := InitSigningKey(); err != nil { 36 return err 37 } 38 } 39 40 // others for oauth2 clients 41 // Lock our mutex 42 gothRWMutex.Lock() 43 44 gob.Register(&sessions.Session{}) 45 46 gothic.Store = &SessionsStore{ 47 maxLength: int64(setting.OAuth2.MaxTokenLength), 48 } 49 50 gothic.SetState = func(req *http.Request) string { 51 return uuid.New().String() 52 } 53 54 gothic.GetProviderName = func(req *http.Request) (string, error) { 55 return req.Header.Get(ProviderHeaderKey), nil 56 } 57 58 // Unlock our mutex 59 gothRWMutex.Unlock() 60 61 return initOAuth2Sources(ctx) 62 } 63 64 // ResetOAuth2 clears existing OAuth2 providers and loads them from DB 65 func ResetOAuth2(ctx context.Context) error { 66 ClearProviders() 67 return initOAuth2Sources(ctx) 68 } 69 70 // initOAuth2Sources is used to load and register all active OAuth2 providers 71 func initOAuth2Sources(ctx context.Context) error { 72 authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{ 73 IsActive: optional.Some(true), 74 LoginType: auth.OAuth2, 75 }) 76 if err != nil { 77 return err 78 } 79 for _, source := range authSources { 80 oauth2Source, ok := source.Cfg.(*Source) 81 if !ok { 82 continue 83 } 84 err := oauth2Source.RegisterSource() 85 if err != nil { 86 log.Critical("Unable to register source: %s due to Error: %v.", source.Name, err) 87 } 88 } 89 return nil 90 }