github.com/uber/kraken@v0.1.4/lib/backend/s3backend/client.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package s3backend
    15  
    16  import (
    17  	"errors"
    18  	"fmt"
    19  	"io"
    20  	"path"
    21  
    22  	"github.com/uber/kraken/core"
    23  	"github.com/uber/kraken/lib/backend"
    24  	"github.com/uber/kraken/lib/backend/backenderrors"
    25  	"github.com/uber/kraken/lib/backend/namepath"
    26  	"github.com/uber/kraken/utils/log"
    27  	"github.com/uber/kraken/utils/rwutil"
    28  
    29  	"github.com/aws/aws-sdk-go/aws"
    30  	"github.com/aws/aws-sdk-go/aws/awserr"
    31  	"github.com/aws/aws-sdk-go/aws/credentials"
    32  	"github.com/aws/aws-sdk-go/aws/session"
    33  	"github.com/aws/aws-sdk-go/service/s3"
    34  	"github.com/aws/aws-sdk-go/service/s3/s3manager"
    35  	"gopkg.in/yaml.v2"
    36  )
    37  
    38  const _s3 = "s3"
    39  
    40  func init() {
    41  	backend.Register(_s3, &factory{})
    42  }
    43  
    44  type factory struct{}
    45  
    46  func (f *factory) Create(
    47  	confRaw interface{}, authConfRaw interface{}) (backend.Client, error) {
    48  
    49  	confBytes, err := yaml.Marshal(confRaw)
    50  	if err != nil {
    51  		return nil, errors.New("marshal s3 config")
    52  	}
    53  	authConfBytes, err := yaml.Marshal(authConfRaw)
    54  	if err != nil {
    55  		return nil, errors.New("marshal s3 auth config")
    56  	}
    57  
    58  	var config Config
    59  	if err := yaml.Unmarshal(confBytes, &config); err != nil {
    60  		return nil, errors.New("unmarshal s3 config")
    61  	}
    62  	var userAuth UserAuthConfig
    63  	if err := yaml.Unmarshal(authConfBytes, &userAuth); err != nil {
    64  		return nil, errors.New("unmarshal s3 auth config")
    65  	}
    66  
    67  	return NewClient(config, userAuth)
    68  }
    69  
    70  // Client implements a backend.Client for S3.
    71  type Client struct {
    72  	config Config
    73  	pather namepath.Pather
    74  	s3     S3
    75  }
    76  
    77  // Option allows setting optional Client parameters.
    78  type Option func(*Client)
    79  
    80  // WithS3 configures a Client with a custom S3 implementation.
    81  func WithS3(s3 S3) Option {
    82  	return func(c *Client) { c.s3 = s3 }
    83  }
    84  
    85  // NewClient creates a new Client for S3.
    86  func NewClient(
    87  	config Config, userAuth UserAuthConfig, opts ...Option) (*Client, error) {
    88  
    89  	config.applyDefaults()
    90  	if config.Username == "" {
    91  		return nil, errors.New("invalid config: username required")
    92  	}
    93  	if config.Region == "" {
    94  		return nil, errors.New("invalid config: region required")
    95  	}
    96  	if config.Bucket == "" {
    97  		return nil, errors.New("invalid config: bucket required")
    98  	}
    99  	if !path.IsAbs(config.RootDirectory) {
   100  		return nil, errors.New("invalid config: root_directory must be absolute path")
   101  	}
   102  
   103  	pather, err := namepath.New(config.RootDirectory, config.NamePath)
   104  	if err != nil {
   105  		return nil, fmt.Errorf("namepath: %s", err)
   106  	}
   107  
   108  	auth, ok := userAuth[config.Username]
   109  	if !ok {
   110  		return nil, errors.New("auth not configured for username")
   111  	}
   112  	creds := credentials.NewStaticCredentials(
   113  		auth.S3.AccessKeyID, auth.S3.AccessSecretKey, auth.S3.SessionToken)
   114  
   115  	awsConfig := aws.NewConfig().WithRegion(config.Region).WithCredentials(creds)
   116  
   117  	if config.Endpoint != "" {
   118  		awsConfig = awsConfig.WithEndpoint(config.Endpoint)
   119  	}
   120  
   121  	if config.DisableSSL {
   122  		awsConfig = awsConfig.WithDisableSSL(config.DisableSSL)
   123  	}
   124  
   125  	if config.S3ForcePathStyle {
   126  		awsConfig = awsConfig.WithS3ForcePathStyle(config.S3ForcePathStyle)
   127  	}
   128  
   129  	api := s3.New(session.New(), awsConfig)
   130  
   131  	downloader := s3manager.NewDownloaderWithClient(api, func(d *s3manager.Downloader) {
   132  		d.PartSize = config.DownloadPartSize
   133  		d.Concurrency = config.DownloadConcurrency
   134  	})
   135  
   136  	uploader := s3manager.NewUploaderWithClient(api, func(u *s3manager.Uploader) {
   137  		u.PartSize = config.UploadPartSize
   138  		u.Concurrency = config.UploadConcurrency
   139  	})
   140  
   141  	client := &Client{config, pather, join{api, downloader, uploader}}
   142  	for _, opt := range opts {
   143  		opt(client)
   144  	}
   145  	return client, nil
   146  }
   147  
   148  // Stat returns blob info for name.
   149  func (c *Client) Stat(namespace, name string) (*core.BlobInfo, error) {
   150  	path, err := c.pather.BlobPath(name)
   151  	if err != nil {
   152  		return nil, fmt.Errorf("blob path: %s", err)
   153  	}
   154  	output, err := c.s3.HeadObject(&s3.HeadObjectInput{
   155  		Bucket: aws.String(c.config.Bucket),
   156  		Key:    aws.String(path),
   157  	})
   158  	if err != nil {
   159  		if isNotFound(err) {
   160  			return nil, backenderrors.ErrBlobNotFound
   161  		}
   162  		return nil, err
   163  	}
   164  	var size int64
   165  	if output.ContentLength != nil {
   166  		size = *output.ContentLength
   167  	}
   168  	return core.NewBlobInfo(size), nil
   169  }
   170  
   171  // Download downloads the content from a configured bucket and writes the
   172  // data to dst.
   173  func (c *Client) Download(namespace, name string, dst io.Writer) error {
   174  	path, err := c.pather.BlobPath(name)
   175  	if err != nil {
   176  		return fmt.Errorf("blob path: %s", err)
   177  	}
   178  
   179  	// The S3 download API uses io.WriterAt to perform concurrent chunked download.
   180  	// We attempt to upcast dst to io.WriterAt for this purpose, else we download into
   181  	// in-memory buffer and drain it into dst after the download is finished.
   182  	writerAt, ok := dst.(io.WriterAt)
   183  	if !ok {
   184  		writerAt = rwutil.NewCappedBuffer(int(c.config.BufferGuard))
   185  	}
   186  
   187  	input := &s3.GetObjectInput{
   188  		Bucket: aws.String(c.config.Bucket),
   189  		Key:    aws.String(path),
   190  	}
   191  	if _, err := c.s3.Download(writerAt, input); err != nil {
   192  		if isNotFound(err) {
   193  			return backenderrors.ErrBlobNotFound
   194  		}
   195  		return err
   196  	}
   197  
   198  	if capBuf, ok := writerAt.(*rwutil.CappedBuffer); ok {
   199  		if err = capBuf.DrainInto(dst); err != nil {
   200  			return err
   201  		}
   202  	}
   203  
   204  	return nil
   205  }
   206  
   207  // Upload uploads src to a configured bucket.
   208  func (c *Client) Upload(namespace, name string, src io.Reader) error {
   209  	path, err := c.pather.BlobPath(name)
   210  	if err != nil {
   211  		return fmt.Errorf("blob path: %s", err)
   212  	}
   213  	input := &s3manager.UploadInput{
   214  		Bucket: aws.String(c.config.Bucket),
   215  		Key:    aws.String(path),
   216  		Body:   src,
   217  	}
   218  	_, err = c.s3.Upload(input, func(u *s3manager.Uploader) {
   219  		u.LeavePartsOnError = false // Delete the parts if the upload fails.
   220  	})
   221  	return err
   222  }
   223  
   224  func isNotFound(err error) bool {
   225  	awsErr, ok := err.(awserr.Error)
   226  	return ok && (awsErr.Code() == s3.ErrCodeNoSuchKey || awsErr.Code() == "NotFound")
   227  }
   228  
   229  // List lists names with start with prefix.
   230  func (c *Client) List(prefix string, opts ...backend.ListOption) (*backend.ListResult, error) {
   231  	// For whatever reason, the S3 list API does not accept an absolute path
   232  	// for prefix. Thus, the root is stripped from the input and added manually
   233  	// to each output key.
   234  	options := backend.DefaultListOptions()
   235  	for _, opt := range opts {
   236  		opt(options)
   237  	}
   238  
   239  	// If paginiated is enabled use the maximum number of keys requests from thhe options,
   240  	// otherwise fall back to the configuration's max keys
   241  	maxKeys := int64(c.config.ListMaxKeys)
   242  	var continuationToken *string
   243  	if options.Paginated {
   244  		maxKeys = int64(options.MaxKeys)
   245  		// An empty continuationToken should be left as nil when sending paginated list
   246  		// requests to s3
   247  		if options.ContinuationToken != "" {
   248  			continuationToken = aws.String(options.ContinuationToken)
   249  		}
   250  	}
   251  
   252  	var names []string
   253  	nextContinuationToken := ""
   254  	err := c.s3.ListObjectsV2Pages(&s3.ListObjectsV2Input{
   255  		Bucket:            aws.String(c.config.Bucket),
   256  		MaxKeys:           aws.Int64(maxKeys),
   257  		Prefix:            aws.String(path.Join(c.pather.BasePath(), prefix)[1:]),
   258  		ContinuationToken: continuationToken,
   259  	}, func(page *s3.ListObjectsV2Output, last bool) bool {
   260  		for _, object := range page.Contents {
   261  			if object.Key == nil {
   262  				log.With(
   263  					"prefix", prefix,
   264  					"object", object).Error("List encountered nil S3 object key")
   265  				continue
   266  			}
   267  			name, err := c.pather.NameFromBlobPath(path.Join("/", *object.Key))
   268  			if err != nil {
   269  				log.With("key", *object.Key).Errorf("Error converting blob path into name: %s", err)
   270  				continue
   271  			}
   272  			names = append(names, name)
   273  		}
   274  
   275  
   276  		if int64(len(names)) < maxKeys {
   277  			// Continue iterating pages to get more keys
   278  			return true
   279  		}
   280  
   281  		// Attempt to capture the continuation token before we stop iterating pages
   282  		if page.IsTruncated != nil && *page.IsTruncated && page.NextContinuationToken != nil {
   283  			nextContinuationToken = *page.NextContinuationToken
   284  		}
   285  
   286  		return false
   287  	})
   288  
   289  	if err != nil {
   290  		return nil, err
   291  	}
   292  
   293  	return &backend.ListResult{
   294  		Names:             names,
   295  		ContinuationToken: nextContinuationToken,
   296  	}, nil
   297  }