github.com/yaegashi/msgraph.go@v0.1.4/msauth/msauth.go (about) 1 // Package msauth implements a library to authorize against Microsoft identity platform: 2 // https://docs.microsoft.com/en-us/azure/active-directory/develop/ 3 // 4 // It utilizes v2.0 endpoint 5 // so it can authorize users with both personal (Microsoft) and organizational (Azure AD) account. 6 package msauth 7 8 import ( 9 "context" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "io/ioutil" 14 "net/http" 15 "net/url" 16 "sync" 17 "time" 18 19 "golang.org/x/oauth2" 20 ) 21 22 const ( 23 // DefaultMSGraphScope is the default scope for MS Graph API 24 DefaultMSGraphScope = "https://graph.microsoft.com/.default" 25 endpointURLFormat = "https://login.microsoftonline.com/%s/oauth2/v2.0/%s" 26 ) 27 28 // TokenError is returned on failed authentication 29 type TokenError struct { 30 ErrorObject string `json:"error"` 31 ErrorDescription string `json:"error_description"` 32 } 33 34 // Error implements error interface 35 func (t *TokenError) Error() string { 36 return fmt.Sprintf("%s: %s", t.ErrorObject, t.ErrorDescription) 37 } 38 39 func deviceCodeURL(tenantID string) string { 40 return fmt.Sprintf(endpointURLFormat, tenantID, "devicecode") 41 } 42 43 func tokenURL(tenantID string) string { 44 return fmt.Sprintf(endpointURLFormat, tenantID, "token") 45 } 46 47 type tokenJSON struct { 48 AccessToken string `json:"access_token"` 49 TokenType string `json:"token_type"` 50 RefreshToken string `json:"refresh_token"` 51 ExpiresIn int `json:"expires_in"` 52 } 53 54 func (e *tokenJSON) expiry() (t time.Time) { 55 if v := e.ExpiresIn; v != 0 { 56 return time.Now().Add(time.Duration(v) * time.Second) 57 } 58 return 59 } 60 61 // Manager is oauth2 token cache manager 62 type Manager struct { 63 mu sync.Mutex 64 Dirty bool 65 TokenCache map[string]*oauth2.Token 66 } 67 68 // NewManager returns a new Manager instance 69 func NewManager() *Manager { 70 return &Manager{TokenCache: map[string]*oauth2.Token{}} 71 } 72 73 // LoadBytes loads token cache from opaque bytes (it's actually JSON) 74 func (m *Manager) LoadBytes(b []byte) error { 75 m.mu.Lock() 76 defer m.mu.Unlock() 77 return json.Unmarshal(b, &m.TokenCache) 78 } 79 80 // SaveBytes saves token cache to opaque bytes (it's actually JSON) 81 func (m *Manager) SaveBytes() ([]byte, error) { 82 m.mu.Lock() 83 defer m.mu.Unlock() 84 return json.Marshal(m.TokenCache) 85 } 86 87 // LoadFile loads token cache from file with dirty state control 88 func (m *Manager) LoadFile(path string) error { 89 m.mu.Lock() 90 defer m.mu.Unlock() 91 b, err := ReadLocation(path) 92 if err != nil { 93 return err 94 } 95 err = json.Unmarshal(b, &m.TokenCache) 96 if err != nil { 97 return err 98 } 99 m.Dirty = false 100 return nil 101 } 102 103 // SaveFile saves token cache to file with dirty state control 104 func (m *Manager) SaveFile(path string) error { 105 m.mu.Lock() 106 defer m.mu.Unlock() 107 if !m.Dirty { 108 return nil 109 } 110 b, err := json.Marshal(m.TokenCache) 111 if err != nil { 112 return err 113 } 114 err = WriteLocation(path, b, 0644) 115 if err != nil { 116 return err 117 } 118 m.Dirty = false 119 return nil 120 } 121 122 // CacheKey generates a token cache key from tenantID/clientID 123 func CacheKey(tenantID, clientID string) string { 124 return fmt.Sprintf("%s:%s", tenantID, clientID) 125 } 126 127 // GetToken gets a token from token cache 128 func (m *Manager) GetToken(cacheKey string) (*oauth2.Token, bool) { 129 m.mu.Lock() 130 defer m.mu.Unlock() 131 token, ok := m.TokenCache[cacheKey] 132 return token, ok 133 } 134 135 // PutToken puts a token into token cache 136 func (m *Manager) PutToken(cacheKey string, token *oauth2.Token) { 137 m.mu.Lock() 138 defer m.mu.Unlock() 139 oldToken, ok := m.TokenCache[cacheKey] 140 if ok && *oldToken == *token { 141 return 142 } 143 m.TokenCache[cacheKey] = token 144 m.Dirty = true 145 } 146 147 // requestToken requests a token from the token endpoint 148 // TODO(ctx): use http client from ctx 149 func (m *Manager) requestToken(ctx context.Context, tenantID, clientID string, values url.Values) (*oauth2.Token, error) { 150 res, err := http.PostForm(tokenURL(tenantID), values) 151 if err != nil { 152 return nil, err 153 } 154 defer res.Body.Close() 155 b, err := ioutil.ReadAll(res.Body) 156 if err != nil { 157 return nil, err 158 } 159 if res.StatusCode != http.StatusOK { 160 var terr *TokenError 161 err = json.Unmarshal(b, &terr) 162 if err != nil { 163 return nil, err 164 } 165 return nil, terr 166 } 167 var tj *tokenJSON 168 err = json.Unmarshal(b, &tj) 169 if err != nil { 170 return nil, err 171 } 172 token := &oauth2.Token{ 173 AccessToken: tj.AccessToken, 174 TokenType: tj.TokenType, 175 RefreshToken: tj.RefreshToken, 176 Expiry: tj.expiry(), 177 } 178 if token.AccessToken == "" { 179 return nil, errors.New("msauth: server response missing access_token") 180 } 181 return token, nil 182 }