github.com/kyleu/dbaudit@v0.0.2-0.20240321155047-ff2f2c940496/app/lib/auth/msfix/provider.go (about)

     1  // Package msfix - Content managed by Project Forge, see [projectforge.md] for details.
     2  package msfix
     3  
     4  import (
     5  	"bytes"
     6  	"context"
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"strings"
    12  
    13  	"github.com/markbates/going/defaults"
    14  	"github.com/markbates/goth"
    15  	"github.com/pkg/errors"
    16  	"golang.org/x/oauth2"
    17  )
    18  
    19  const (
    20  	authURL         string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize"
    21  	tokenURL        string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token" //nolint:gosec
    22  	endpointProfile string = "https://graph.microsoft.com/v1.0/me"
    23  )
    24  
    25  var defaultScopes = []string{"openid", "offline_access", "user.read"}
    26  
    27  // Note that this is a copy of the `microsoftonline` provider, but accepts a tenant.
    28  func New(clientKey, secret, callbackURL string, tenant string, scopes ...string) *Provider {
    29  	if tenant == "" {
    30  		tenant = "common"
    31  	}
    32  	p := &Provider{ClientKey: clientKey, Secret: secret, CallbackURL: callbackURL, Tenant: tenant, providerName: "microsoft"}
    33  	p.config = newConfig(p, scopes)
    34  	return p
    35  }
    36  
    37  type Provider struct {
    38  	ClientKey    string
    39  	Secret       string
    40  	CallbackURL  string
    41  	Tenant       string
    42  	HTTPClient   *http.Client
    43  	config       *oauth2.Config
    44  	providerName string
    45  }
    46  
    47  func (p *Provider) Name() string {
    48  	return p.providerName
    49  }
    50  
    51  func (p *Provider) SetName(name string) {
    52  	p.providerName = name
    53  }
    54  
    55  func (p *Provider) Client() *http.Client {
    56  	return goth.HTTPClientWithFallBack(p.HTTPClient)
    57  }
    58  
    59  func (p *Provider) Debug(_ bool) {}
    60  
    61  func (p *Provider) BeginAuth(state string) (goth.Session, error) {
    62  	au := p.config.AuthCodeURL(state)
    63  	return &Session{AuthURL: au}, nil
    64  }
    65  
    66  func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
    67  	msSession, ok := session.(*Session)
    68  	if !ok {
    69  		return goth.User{}, errors.Errorf("invalid session of type [%T]", session)
    70  	}
    71  	user := goth.User{
    72  		AccessToken: msSession.AccessToken,
    73  		Provider:    p.Name(),
    74  		ExpiresAt:   msSession.ExpiresAt,
    75  	}
    76  
    77  	if user.AccessToken == "" {
    78  		return user, errors.Errorf("%s cannot get user information without accessToken", p.providerName)
    79  	}
    80  
    81  	req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, endpointProfile, http.NoBody)
    82  	if err != nil {
    83  		return user, err
    84  	}
    85  
    86  	req.Header.Set(authorizationHeader(msSession))
    87  
    88  	response, err := p.Client().Do(req)
    89  	if err != nil {
    90  		return user, err
    91  	}
    92  	defer func() { _ = response.Body.Close() }()
    93  
    94  	if response.StatusCode != http.StatusOK {
    95  		return user, errors.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
    96  	}
    97  
    98  	user.AccessToken = msSession.AccessToken
    99  	if len(user.AccessToken) > 1024 {
   100  		user.AccessToken = ""
   101  	}
   102  
   103  	err = userFromReader(response.Body, &user)
   104  	return user, err
   105  }
   106  
   107  func (p *Provider) RefreshTokenAvailable() bool {
   108  	return false
   109  }
   110  
   111  func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
   112  	if refreshToken == "" {
   113  		return nil, errors.Errorf("no refresh token provided")
   114  	}
   115  
   116  	token := &oauth2.Token{RefreshToken: refreshToken}
   117  	ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token)
   118  	newToken, err := ts.Token()
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	return newToken, err
   123  }
   124  
   125  func newConfig(provider *Provider, scopes []string) *oauth2.Config {
   126  	c := &oauth2.Config{
   127  		ClientID:     provider.ClientKey,
   128  		ClientSecret: provider.Secret,
   129  		RedirectURL:  provider.CallbackURL,
   130  		Endpoint: oauth2.Endpoint{
   131  			AuthURL:  fmt.Sprintf(authURL, provider.Tenant),
   132  			TokenURL: fmt.Sprintf(tokenURL, provider.Tenant),
   133  		},
   134  		Scopes: []string{},
   135  	}
   136  
   137  	c.Scopes = append(c.Scopes, scopes...)
   138  	if len(scopes) == 0 {
   139  		c.Scopes = append(c.Scopes, defaultScopes...)
   140  	}
   141  
   142  	return c
   143  }
   144  
   145  func userFromReader(r io.Reader, user *goth.User) error {
   146  	buf := &bytes.Buffer{}
   147  	tee := io.TeeReader(r, buf)
   148  
   149  	u := struct {
   150  		ID                string `json:"id"`
   151  		Name              string `json:"displayName"`
   152  		Email             string `json:"mail"`
   153  		FirstName         string `json:"givenName"`
   154  		LastName          string `json:"surname"`
   155  		UserPrincipalName string `json:"userPrincipalName"`
   156  	}{}
   157  
   158  	if err := json.NewDecoder(tee).Decode(&u); err != nil {
   159  		return err
   160  	}
   161  
   162  	raw := map[string]any{}
   163  	if err := json.NewDecoder(buf).Decode(&raw); err != nil {
   164  		return err
   165  	}
   166  
   167  	user.UserID = u.ID
   168  	user.Email = defaults.String(u.Email, u.UserPrincipalName)
   169  	user.Name = u.Name
   170  	user.NickName = u.Name
   171  	user.FirstName = u.FirstName
   172  	user.LastName = u.LastName
   173  	user.RawData = raw
   174  
   175  	return nil
   176  }
   177  
   178  func authorizationHeader(session *Session) (string, string) {
   179  	return "Authorization", fmt.Sprintf("Bearer %s", session.AccessToken)
   180  }
   181  
   182  func (p *Provider) UnmarshalSession(data string) (goth.Session, error) {
   183  	session := &Session{}
   184  	err := json.NewDecoder(strings.NewReader(data)).Decode(session)
   185  	return session, err
   186  }