github.com/rudderlabs/rudder-go-kit@v0.30.0/filemanager/s3manager.go (about) 1 package filemanager 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net/url" 8 "os" 9 "path" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/aws/aws-sdk-go/aws" 15 "github.com/aws/aws-sdk-go/aws/session" 16 "github.com/aws/aws-sdk-go/service/s3" 17 awsS3Manager "github.com/aws/aws-sdk-go/service/s3/s3manager" 18 "github.com/mitchellh/mapstructure" 19 "github.com/samber/lo" 20 21 "github.com/rudderlabs/rudder-go-kit/awsutil" 22 appConfig "github.com/rudderlabs/rudder-go-kit/config" 23 "github.com/rudderlabs/rudder-go-kit/logger" 24 ) 25 26 type S3Config struct { 27 Bucket string `mapstructure:"bucketName"` 28 Prefix string `mapstructure:"Prefix"` 29 Region *string `mapstructure:"region"` 30 Endpoint *string `mapstructure:"endpoint"` 31 S3ForcePathStyle *bool `mapstructure:"s3ForcePathStyle"` 32 DisableSSL *bool `mapstructure:"disableSSL"` 33 EnableSSE bool `mapstructure:"enableSSE"` 34 RegionHint string `mapstructure:"regionHint"` 35 UseGlue bool `mapstructure:"useGlue"` 36 } 37 38 // NewS3Manager creates a new file manager for S3 39 func NewS3Manager( 40 config map[string]interface{}, log logger.Logger, defaultTimeout func() time.Duration, 41 ) (*S3Manager, error) { 42 var s3Config S3Config 43 if err := mapstructure.Decode(config, &s3Config); err != nil { 44 return nil, err 45 } 46 47 sessionConfig, err := awsutil.NewSimpleSessionConfig(config, s3.ServiceName) 48 if err != nil { 49 return nil, err 50 } 51 52 s3Config.RegionHint = appConfig.GetString("AWS_S3_REGION_HINT", "us-east-1") 53 54 return &S3Manager{ 55 baseManager: &baseManager{ 56 logger: log, 57 defaultTimeout: defaultTimeout, 58 }, 59 config: &s3Config, 60 sessionConfig: sessionConfig, 61 }, nil 62 } 63 64 func (m *S3Manager) ListFilesWithPrefix(ctx context.Context, startAfter, prefix string, maxItems int64) ListSession { 65 return &s3ListSession{ 66 baseListSession: &baseListSession{ 67 ctx: ctx, 68 startAfter: startAfter, 69 prefix: prefix, 70 maxItems: maxItems, 71 }, 72 manager: m, 73 isTruncated: true, 74 } 75 } 76 77 // Download downloads a file from S3 78 func (m *S3Manager) Download(ctx context.Context, output *os.File, key string) error { 79 sess, err := m.GetSession(ctx) 80 if err != nil { 81 return fmt.Errorf("error starting S3 session: %w", err) 82 } 83 84 downloader := awsS3Manager.NewDownloader(sess) 85 86 ctx, cancel := context.WithTimeout(ctx, m.getTimeout()) 87 defer cancel() 88 89 _, err = downloader.DownloadWithContext(ctx, output, 90 &s3.GetObjectInput{ 91 Bucket: aws.String(m.config.Bucket), 92 Key: aws.String(key), 93 }) 94 if err != nil { 95 if codeErr, ok := err.(codeError); ok && codeErr.Code() == "NoSuchKey" { 96 return ErrKeyNotFound 97 } 98 return err 99 } 100 return nil 101 } 102 103 // Upload uploads a file to S3 104 func (m *S3Manager) Upload(ctx context.Context, file *os.File, prefixes ...string) (UploadedFile, error) { 105 fileName := path.Join(m.config.Prefix, path.Join(prefixes...), path.Base(file.Name())) 106 107 uploadInput := &awsS3Manager.UploadInput{ 108 ACL: aws.String("bucket-owner-full-control"), 109 Bucket: aws.String(m.config.Bucket), 110 Key: aws.String(fileName), 111 Body: file, 112 } 113 if m.config.EnableSSE { 114 uploadInput.ServerSideEncryption = aws.String("AES256") 115 } 116 117 uploadSession, err := m.GetSession(ctx) 118 if err != nil { 119 return UploadedFile{}, fmt.Errorf("error starting S3 session: %w", err) 120 } 121 s3manager := awsS3Manager.NewUploader(uploadSession) 122 123 ctx, cancel := context.WithTimeout(ctx, m.getTimeout()) 124 defer cancel() 125 126 output, err := s3manager.UploadWithContext(ctx, uploadInput) 127 if err != nil { 128 if codeErr, ok := err.(codeError); ok && codeErr.Code() == "MissingRegion" { 129 err = fmt.Errorf(fmt.Sprintf(`Bucket '%s' not found.`, m.config.Bucket)) 130 } 131 return UploadedFile{}, err 132 } 133 134 return UploadedFile{Location: output.Location, ObjectName: fileName}, err 135 } 136 137 func (m *S3Manager) Delete(ctx context.Context, keys []string) (err error) { 138 sess, err := m.GetSession(ctx) 139 if err != nil { 140 return fmt.Errorf("error starting S3 session: %w", err) 141 } 142 143 var objects []*s3.ObjectIdentifier 144 for _, key := range keys { 145 objects = append(objects, &s3.ObjectIdentifier{Key: aws.String(key)}) 146 } 147 148 svc := s3.New(sess) 149 150 batchSize := 1000 // max accepted by DeleteObjects API 151 chunks := lo.Chunk(objects, batchSize) 152 for _, chunk := range chunks { 153 input := &s3.DeleteObjectsInput{ 154 Bucket: aws.String(m.config.Bucket), 155 Delete: &s3.Delete{ 156 Objects: chunk, 157 }, 158 } 159 160 deleteCtx, cancel := context.WithTimeout(ctx, m.getTimeout()) 161 _, err := svc.DeleteObjectsWithContext(deleteCtx, input) 162 cancel() 163 164 if err != nil { 165 if codeErr, ok := err.(codeError); ok { 166 m.logger.Errorf(`Error while deleting S3 objects: %v, error code: %v`, err.Error(), codeErr.Code()) 167 } else { 168 m.logger.Errorf(`Error while deleting S3 objects: %v`, err.Error()) 169 } 170 return err 171 } 172 } 173 return nil 174 } 175 176 func (m *S3Manager) Prefix() string { 177 return m.config.Prefix 178 } 179 180 func (m *S3Manager) Bucket() string { 181 return m.config.Bucket 182 } 183 184 /* 185 GetObjectNameFromLocation gets the object name/key name from the object location url 186 187 https://bucket-name.s3.amazonaws.com/key - >> key 188 */ 189 func (m *S3Manager) GetObjectNameFromLocation(location string) (string, error) { 190 parsedUrl, err := url.Parse(location) 191 if err != nil { 192 return "", err 193 } 194 trimmedURL := strings.TrimLeft(parsedUrl.Path, "/") 195 if (m.config.S3ForcePathStyle != nil && *m.config.S3ForcePathStyle) || 196 (!strings.Contains(parsedUrl.Host, m.config.Bucket)) { 197 return strings.TrimPrefix(trimmedURL, fmt.Sprintf(`%s/`, m.config.Bucket)), nil 198 } 199 return trimmedURL, nil 200 } 201 202 func (m *S3Manager) GetDownloadKeyFromFileLocation(location string) string { 203 parsedURL, err := url.Parse(location) 204 if err != nil { 205 fmt.Println("error while parsing location url: ", err) 206 } 207 trimmedURL := strings.TrimLeft(parsedURL.Path, "/") 208 if (m.config.S3ForcePathStyle != nil && *m.config.S3ForcePathStyle) || 209 (!strings.Contains(parsedURL.Host, m.config.Bucket)) { 210 return strings.TrimPrefix(trimmedURL, fmt.Sprintf(`%s/`, m.config.Bucket)) 211 } 212 return trimmedURL 213 } 214 215 func (m *S3Manager) GetSession(ctx context.Context) (*session.Session, error) { 216 m.sessionMu.Lock() 217 defer m.sessionMu.Unlock() 218 219 if m.session != nil { 220 return m.session, nil 221 } 222 223 if m.config.Bucket == "" { 224 return nil, errors.New("no storage bucket configured to downloader") 225 } 226 227 if !m.config.UseGlue || m.config.Region == nil { 228 getRegionSession, err := session.NewSession() 229 if err != nil { 230 return nil, err 231 } 232 233 ctx, cancel := context.WithTimeout(ctx, m.getTimeout()) 234 defer cancel() 235 236 region, err := awsS3Manager.GetBucketRegion(ctx, getRegionSession, m.config.Bucket, m.config.RegionHint) 237 if err != nil { 238 m.logger.Errorf("Failed to fetch AWS region for bucket %s. Error %v", m.config.Bucket, err) 239 // Failed to get Region probably due to VPC restrictions 240 // Will proceed to try with AccessKeyID and AccessKey 241 } 242 m.config.Region = aws.String(region) 243 m.sessionConfig.Region = region 244 } 245 246 var err error 247 m.session, err = awsutil.CreateSession(m.sessionConfig) 248 if err != nil { 249 return nil, err 250 } 251 return m.session, err 252 } 253 254 type S3Manager struct { 255 *baseManager 256 config *S3Config 257 258 sessionConfig *awsutil.SessionConfig 259 session *session.Session 260 sessionMu sync.Mutex 261 } 262 263 func (m *S3Manager) getTimeout() time.Duration { 264 if m.timeout > 0 { 265 return m.timeout 266 } 267 if m.defaultTimeout != nil { 268 return m.defaultTimeout() 269 } 270 return defaultTimeout 271 } 272 273 type s3ListSession struct { 274 *baseListSession 275 manager *S3Manager 276 277 continuationToken *string 278 isTruncated bool 279 } 280 281 func (l *s3ListSession) Next() (fileObjects []*FileInfo, err error) { 282 manager := l.manager 283 if !l.isTruncated { 284 manager.logger.Infof("Manager is truncated: %v so returning here", l.isTruncated) 285 return 286 } 287 fileObjects = make([]*FileInfo, 0) 288 289 sess, err := manager.GetSession(l.ctx) 290 if err != nil { 291 return []*FileInfo{}, fmt.Errorf("error starting S3 session: %w", err) 292 } 293 // Create S3 service client 294 svc := s3.New(sess) 295 listObjectsV2Input := s3.ListObjectsV2Input{ 296 Bucket: aws.String(manager.config.Bucket), 297 Prefix: aws.String(l.prefix), 298 MaxKeys: &l.maxItems, 299 // Delimiter: aws.String("/"), 300 } 301 // startAfter is to resume a paused task. 302 if l.startAfter != "" { 303 listObjectsV2Input.StartAfter = aws.String(l.startAfter) 304 } 305 306 if l.continuationToken != nil { 307 listObjectsV2Input.ContinuationToken = l.continuationToken 308 } 309 310 ctx, cancel := context.WithTimeout(l.ctx, manager.getTimeout()) 311 defer cancel() 312 313 // Get the list of items 314 resp, err := svc.ListObjectsV2WithContext(ctx, &listObjectsV2Input) 315 if err != nil { 316 manager.logger.Errorf("Error while listing S3 objects: %v", err) 317 return 318 } 319 if resp.IsTruncated != nil { 320 l.isTruncated = *resp.IsTruncated 321 } 322 l.continuationToken = resp.NextContinuationToken 323 for _, item := range resp.Contents { 324 fileObjects = append(fileObjects, &FileInfo{*item.Key, *item.LastModified}) 325 } 326 return 327 } 328 329 type codeError interface { 330 Code() string 331 }