github.com/kyleu/dbaudit@v0.0.2-0.20240321155047-ff2f2c940496/app/lib/auth/msfix/provider.go (about) 1 // Package msfix - Content managed by Project Forge, see [projectforge.md] for details. 2 package msfix 3 4 import ( 5 "bytes" 6 "context" 7 "encoding/json" 8 "fmt" 9 "io" 10 "net/http" 11 "strings" 12 13 "github.com/markbates/going/defaults" 14 "github.com/markbates/goth" 15 "github.com/pkg/errors" 16 "golang.org/x/oauth2" 17 ) 18 19 const ( 20 authURL string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize" 21 tokenURL string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token" //nolint:gosec 22 endpointProfile string = "https://graph.microsoft.com/v1.0/me" 23 ) 24 25 var defaultScopes = []string{"openid", "offline_access", "user.read"} 26 27 // Note that this is a copy of the `microsoftonline` provider, but accepts a tenant. 28 func New(clientKey, secret, callbackURL string, tenant string, scopes ...string) *Provider { 29 if tenant == "" { 30 tenant = "common" 31 } 32 p := &Provider{ClientKey: clientKey, Secret: secret, CallbackURL: callbackURL, Tenant: tenant, providerName: "microsoft"} 33 p.config = newConfig(p, scopes) 34 return p 35 } 36 37 type Provider struct { 38 ClientKey string 39 Secret string 40 CallbackURL string 41 Tenant string 42 HTTPClient *http.Client 43 config *oauth2.Config 44 providerName string 45 } 46 47 func (p *Provider) Name() string { 48 return p.providerName 49 } 50 51 func (p *Provider) SetName(name string) { 52 p.providerName = name 53 } 54 55 func (p *Provider) Client() *http.Client { 56 return goth.HTTPClientWithFallBack(p.HTTPClient) 57 } 58 59 func (p *Provider) Debug(_ bool) {} 60 61 func (p *Provider) BeginAuth(state string) (goth.Session, error) { 62 au := p.config.AuthCodeURL(state) 63 return &Session{AuthURL: au}, nil 64 } 65 66 func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { 67 msSession, ok := session.(*Session) 68 if !ok { 69 return goth.User{}, errors.Errorf("invalid session of type [%T]", session) 70 } 71 user := goth.User{ 72 AccessToken: msSession.AccessToken, 73 Provider: p.Name(), 74 ExpiresAt: msSession.ExpiresAt, 75 } 76 77 if user.AccessToken == "" { 78 return user, errors.Errorf("%s cannot get user information without accessToken", p.providerName) 79 } 80 81 req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, endpointProfile, http.NoBody) 82 if err != nil { 83 return user, err 84 } 85 86 req.Header.Set(authorizationHeader(msSession)) 87 88 response, err := p.Client().Do(req) 89 if err != nil { 90 return user, err 91 } 92 defer func() { _ = response.Body.Close() }() 93 94 if response.StatusCode != http.StatusOK { 95 return user, errors.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode) 96 } 97 98 user.AccessToken = msSession.AccessToken 99 if len(user.AccessToken) > 1024 { 100 user.AccessToken = "" 101 } 102 103 err = userFromReader(response.Body, &user) 104 return user, err 105 } 106 107 func (p *Provider) RefreshTokenAvailable() bool { 108 return false 109 } 110 111 func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { 112 if refreshToken == "" { 113 return nil, errors.Errorf("no refresh token provided") 114 } 115 116 token := &oauth2.Token{RefreshToken: refreshToken} 117 ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token) 118 newToken, err := ts.Token() 119 if err != nil { 120 return nil, err 121 } 122 return newToken, err 123 } 124 125 func newConfig(provider *Provider, scopes []string) *oauth2.Config { 126 c := &oauth2.Config{ 127 ClientID: provider.ClientKey, 128 ClientSecret: provider.Secret, 129 RedirectURL: provider.CallbackURL, 130 Endpoint: oauth2.Endpoint{ 131 AuthURL: fmt.Sprintf(authURL, provider.Tenant), 132 TokenURL: fmt.Sprintf(tokenURL, provider.Tenant), 133 }, 134 Scopes: []string{}, 135 } 136 137 c.Scopes = append(c.Scopes, scopes...) 138 if len(scopes) == 0 { 139 c.Scopes = append(c.Scopes, defaultScopes...) 140 } 141 142 return c 143 } 144 145 func userFromReader(r io.Reader, user *goth.User) error { 146 buf := &bytes.Buffer{} 147 tee := io.TeeReader(r, buf) 148 149 u := struct { 150 ID string `json:"id"` 151 Name string `json:"displayName"` 152 Email string `json:"mail"` 153 FirstName string `json:"givenName"` 154 LastName string `json:"surname"` 155 UserPrincipalName string `json:"userPrincipalName"` 156 }{} 157 158 if err := json.NewDecoder(tee).Decode(&u); err != nil { 159 return err 160 } 161 162 raw := map[string]any{} 163 if err := json.NewDecoder(buf).Decode(&raw); err != nil { 164 return err 165 } 166 167 user.UserID = u.ID 168 user.Email = defaults.String(u.Email, u.UserPrincipalName) 169 user.Name = u.Name 170 user.NickName = u.Name 171 user.FirstName = u.FirstName 172 user.LastName = u.LastName 173 user.RawData = raw 174 175 return nil 176 } 177 178 func authorizationHeader(session *Session) (string, string) { 179 return "Authorization", fmt.Sprintf("Bearer %s", session.AccessToken) 180 } 181 182 func (p *Provider) UnmarshalSession(data string) (goth.Session, error) { 183 session := &Session{} 184 err := json.NewDecoder(strings.NewReader(data)).Decode(session) 185 return session, err 186 }