code.gitea.io/gitea@v1.21.7/services/auth/source/oauth2/init.go (about)

     1  // Copyright 2021 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  package oauth2
     5  
     6  import (
     7  	"encoding/gob"
     8  	"net/http"
     9  	"sync"
    10  
    11  	"code.gitea.io/gitea/models/auth"
    12  	"code.gitea.io/gitea/modules/log"
    13  	"code.gitea.io/gitea/modules/setting"
    14  
    15  	"github.com/google/uuid"
    16  	"github.com/gorilla/sessions"
    17  	"github.com/markbates/goth/gothic"
    18  )
    19  
    20  var gothRWMutex = sync.RWMutex{}
    21  
    22  // UsersStoreKey is the key for the store
    23  const UsersStoreKey = "gitea-oauth2-sessions"
    24  
    25  // ProviderHeaderKey is the HTTP header key
    26  const ProviderHeaderKey = "gitea-oauth2-provider"
    27  
    28  // Init initializes the oauth source
    29  func Init() error {
    30  	if err := InitSigningKey(); err != nil {
    31  		return err
    32  	}
    33  
    34  	// Lock our mutex
    35  	gothRWMutex.Lock()
    36  
    37  	gob.Register(&sessions.Session{})
    38  
    39  	gothic.Store = &SessionsStore{
    40  		maxLength: int64(setting.OAuth2.MaxTokenLength),
    41  	}
    42  
    43  	gothic.SetState = func(req *http.Request) string {
    44  		return uuid.New().String()
    45  	}
    46  
    47  	gothic.GetProviderName = func(req *http.Request) (string, error) {
    48  		return req.Header.Get(ProviderHeaderKey), nil
    49  	}
    50  
    51  	// Unlock our mutex
    52  	gothRWMutex.Unlock()
    53  
    54  	return initOAuth2Sources()
    55  }
    56  
    57  // ResetOAuth2 clears existing OAuth2 providers and loads them from DB
    58  func ResetOAuth2() error {
    59  	ClearProviders()
    60  	return initOAuth2Sources()
    61  }
    62  
    63  // initOAuth2Sources is used to load and register all active OAuth2 providers
    64  func initOAuth2Sources() error {
    65  	authSources, _ := auth.GetOAuth2ProviderSources(true)
    66  	for _, source := range authSources {
    67  		oauth2Source, ok := source.Cfg.(*Source)
    68  		if !ok {
    69  			continue
    70  		}
    71  		err := oauth2Source.RegisterSource()
    72  		if err != nil {
    73  			log.Critical("Unable to register source: %s due to Error: %v.", source.Name, err)
    74  		}
    75  	}
    76  	return nil
    77  }