github.com/gaukas/goofys100m@v0.24.0/api/common/conf_s3.go (about)

     1  // Copyright 2019 Databricks
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package common
    16  
    17  import (
    18  	"crypto/md5"
    19  	"encoding/base64"
    20  	"fmt"
    21  	"net/http"
    22  
    23  	"github.com/aws/aws-sdk-go/aws"
    24  	"github.com/aws/aws-sdk-go/aws/client"
    25  	"github.com/aws/aws-sdk-go/aws/credentials"
    26  	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
    27  	"github.com/aws/aws-sdk-go/aws/session"
    28  )
    29  
    30  type S3Config struct {
    31  	Profile         string
    32  	AccessKey       string
    33  	SecretKey       string
    34  	RoleArn         string
    35  	RoleExternalId  string
    36  	RoleSessionName string
    37  	StsEndpoint     string
    38  
    39  	RequesterPays bool
    40  	Region        string
    41  	RegionSet     bool
    42  
    43  	StorageClass string
    44  
    45  	UseSSE     bool
    46  	UseKMS     bool
    47  	KMSKeyID   string
    48  	SseC       string
    49  	SseCDigest string
    50  	ACL        string
    51  
    52  	Subdomain bool
    53  
    54  	Credentials *credentials.Credentials
    55  	Session     *session.Session
    56  
    57  	BucketOwner string
    58  }
    59  
    60  var s3Session *session.Session
    61  
    62  func (c *S3Config) Init() *S3Config {
    63  	if c.Region == "" {
    64  		c.Region = "us-east-1"
    65  	}
    66  	if c.StorageClass == "" {
    67  		c.StorageClass = "STANDARD"
    68  	}
    69  	return c
    70  }
    71  
    72  func (c *S3Config) ToAwsConfig(flags *FlagStorage) (*aws.Config, error) {
    73  	awsConfig := (&aws.Config{
    74  		Region: &c.Region,
    75  		Logger: GetLogger("s3"),
    76  	}).WithHTTPClient(&http.Client{
    77  		Transport: &defaultHTTPTransport,
    78  		Timeout:   flags.HTTPTimeout,
    79  	})
    80  	if flags.DebugS3 {
    81  		awsConfig.LogLevel = aws.LogLevel(aws.LogDebug | aws.LogDebugWithRequestErrors)
    82  	}
    83  
    84  	if c.Credentials == nil {
    85  		if c.AccessKey != "" {
    86  			c.Credentials = credentials.NewStaticCredentials(c.AccessKey, c.SecretKey, "")
    87  		}
    88  	}
    89  	if flags.Endpoint != "" {
    90  		awsConfig.Endpoint = &flags.Endpoint
    91  	}
    92  
    93  	awsConfig.S3ForcePathStyle = aws.Bool(!c.Subdomain)
    94  
    95  	if c.Session == nil {
    96  		if s3Session == nil {
    97  			var err error
    98  			s3Session, err = session.NewSessionWithOptions(session.Options{
    99  				Profile:           c.Profile,
   100  				SharedConfigState: session.SharedConfigEnable,
   101  			})
   102  			if err != nil {
   103  				return nil, err
   104  			}
   105  		}
   106  		c.Session = s3Session
   107  	}
   108  
   109  	if c.RoleArn != "" {
   110  		c.Credentials = stscreds.NewCredentials(stsConfigProvider{c}, c.RoleArn,
   111  			func(p *stscreds.AssumeRoleProvider) {
   112  				if c.RoleExternalId != "" {
   113  					p.ExternalID = &c.RoleExternalId
   114  				}
   115  				p.RoleSessionName = c.RoleSessionName
   116  			})
   117  	}
   118  
   119  	if c.Credentials != nil {
   120  		awsConfig.Credentials = c.Credentials
   121  	}
   122  
   123  	if c.SseC != "" {
   124  		key, err := base64.StdEncoding.DecodeString(c.SseC)
   125  		if err != nil {
   126  			return nil, fmt.Errorf("sse-c is not base64-encoded: %v", err)
   127  		}
   128  
   129  		c.SseC = string(key)
   130  		m := md5.Sum(key)
   131  		c.SseCDigest = base64.StdEncoding.EncodeToString(m[:])
   132  	}
   133  
   134  	return awsConfig, nil
   135  }
   136  
   137  type stsConfigProvider struct {
   138  	*S3Config
   139  }
   140  
   141  func (c stsConfigProvider) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
   142  	config := c.Session.ClientConfig(serviceName, cfgs...)
   143  	if c.Credentials != nil {
   144  		config.Config.Credentials = c.Credentials
   145  	}
   146  	if c.StsEndpoint != "" {
   147  		config.Endpoint = c.StsEndpoint
   148  	}
   149  
   150  	return config
   151  }