github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/clients/hadoopfs/src/main/java/io/lakefs/auth/AWSLakeFSTokenProvider.java (about)

     1  package io.lakefs.auth;
     2  
     3  import com.amazonaws.auth.AWSCredentialsProvider;
     4  import io.lakefs.Constants;
     5  import io.lakefs.FSConfiguration;
     6  import io.lakefs.clients.sdk.ApiClient;
     7  import io.lakefs.clients.sdk.AuthApi;
     8  import io.lakefs.clients.sdk.model.ExternalLoginInformation;
     9  import io.lakefs.clients.sdk.model.AuthenticationToken;
    10  import org.apache.commons.codec.binary.Base64;
    11  
    12  import java.io.IOException;
    13  
    14  import java.net.URI;
    15  import java.net.URL;
    16  import java.util.Arrays;
    17  import java.util.HashMap;
    18  import java.util.Map;
    19  import java.util.Optional;
    20  
    21  import org.apache.hadoop.conf.Configuration;
    22  
    23  
    24  public class AWSLakeFSTokenProvider implements LakeFSTokenProvider {
    25      STSGetCallerIdentityPresigner stsPresigner;
    26      AWSCredentialsProvider awsProvider;
    27      AuthenticationToken lakeFSAuthToken = null;
    28      String stsEndpoint;
    29      Map<String, String> stsAdditionalHeaders;
    30      int stsExpirationInSeconds;
    31      Optional<Integer> lakeFSTokenTTLSeconds = Optional.empty();
    32      ApiClient lakeFSApi;
    33  
    34      AWSLakeFSTokenProvider() {
    35      }
    36  
    37      public AWSLakeFSTokenProvider(AWSCredentialsProvider awsProvider, ApiClient lakeFSClient, STSGetCallerIdentityPresigner stsPresigner, String stsEndpoint, Map<String, String> stsAdditionalHeaders, int stsExpirationInSeconds) {
    38          this.awsProvider = awsProvider;
    39          this.stsPresigner = stsPresigner;
    40          this.lakeFSApi = lakeFSClient;
    41          this.stsEndpoint = stsEndpoint;
    42          this.stsAdditionalHeaders = stsAdditionalHeaders;
    43          this.stsExpirationInSeconds = stsExpirationInSeconds;
    44      }
    45  
    46      protected void initialize(AWSCredentialsProvider awsProvider, String scheme, Configuration conf) throws IOException {
    47          // aws credentials provider
    48          this.awsProvider = awsProvider;
    49  
    50          // sts endpoint to call STS
    51          this.stsEndpoint = FSConfiguration.get(conf, scheme, Constants.TOKEN_AWS_STS_ENDPOINT);
    52  
    53          if (this.stsEndpoint == null) {
    54              throw new IOException("Missing sts endpoint");
    55          }
    56  
    57          // Expiration for each identity token generated (they are very short-lived and only used for exchange, the value is part of the signature)
    58          this.stsExpirationInSeconds = FSConfiguration.getInt(conf, scheme, Constants.TOKEN_AWS_CREDENTIALS_PROVIDER_TOKEN_DURATION_SECONDS, 60);
    59  
    60          // initialize the presigner
    61          this.stsPresigner = new GetCallerIdentityV4Presigner();
    62  
    63          // initialize a lakeFS api client
    64  
    65          this.lakeFSApi = io.lakefs.clients.sdk.Configuration.getDefaultApiClient();
    66          this.lakeFSApi.addDefaultHeader("X-Lakefs-Client", "lakefs-hadoopfs/" + getClass().getPackage().getImplementationVersion());
    67          String endpoint = FSConfiguration.get(conf, scheme, Constants.ENDPOINT_KEY_SUFFIX, Constants.DEFAULT_CLIENT_ENDPOINT);
    68          if (endpoint.endsWith(Constants.SEPARATOR)) {
    69              endpoint = endpoint.substring(0, endpoint.length() - 1);
    70          }
    71          String sessionId = FSConfiguration.get(conf, scheme, Constants.SESSION_ID);
    72          if (sessionId != null) {
    73              this.lakeFSApi.addDefaultCookie("sessionId", sessionId);
    74          }
    75          this.lakeFSApi.setBasePath(endpoint);
    76  
    77          // optional timeout for lakeFS token
    78          int tokenTTL = FSConfiguration.getInt(conf, scheme, Constants.LAKEFS_AUTH_TOKEN_TTL_KEY_SUFFIX, -1);
    79          if (tokenTTL != -1) {
    80              this.lakeFSTokenTTLSeconds = Optional.of(tokenTTL);
    81          }
    82  
    83          // set additional headers (non-canonical) to sign with each request to STS
    84          // non-canonical headers are signed by the presigner and sent to STS for verification in the requests by lakeFS to exchange the token
    85          Map<String, String> additionalHeaders = FSConfiguration.getMap(conf, scheme, Constants.TOKEN_AWS_CREDENTIALS_PROVIDER_ADDITIONAL_HEADERS);
    86          if (additionalHeaders == null) {
    87              additionalHeaders = new HashMap<String, String>() {{
    88                  put(Constants.DEFAULT_AUTH_PROVIDER_SERVER_ID_HEADER, new URL(lakeFSApi.getBasePath()).getHost());
    89              }};
    90              // default header to sign is the lakeFS server host name
    91              additionalHeaders.put(Constants.DEFAULT_AUTH_PROVIDER_SERVER_ID_HEADER, new URL(endpoint).getHost());
    92          }
    93          this.stsAdditionalHeaders = additionalHeaders;
    94      }
    95  
    96      @Override
    97      public String getToken() {
    98          if (needsNewToken()) {
    99              refresh();
   100          }
   101          return this.lakeFSAuthToken.getToken();
   102      }
   103  
   104      private boolean needsNewToken() {
   105          return this.lakeFSAuthToken == null || this.lakeFSAuthToken.getTokenExpiration() < System.currentTimeMillis();
   106      }
   107  
   108      public GeneratePresignGetCallerIdentityResponse newPresignedRequest() throws Exception {
   109          GeneratePresignGetCallerIdentityRequest stsReq = new GeneratePresignGetCallerIdentityRequest(new URI(this.stsEndpoint), this.awsProvider.getCredentials(), this.stsAdditionalHeaders, this.stsExpirationInSeconds);
   110          return this.stsPresigner.presignRequest(stsReq);
   111      }
   112  
   113      public String newPresignedGetCallerIdentityToken() throws Exception {
   114          GeneratePresignGetCallerIdentityResponse signedRequest = this.newPresignedRequest();
   115  
   116          // generate token parameters object
   117          LakeFSExternalPrincipalIdentityRequest identityTokenParams = new LakeFSExternalPrincipalIdentityRequest(signedRequest.getHTTPMethod(), signedRequest.getHost(), signedRequest.getRegion(), signedRequest.getAction(), signedRequest.getDate(), signedRequest.getExpires(), signedRequest.getAccessKeyId(), signedRequest.getSignature(), Arrays.asList(signedRequest.getSignedHeadersParam().split(";")), signedRequest.getVersion(), signedRequest.getAlgorithm(), signedRequest.getSecurityToken());
   118  
   119          // base64 encode
   120          return Base64.encodeBase64String(identityTokenParams.toJSON().getBytes());
   121      }
   122  
   123      private void newToken() throws Exception {
   124          // created identity token to exchange for lakeFS token
   125          String identityToken = this.newPresignedGetCallerIdentityToken();
   126  
   127          // build lakeFS login request
   128          ExternalLoginInformation req = new ExternalLoginInformation();
   129  
   130          // set lakeFS token expiration if provided by the configuration
   131          this.lakeFSTokenTTLSeconds.ifPresent(req::setTokenExpirationDuration);
   132  
   133          // set identity request
   134          IdentityRequestRequestWrapper t = new IdentityRequestRequestWrapper(identityToken);
   135          req.setIdentityRequest(t);
   136  
   137          // call lakeFS to exchange the identity token for a lakeFS token
   138          AuthApi auth = new AuthApi(this.lakeFSApi);
   139          this.lakeFSAuthToken = auth.externalPrincipalLogin().externalLoginInformation(req).execute();
   140      }
   141  
   142      // refresh can be called to create a new token regardless if the current token is expired or not or does not exist.
   143      @Override
   144      public void refresh() {
   145          synchronized (this) {
   146              try {
   147                  newToken();
   148              } catch (Exception e) {
   149                  throw new RuntimeException("Failed to refresh token", e);
   150              }
   151          }
   152      }
   153  }