github.com/rudderlabs/rudder-go-kit@v0.30.0/filemanager/s3manager.go (about)

     1  package filemanager
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/url"
     8  	"os"
     9  	"path"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/aws/aws-sdk-go/aws"
    15  	"github.com/aws/aws-sdk-go/aws/session"
    16  	"github.com/aws/aws-sdk-go/service/s3"
    17  	awsS3Manager "github.com/aws/aws-sdk-go/service/s3/s3manager"
    18  	"github.com/mitchellh/mapstructure"
    19  	"github.com/samber/lo"
    20  
    21  	"github.com/rudderlabs/rudder-go-kit/awsutil"
    22  	appConfig "github.com/rudderlabs/rudder-go-kit/config"
    23  	"github.com/rudderlabs/rudder-go-kit/logger"
    24  )
    25  
    26  type S3Config struct {
    27  	Bucket           string  `mapstructure:"bucketName"`
    28  	Prefix           string  `mapstructure:"Prefix"`
    29  	Region           *string `mapstructure:"region"`
    30  	Endpoint         *string `mapstructure:"endpoint"`
    31  	S3ForcePathStyle *bool   `mapstructure:"s3ForcePathStyle"`
    32  	DisableSSL       *bool   `mapstructure:"disableSSL"`
    33  	EnableSSE        bool    `mapstructure:"enableSSE"`
    34  	RegionHint       string  `mapstructure:"regionHint"`
    35  	UseGlue          bool    `mapstructure:"useGlue"`
    36  }
    37  
    38  // NewS3Manager creates a new file manager for S3
    39  func NewS3Manager(
    40  	config map[string]interface{}, log logger.Logger, defaultTimeout func() time.Duration,
    41  ) (*S3Manager, error) {
    42  	var s3Config S3Config
    43  	if err := mapstructure.Decode(config, &s3Config); err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	sessionConfig, err := awsutil.NewSimpleSessionConfig(config, s3.ServiceName)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	s3Config.RegionHint = appConfig.GetString("AWS_S3_REGION_HINT", "us-east-1")
    53  
    54  	return &S3Manager{
    55  		baseManager: &baseManager{
    56  			logger:         log,
    57  			defaultTimeout: defaultTimeout,
    58  		},
    59  		config:        &s3Config,
    60  		sessionConfig: sessionConfig,
    61  	}, nil
    62  }
    63  
    64  func (m *S3Manager) ListFilesWithPrefix(ctx context.Context, startAfter, prefix string, maxItems int64) ListSession {
    65  	return &s3ListSession{
    66  		baseListSession: &baseListSession{
    67  			ctx:        ctx,
    68  			startAfter: startAfter,
    69  			prefix:     prefix,
    70  			maxItems:   maxItems,
    71  		},
    72  		manager:     m,
    73  		isTruncated: true,
    74  	}
    75  }
    76  
    77  // Download downloads a file from S3
    78  func (m *S3Manager) Download(ctx context.Context, output *os.File, key string) error {
    79  	sess, err := m.GetSession(ctx)
    80  	if err != nil {
    81  		return fmt.Errorf("error starting S3 session: %w", err)
    82  	}
    83  
    84  	downloader := awsS3Manager.NewDownloader(sess)
    85  
    86  	ctx, cancel := context.WithTimeout(ctx, m.getTimeout())
    87  	defer cancel()
    88  
    89  	_, err = downloader.DownloadWithContext(ctx, output,
    90  		&s3.GetObjectInput{
    91  			Bucket: aws.String(m.config.Bucket),
    92  			Key:    aws.String(key),
    93  		})
    94  	if err != nil {
    95  		if codeErr, ok := err.(codeError); ok && codeErr.Code() == "NoSuchKey" {
    96  			return ErrKeyNotFound
    97  		}
    98  		return err
    99  	}
   100  	return nil
   101  }
   102  
   103  // Upload uploads a file to S3
   104  func (m *S3Manager) Upload(ctx context.Context, file *os.File, prefixes ...string) (UploadedFile, error) {
   105  	fileName := path.Join(m.config.Prefix, path.Join(prefixes...), path.Base(file.Name()))
   106  
   107  	uploadInput := &awsS3Manager.UploadInput{
   108  		ACL:    aws.String("bucket-owner-full-control"),
   109  		Bucket: aws.String(m.config.Bucket),
   110  		Key:    aws.String(fileName),
   111  		Body:   file,
   112  	}
   113  	if m.config.EnableSSE {
   114  		uploadInput.ServerSideEncryption = aws.String("AES256")
   115  	}
   116  
   117  	uploadSession, err := m.GetSession(ctx)
   118  	if err != nil {
   119  		return UploadedFile{}, fmt.Errorf("error starting S3 session: %w", err)
   120  	}
   121  	s3manager := awsS3Manager.NewUploader(uploadSession)
   122  
   123  	ctx, cancel := context.WithTimeout(ctx, m.getTimeout())
   124  	defer cancel()
   125  
   126  	output, err := s3manager.UploadWithContext(ctx, uploadInput)
   127  	if err != nil {
   128  		if codeErr, ok := err.(codeError); ok && codeErr.Code() == "MissingRegion" {
   129  			err = fmt.Errorf(fmt.Sprintf(`Bucket '%s' not found.`, m.config.Bucket))
   130  		}
   131  		return UploadedFile{}, err
   132  	}
   133  
   134  	return UploadedFile{Location: output.Location, ObjectName: fileName}, err
   135  }
   136  
   137  func (m *S3Manager) Delete(ctx context.Context, keys []string) (err error) {
   138  	sess, err := m.GetSession(ctx)
   139  	if err != nil {
   140  		return fmt.Errorf("error starting S3 session: %w", err)
   141  	}
   142  
   143  	var objects []*s3.ObjectIdentifier
   144  	for _, key := range keys {
   145  		objects = append(objects, &s3.ObjectIdentifier{Key: aws.String(key)})
   146  	}
   147  
   148  	svc := s3.New(sess)
   149  
   150  	batchSize := 1000 // max accepted by DeleteObjects API
   151  	chunks := lo.Chunk(objects, batchSize)
   152  	for _, chunk := range chunks {
   153  		input := &s3.DeleteObjectsInput{
   154  			Bucket: aws.String(m.config.Bucket),
   155  			Delete: &s3.Delete{
   156  				Objects: chunk,
   157  			},
   158  		}
   159  
   160  		deleteCtx, cancel := context.WithTimeout(ctx, m.getTimeout())
   161  		_, err := svc.DeleteObjectsWithContext(deleteCtx, input)
   162  		cancel()
   163  
   164  		if err != nil {
   165  			if codeErr, ok := err.(codeError); ok {
   166  				m.logger.Errorf(`Error while deleting S3 objects: %v, error code: %v`, err.Error(), codeErr.Code())
   167  			} else {
   168  				m.logger.Errorf(`Error while deleting S3 objects: %v`, err.Error())
   169  			}
   170  			return err
   171  		}
   172  	}
   173  	return nil
   174  }
   175  
   176  func (m *S3Manager) Prefix() string {
   177  	return m.config.Prefix
   178  }
   179  
   180  func (m *S3Manager) Bucket() string {
   181  	return m.config.Bucket
   182  }
   183  
   184  /*
   185  GetObjectNameFromLocation gets the object name/key name from the object location url
   186  
   187  	https://bucket-name.s3.amazonaws.com/key - >> key
   188  */
   189  func (m *S3Manager) GetObjectNameFromLocation(location string) (string, error) {
   190  	parsedUrl, err := url.Parse(location)
   191  	if err != nil {
   192  		return "", err
   193  	}
   194  	trimmedURL := strings.TrimLeft(parsedUrl.Path, "/")
   195  	if (m.config.S3ForcePathStyle != nil && *m.config.S3ForcePathStyle) ||
   196  		(!strings.Contains(parsedUrl.Host, m.config.Bucket)) {
   197  		return strings.TrimPrefix(trimmedURL, fmt.Sprintf(`%s/`, m.config.Bucket)), nil
   198  	}
   199  	return trimmedURL, nil
   200  }
   201  
   202  func (m *S3Manager) GetDownloadKeyFromFileLocation(location string) string {
   203  	parsedURL, err := url.Parse(location)
   204  	if err != nil {
   205  		fmt.Println("error while parsing location url: ", err)
   206  	}
   207  	trimmedURL := strings.TrimLeft(parsedURL.Path, "/")
   208  	if (m.config.S3ForcePathStyle != nil && *m.config.S3ForcePathStyle) ||
   209  		(!strings.Contains(parsedURL.Host, m.config.Bucket)) {
   210  		return strings.TrimPrefix(trimmedURL, fmt.Sprintf(`%s/`, m.config.Bucket))
   211  	}
   212  	return trimmedURL
   213  }
   214  
   215  func (m *S3Manager) GetSession(ctx context.Context) (*session.Session, error) {
   216  	m.sessionMu.Lock()
   217  	defer m.sessionMu.Unlock()
   218  
   219  	if m.session != nil {
   220  		return m.session, nil
   221  	}
   222  
   223  	if m.config.Bucket == "" {
   224  		return nil, errors.New("no storage bucket configured to downloader")
   225  	}
   226  
   227  	if !m.config.UseGlue || m.config.Region == nil {
   228  		getRegionSession, err := session.NewSession()
   229  		if err != nil {
   230  			return nil, err
   231  		}
   232  
   233  		ctx, cancel := context.WithTimeout(ctx, m.getTimeout())
   234  		defer cancel()
   235  
   236  		region, err := awsS3Manager.GetBucketRegion(ctx, getRegionSession, m.config.Bucket, m.config.RegionHint)
   237  		if err != nil {
   238  			m.logger.Errorf("Failed to fetch AWS region for bucket %s. Error %v", m.config.Bucket, err)
   239  			// Failed to get Region probably due to VPC restrictions
   240  			// Will proceed to try with AccessKeyID and AccessKey
   241  		}
   242  		m.config.Region = aws.String(region)
   243  		m.sessionConfig.Region = region
   244  	}
   245  
   246  	var err error
   247  	m.session, err = awsutil.CreateSession(m.sessionConfig)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	return m.session, err
   252  }
   253  
   254  type S3Manager struct {
   255  	*baseManager
   256  	config *S3Config
   257  
   258  	sessionConfig *awsutil.SessionConfig
   259  	session       *session.Session
   260  	sessionMu     sync.Mutex
   261  }
   262  
   263  func (m *S3Manager) getTimeout() time.Duration {
   264  	if m.timeout > 0 {
   265  		return m.timeout
   266  	}
   267  	if m.defaultTimeout != nil {
   268  		return m.defaultTimeout()
   269  	}
   270  	return defaultTimeout
   271  }
   272  
   273  type s3ListSession struct {
   274  	*baseListSession
   275  	manager *S3Manager
   276  
   277  	continuationToken *string
   278  	isTruncated       bool
   279  }
   280  
   281  func (l *s3ListSession) Next() (fileObjects []*FileInfo, err error) {
   282  	manager := l.manager
   283  	if !l.isTruncated {
   284  		manager.logger.Infof("Manager is truncated: %v so returning here", l.isTruncated)
   285  		return
   286  	}
   287  	fileObjects = make([]*FileInfo, 0)
   288  
   289  	sess, err := manager.GetSession(l.ctx)
   290  	if err != nil {
   291  		return []*FileInfo{}, fmt.Errorf("error starting S3 session: %w", err)
   292  	}
   293  	// Create S3 service client
   294  	svc := s3.New(sess)
   295  	listObjectsV2Input := s3.ListObjectsV2Input{
   296  		Bucket:  aws.String(manager.config.Bucket),
   297  		Prefix:  aws.String(l.prefix),
   298  		MaxKeys: &l.maxItems,
   299  		// Delimiter: aws.String("/"),
   300  	}
   301  	// startAfter is to resume a paused task.
   302  	if l.startAfter != "" {
   303  		listObjectsV2Input.StartAfter = aws.String(l.startAfter)
   304  	}
   305  
   306  	if l.continuationToken != nil {
   307  		listObjectsV2Input.ContinuationToken = l.continuationToken
   308  	}
   309  
   310  	ctx, cancel := context.WithTimeout(l.ctx, manager.getTimeout())
   311  	defer cancel()
   312  
   313  	// Get the list of items
   314  	resp, err := svc.ListObjectsV2WithContext(ctx, &listObjectsV2Input)
   315  	if err != nil {
   316  		manager.logger.Errorf("Error while listing S3 objects: %v", err)
   317  		return
   318  	}
   319  	if resp.IsTruncated != nil {
   320  		l.isTruncated = *resp.IsTruncated
   321  	}
   322  	l.continuationToken = resp.NextContinuationToken
   323  	for _, item := range resp.Contents {
   324  		fileObjects = append(fileObjects, &FileInfo{*item.Key, *item.LastModified})
   325  	}
   326  	return
   327  }
   328  
   329  type codeError interface {
   330  	Code() string
   331  }