github.com/rclone/rclone@v1.66.1-0.20240517100346-7b89735ae726/lib/jwtutil/jwtutil.go (about)

     1  // Package jwtutil provides JWT utilities.
     2  package jwtutil
     3  
     4  import (
     5  	"bytes"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"encoding/hex"
     9  	"encoding/json"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net/http"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/golang-jwt/jwt/v4"
    18  	"github.com/rclone/rclone/fs"
    19  	"github.com/rclone/rclone/fs/config/configmap"
    20  	"github.com/rclone/rclone/lib/oauthutil"
    21  
    22  	"golang.org/x/oauth2"
    23  )
    24  
    25  // RandomHex creates a random string of the given length
    26  func RandomHex(n int) (string, error) {
    27  	bytes := make([]byte, n)
    28  	if _, err := rand.Read(bytes); err != nil {
    29  		return "", err
    30  	}
    31  	return hex.EncodeToString(bytes), nil
    32  }
    33  
    34  // Config configures rclone using JWT
    35  func Config(id, name, url string, claims jwt.Claims, headerParams map[string]interface{}, queryParams map[string]string, privateKey *rsa.PrivateKey, m configmap.Mapper, client *http.Client) (err error) {
    36  	jwtToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
    37  	for key, value := range headerParams {
    38  		jwtToken.Header[key] = value
    39  	}
    40  	payload, err := jwtToken.SignedString(privateKey)
    41  	if err != nil {
    42  		return fmt.Errorf("jwtutil: failed to encode payload: %w", err)
    43  	}
    44  	req, err := http.NewRequest("POST", url, nil)
    45  	if err != nil {
    46  		return fmt.Errorf("jwtutil: failed to create new request: %w", err)
    47  	}
    48  	q := req.URL.Query()
    49  	q.Add("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
    50  	q.Add("assertion", payload)
    51  	for key, value := range queryParams {
    52  		q.Add(key, value)
    53  	}
    54  	queryString := q.Encode()
    55  
    56  	req, err = http.NewRequest("POST", url, bytes.NewBuffer([]byte(queryString)))
    57  	if err != nil {
    58  		return fmt.Errorf("jwtutil: failed to create new request: %w", err)
    59  	}
    60  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
    61  
    62  	resp, err := client.Do(req)
    63  	if err != nil {
    64  		return fmt.Errorf("jwtutil: failed making auth request: %w", err)
    65  	}
    66  
    67  	s, err := bodyToString(resp.Body)
    68  	if err != nil {
    69  		fs.Debugf(nil, "jwtutil: failed to get response body")
    70  	}
    71  	if resp.StatusCode != 200 {
    72  		err = errors.New(resp.Status)
    73  		return fmt.Errorf("jwtutil: failed making auth request: %w", err)
    74  	}
    75  	defer func() {
    76  		deferredErr := resp.Body.Close()
    77  		if deferredErr != nil {
    78  			err = fmt.Errorf("jwtutil: failed to close resp.Body: %w", err)
    79  		}
    80  	}()
    81  
    82  	result := &response{}
    83  	err = json.NewDecoder(strings.NewReader(s)).Decode(result)
    84  	if result.AccessToken == "" && err == nil {
    85  		err = errors.New("no AccessToken in Response")
    86  	}
    87  	if err != nil {
    88  		return fmt.Errorf("jwtutil: failed to get token: %w", err)
    89  	}
    90  	token := &oauth2.Token{
    91  		AccessToken: result.AccessToken,
    92  		TokenType:   result.TokenType,
    93  	}
    94  	e := result.ExpiresIn
    95  	if e != 0 {
    96  		token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
    97  	}
    98  	return oauthutil.PutToken(name, m, token, true)
    99  }
   100  
   101  func bodyToString(responseBody io.Reader) (bodyString string, err error) {
   102  	bodyBytes, err := io.ReadAll(responseBody)
   103  	if err != nil {
   104  		return "", err
   105  	}
   106  	bodyString = string(bodyBytes)
   107  	fs.Debugf(nil, "jwtutil: Response Body: "+bodyString)
   108  	return bodyString, nil
   109  }
   110  
   111  type response struct {
   112  	AccessToken string `json:"access_token"`
   113  	TokenType   string `json:"token_type"`
   114  	ExpiresIn   int    `json:"expires_in"`
   115  }