github.com/cs3org/reva/v2@v2.27.7/pkg/cbox/utils/tokenmanagement.go (about)

     1  // Copyright 2018-2021 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package utils
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"errors"
    25  	"io"
    26  	"net/http"
    27  	"net/url"
    28  	"strings"
    29  	"sync"
    30  	"time"
    31  
    32  	"github.com/cs3org/reva/v2/pkg/rhttp"
    33  )
    34  
    35  // APITokenManager stores config related to api management
    36  type APITokenManager struct {
    37  	oidcToken OIDCToken
    38  	conf      *config
    39  	client    *http.Client
    40  }
    41  
    42  // OIDCToken stores the OIDC token used to authenticate requests to the REST API service
    43  type OIDCToken struct {
    44  	sync.Mutex          // concurrent access to apiToken and tokenExpirationTime
    45  	apiToken            string
    46  	tokenExpirationTime time.Time
    47  }
    48  
    49  type config struct {
    50  	TargetAPI         string
    51  	OIDCTokenEndpoint string
    52  	ClientID          string
    53  	ClientSecret      string
    54  }
    55  
    56  // InitAPITokenManager initializes a new APITokenManager
    57  func InitAPITokenManager(targetAPI, oidcTokenEndpoint, clientID, clientSecret string) *APITokenManager {
    58  	return &APITokenManager{
    59  		conf: &config{
    60  			TargetAPI:         targetAPI,
    61  			OIDCTokenEndpoint: oidcTokenEndpoint,
    62  			ClientID:          clientID,
    63  			ClientSecret:      clientSecret,
    64  		},
    65  		client: rhttp.GetHTTPClient(
    66  			rhttp.Timeout(10*time.Second),
    67  			rhttp.Insecure(true),
    68  		),
    69  	}
    70  }
    71  
    72  func (a *APITokenManager) renewAPIToken(ctx context.Context, forceRenewal bool) error {
    73  	// Received tokens have an expiration time of 20 minutes.
    74  	// Take a couple of seconds as buffer time for the API call to complete
    75  	if forceRenewal || a.oidcToken.tokenExpirationTime.Before(time.Now().Add(time.Second*time.Duration(2))) {
    76  		token, expiration, err := a.getAPIToken(ctx)
    77  		if err != nil {
    78  			return err
    79  		}
    80  
    81  		a.oidcToken.Lock()
    82  		defer a.oidcToken.Unlock()
    83  
    84  		a.oidcToken.apiToken = token
    85  		a.oidcToken.tokenExpirationTime = expiration
    86  	}
    87  	return nil
    88  }
    89  
    90  func (a *APITokenManager) getAPIToken(ctx context.Context) (string, time.Time, error) {
    91  
    92  	params := url.Values{
    93  		"grant_type": {"client_credentials"},
    94  		"audience":   {a.conf.TargetAPI},
    95  	}
    96  
    97  	httpReq, err := http.NewRequest("POST", a.conf.OIDCTokenEndpoint, strings.NewReader(params.Encode()))
    98  	if err != nil {
    99  		return "", time.Time{}, err
   100  	}
   101  	httpReq.SetBasicAuth(a.conf.ClientID, a.conf.ClientSecret)
   102  	httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
   103  
   104  	httpRes, err := a.client.Do(httpReq)
   105  	if err != nil {
   106  		return "", time.Time{}, err
   107  	}
   108  	defer httpRes.Body.Close()
   109  
   110  	body, err := io.ReadAll(httpRes.Body)
   111  	if err != nil {
   112  		return "", time.Time{}, err
   113  	}
   114  	if httpRes.StatusCode < 200 || httpRes.StatusCode > 299 {
   115  		return "", time.Time{}, errors.New("rest: get token endpoint returned " + httpRes.Status)
   116  	}
   117  
   118  	var result map[string]interface{}
   119  	err = json.Unmarshal(body, &result)
   120  	if err != nil {
   121  		return "", time.Time{}, err
   122  	}
   123  
   124  	expirationSecs := result["expires_in"].(float64)
   125  	expirationTime := time.Now().Add(time.Second * time.Duration(expirationSecs))
   126  	return result["access_token"].(string), expirationTime, nil
   127  }
   128  
   129  // SendAPIGetRequest makes an API GET Request to the passed URL
   130  func (a *APITokenManager) SendAPIGetRequest(ctx context.Context, url string, forceRenewal bool) (map[string]interface{}, error) {
   131  	err := a.renewAPIToken(ctx, forceRenewal)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	httpReq, err := http.NewRequest("GET", url, nil)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  
   141  	// We don't need to take the lock when reading apiToken, because if we reach here,
   142  	// the token is valid at least for a couple of seconds. Even if another request modifies
   143  	// the token and expiration time while this request is in progress, the current token will still be valid.
   144  	httpReq.Header.Set("Authorization", "Bearer "+a.oidcToken.apiToken)
   145  
   146  	httpRes, err := a.client.Do(httpReq)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  	defer httpRes.Body.Close()
   151  
   152  	if httpRes.StatusCode == http.StatusUnauthorized {
   153  		// The token is no longer valid, try renewing it
   154  		return a.SendAPIGetRequest(ctx, url, true)
   155  	}
   156  	if httpRes.StatusCode < 200 || httpRes.StatusCode > 299 {
   157  		return nil, errors.New("rest: API request returned " + httpRes.Status)
   158  	}
   159  
   160  	body, err := io.ReadAll(httpRes.Body)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	var result map[string]interface{}
   166  	err = json.Unmarshal(body, &result)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	return result, nil
   172  }