github.com/yaegashi/msgraph.go@v0.1.4/msauth/msauth.go (about)

     1  // Package msauth implements a library to authorize against Microsoft identity platform:
     2  // https://docs.microsoft.com/en-us/azure/active-directory/develop/
     3  //
     4  // It utilizes v2.0 endpoint
     5  // so it can authorize users with both personal (Microsoft) and organizational (Azure AD) account.
     6  package msauth
     7  
     8  import (
     9  	"context"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"net/http"
    15  	"net/url"
    16  	"sync"
    17  	"time"
    18  
    19  	"golang.org/x/oauth2"
    20  )
    21  
    22  const (
    23  	// DefaultMSGraphScope is the default scope for MS Graph API
    24  	DefaultMSGraphScope = "https://graph.microsoft.com/.default"
    25  	endpointURLFormat   = "https://login.microsoftonline.com/%s/oauth2/v2.0/%s"
    26  )
    27  
    28  // TokenError is returned on failed authentication
    29  type TokenError struct {
    30  	ErrorObject      string `json:"error"`
    31  	ErrorDescription string `json:"error_description"`
    32  }
    33  
    34  // Error implements error interface
    35  func (t *TokenError) Error() string {
    36  	return fmt.Sprintf("%s: %s", t.ErrorObject, t.ErrorDescription)
    37  }
    38  
    39  func deviceCodeURL(tenantID string) string {
    40  	return fmt.Sprintf(endpointURLFormat, tenantID, "devicecode")
    41  }
    42  
    43  func tokenURL(tenantID string) string {
    44  	return fmt.Sprintf(endpointURLFormat, tenantID, "token")
    45  }
    46  
    47  type tokenJSON struct {
    48  	AccessToken  string `json:"access_token"`
    49  	TokenType    string `json:"token_type"`
    50  	RefreshToken string `json:"refresh_token"`
    51  	ExpiresIn    int    `json:"expires_in"`
    52  }
    53  
    54  func (e *tokenJSON) expiry() (t time.Time) {
    55  	if v := e.ExpiresIn; v != 0 {
    56  		return time.Now().Add(time.Duration(v) * time.Second)
    57  	}
    58  	return
    59  }
    60  
    61  // Manager is oauth2 token cache manager
    62  type Manager struct {
    63  	mu         sync.Mutex
    64  	Dirty      bool
    65  	TokenCache map[string]*oauth2.Token
    66  }
    67  
    68  // NewManager returns a new Manager instance
    69  func NewManager() *Manager {
    70  	return &Manager{TokenCache: map[string]*oauth2.Token{}}
    71  }
    72  
    73  // LoadBytes loads token cache from opaque bytes (it's actually JSON)
    74  func (m *Manager) LoadBytes(b []byte) error {
    75  	m.mu.Lock()
    76  	defer m.mu.Unlock()
    77  	return json.Unmarshal(b, &m.TokenCache)
    78  }
    79  
    80  // SaveBytes saves token cache to opaque bytes (it's actually JSON)
    81  func (m *Manager) SaveBytes() ([]byte, error) {
    82  	m.mu.Lock()
    83  	defer m.mu.Unlock()
    84  	return json.Marshal(m.TokenCache)
    85  }
    86  
    87  // LoadFile loads token cache from file with dirty state control
    88  func (m *Manager) LoadFile(path string) error {
    89  	m.mu.Lock()
    90  	defer m.mu.Unlock()
    91  	b, err := ReadLocation(path)
    92  	if err != nil {
    93  		return err
    94  	}
    95  	err = json.Unmarshal(b, &m.TokenCache)
    96  	if err != nil {
    97  		return err
    98  	}
    99  	m.Dirty = false
   100  	return nil
   101  }
   102  
   103  // SaveFile saves token cache to file with dirty state control
   104  func (m *Manager) SaveFile(path string) error {
   105  	m.mu.Lock()
   106  	defer m.mu.Unlock()
   107  	if !m.Dirty {
   108  		return nil
   109  	}
   110  	b, err := json.Marshal(m.TokenCache)
   111  	if err != nil {
   112  		return err
   113  	}
   114  	err = WriteLocation(path, b, 0644)
   115  	if err != nil {
   116  		return err
   117  	}
   118  	m.Dirty = false
   119  	return nil
   120  }
   121  
   122  // CacheKey generates a token cache key from tenantID/clientID
   123  func CacheKey(tenantID, clientID string) string {
   124  	return fmt.Sprintf("%s:%s", tenantID, clientID)
   125  }
   126  
   127  // GetToken gets a token from token cache
   128  func (m *Manager) GetToken(cacheKey string) (*oauth2.Token, bool) {
   129  	m.mu.Lock()
   130  	defer m.mu.Unlock()
   131  	token, ok := m.TokenCache[cacheKey]
   132  	return token, ok
   133  }
   134  
   135  // PutToken puts a token into token cache
   136  func (m *Manager) PutToken(cacheKey string, token *oauth2.Token) {
   137  	m.mu.Lock()
   138  	defer m.mu.Unlock()
   139  	oldToken, ok := m.TokenCache[cacheKey]
   140  	if ok && *oldToken == *token {
   141  		return
   142  	}
   143  	m.TokenCache[cacheKey] = token
   144  	m.Dirty = true
   145  }
   146  
   147  // requestToken requests a token from the token endpoint
   148  // TODO(ctx): use http client from ctx
   149  func (m *Manager) requestToken(ctx context.Context, tenantID, clientID string, values url.Values) (*oauth2.Token, error) {
   150  	res, err := http.PostForm(tokenURL(tenantID), values)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	defer res.Body.Close()
   155  	b, err := ioutil.ReadAll(res.Body)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	if res.StatusCode != http.StatusOK {
   160  		var terr *TokenError
   161  		err = json.Unmarshal(b, &terr)
   162  		if err != nil {
   163  			return nil, err
   164  		}
   165  		return nil, terr
   166  	}
   167  	var tj *tokenJSON
   168  	err = json.Unmarshal(b, &tj)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	token := &oauth2.Token{
   173  		AccessToken:  tj.AccessToken,
   174  		TokenType:    tj.TokenType,
   175  		RefreshToken: tj.RefreshToken,
   176  		Expiry:       tj.expiry(),
   177  	}
   178  	if token.AccessToken == "" {
   179  		return nil, errors.New("msauth: server response missing access_token")
   180  	}
   181  	return token, nil
   182  }