github.com/yaegashi/msgraph.go@v0.1.4/msauth/device_authorization_grant.go (about)

     1  package msauth
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/url"
    10  	"os"
    11  	"strings"
    12  	"time"
    13  
    14  	"golang.org/x/oauth2"
    15  	"golang.org/x/oauth2/microsoft"
    16  )
    17  
    18  const (
    19  	deviceCodeGrantType       = "urn:ietf:params:oauth:grant-type:device_code"
    20  	authorizationPendingError = "authorization_pending"
    21  )
    22  
    23  // DeviceCode is returned on device auth initiation
    24  type DeviceCode struct {
    25  	DeviceCode      string `json:"device_code"`
    26  	UserCode        string `json:"user_code"`
    27  	VerificationURL string `json:"verification_url"`
    28  	ExpiresIn       int    `json:"expires_in"`
    29  	Interval        int    `json:"interval"`
    30  	Message         string `json:"message"`
    31  }
    32  
    33  // DeviceAuthorizationGrant performs OAuth 2.0 device authorization grant and returns auto-refreshing TokenSource
    34  func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, clientID string, scopes []string, callback func(*DeviceCode) error) (oauth2.TokenSource, error) {
    35  	endpoint := microsoft.AzureADEndpoint(tenantID)
    36  	endpoint.AuthStyle = oauth2.AuthStyleInParams
    37  	config := &oauth2.Config{
    38  		ClientID: clientID,
    39  		Endpoint: endpoint,
    40  		Scopes:   scopes,
    41  	}
    42  	if t, ok := m.GetToken(CacheKey(tenantID, clientID)); ok {
    43  		tt, err := config.TokenSource(ctx, t).Token()
    44  		if err == nil {
    45  			m.PutToken(CacheKey(tenantID, clientID), tt)
    46  			return config.TokenSource(ctx, tt), nil
    47  		}
    48  		if _, ok := err.(*oauth2.RetrieveError); !ok {
    49  			return nil, err
    50  		}
    51  	}
    52  	scope := strings.Join(scopes, " ")
    53  	res, err := http.PostForm(deviceCodeURL(tenantID), url.Values{"client_id": {clientID}, "scope": {scope}})
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	defer res.Body.Close()
    58  	if res.StatusCode != http.StatusOK {
    59  		b, _ := ioutil.ReadAll(res.Body)
    60  		return nil, fmt.Errorf("%s: %s", res.Status, string(b))
    61  	}
    62  	dc := &DeviceCode{}
    63  	dec := json.NewDecoder(res.Body)
    64  	err = dec.Decode(&dc)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	if callback != nil {
    69  		err = callback(dc)
    70  		if err != nil {
    71  			return nil, err
    72  		}
    73  	} else {
    74  		fmt.Fprintln(os.Stderr, dc.Message)
    75  	}
    76  	values := url.Values{
    77  		"client_id":   {clientID},
    78  		"grant_type":  {deviceCodeGrantType},
    79  		"device_code": {dc.DeviceCode},
    80  	}
    81  	interval := dc.Interval
    82  	if interval == 0 {
    83  		interval = 5
    84  	}
    85  	for {
    86  		time.Sleep(time.Second * time.Duration(interval))
    87  		token, err := m.requestToken(ctx, tenantID, clientID, values)
    88  		if err == nil {
    89  			m.PutToken(CacheKey(tenantID, clientID), token)
    90  			return config.TokenSource(ctx, token), nil
    91  		}
    92  		tokenError, ok := err.(*TokenError)
    93  		if !ok || tokenError.ErrorObject != authorizationPendingError {
    94  			return nil, err
    95  		}
    96  	}
    97  }