vitess.io/vitess@v0.16.2/go/vt/mysqlctl/s3backupstorage/s3.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package s3backupstorage implements the BackupStorage interface for AWS S3.
    18  //
    19  // AWS access credentials are configured via standard AWS means, such as:
    20  // - AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables
    21  // - credentials file at ~/.aws/credentials
    22  // - if running on an EC2 instance, an IAM role
    23  // See details at http://blogs.aws.amazon.com/security/post/Tx3D6U6WSFGOK2H/A-New-and-Standardized-Way-to-Manage-Credentials-in-the-AWS-SDKs
    24  package s3backupstorage
    25  
    26  import (
    27  	"context"
    28  	"crypto/md5"
    29  	"crypto/tls"
    30  	"encoding/base64"
    31  	"fmt"
    32  	"io"
    33  	"math"
    34  	"net/http"
    35  	"os"
    36  	"sort"
    37  	"strings"
    38  	"sync"
    39  
    40  	"github.com/aws/aws-sdk-go/aws"
    41  	"github.com/aws/aws-sdk-go/aws/client"
    42  	"github.com/aws/aws-sdk-go/aws/request"
    43  	"github.com/aws/aws-sdk-go/aws/session"
    44  	"github.com/aws/aws-sdk-go/service/s3"
    45  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    46  	"github.com/aws/aws-sdk-go/service/s3/s3manager"
    47  	"github.com/spf13/pflag"
    48  
    49  	"vitess.io/vitess/go/vt/concurrency"
    50  	"vitess.io/vitess/go/vt/log"
    51  	"vitess.io/vitess/go/vt/mysqlctl/backupstorage"
    52  	"vitess.io/vitess/go/vt/servenv"
    53  )
    54  
    55  var (
    56  	// AWS API region
    57  	region string
    58  
    59  	// AWS request retries
    60  	retryCount int
    61  
    62  	// AWS endpoint, defaults to amazonaws.com but appliances may use a different location
    63  	endpoint string
    64  
    65  	// bucket is where the backups will go.
    66  	bucket string
    67  
    68  	// root is a prefix added to all object names.
    69  	root string
    70  
    71  	// forcePath is used to ensure that the certificate and path used match the endpoint + region
    72  	forcePath bool
    73  
    74  	tlsSkipVerifyCert bool
    75  
    76  	// verboseLogging provides more verbose logging of AWS actions
    77  	requiredLogLevel string
    78  
    79  	// sse is the server-side encryption algorithm used when storing this object in S3
    80  	sse string
    81  
    82  	// path component delimiter
    83  	delimiter = "/"
    84  )
    85  
    86  func registerFlags(fs *pflag.FlagSet) {
    87  	fs.StringVar(&region, "s3_backup_aws_region", "us-east-1", "AWS region to use.")
    88  	fs.IntVar(&retryCount, "s3_backup_aws_retries", -1, "AWS request retries.")
    89  	fs.StringVar(&endpoint, "s3_backup_aws_endpoint", "", "endpoint of the S3 backend (region must be provided).")
    90  	fs.StringVar(&bucket, "s3_backup_storage_bucket", "", "S3 bucket to use for backups.")
    91  	fs.StringVar(&root, "s3_backup_storage_root", "", "root prefix for all backup-related object names.")
    92  	fs.BoolVar(&forcePath, "s3_backup_force_path_style", false, "force the s3 path style.")
    93  	fs.BoolVar(&tlsSkipVerifyCert, "s3_backup_tls_skip_verify_cert", false, "skip the 'certificate is valid' check for SSL connections.")
    94  	fs.StringVar(&requiredLogLevel, "s3_backup_log_level", "LogOff", "determine the S3 loglevel to use from LogOff, LogDebug, LogDebugWithSigning, LogDebugWithHTTPBody, LogDebugWithRequestRetries, LogDebugWithRequestErrors.")
    95  	fs.StringVar(&sse, "s3_backup_server_side_encryption", "", "server-side encryption algorithm (e.g., AES256, aws:kms, sse_c:/path/to/key/file).")
    96  }
    97  
    98  func init() {
    99  	servenv.OnParseFor("vtbackup", registerFlags)
   100  	servenv.OnParseFor("vtctl", registerFlags)
   101  	servenv.OnParseFor("vtctld", registerFlags)
   102  	servenv.OnParseFor("vttablet", registerFlags)
   103  }
   104  
   105  type logNameToLogLevel map[string]aws.LogLevelType
   106  
   107  var logNameMap logNameToLogLevel
   108  
   109  const sseCustomerPrefix = "sse_c:"
   110  
   111  // S3BackupHandle implements the backupstorage.BackupHandle interface.
   112  type S3BackupHandle struct {
   113  	client    s3iface.S3API
   114  	bs        *S3BackupStorage
   115  	dir       string
   116  	name      string
   117  	readOnly  bool
   118  	errors    concurrency.AllErrorRecorder
   119  	waitGroup sync.WaitGroup
   120  }
   121  
   122  // Directory is part of the backupstorage.BackupHandle interface.
   123  func (bh *S3BackupHandle) Directory() string {
   124  	return bh.dir
   125  }
   126  
   127  // Name is part of the backupstorage.BackupHandle interface.
   128  func (bh *S3BackupHandle) Name() string {
   129  	return bh.name
   130  }
   131  
   132  // RecordError is part of the concurrency.ErrorRecorder interface.
   133  func (bh *S3BackupHandle) RecordError(err error) {
   134  	bh.errors.RecordError(err)
   135  }
   136  
   137  // HasErrors is part of the concurrency.ErrorRecorder interface.
   138  func (bh *S3BackupHandle) HasErrors() bool {
   139  	return bh.errors.HasErrors()
   140  }
   141  
   142  // Error is part of the concurrency.ErrorRecorder interface.
   143  func (bh *S3BackupHandle) Error() error {
   144  	return bh.errors.Error()
   145  }
   146  
   147  // AddFile is part of the backupstorage.BackupHandle interface.
   148  func (bh *S3BackupHandle) AddFile(ctx context.Context, filename string, filesize int64) (io.WriteCloser, error) {
   149  	if bh.readOnly {
   150  		return nil, fmt.Errorf("AddFile cannot be called on read-only backup")
   151  	}
   152  
   153  	// Calculate s3 upload part size using the source filesize
   154  	partSizeBytes := s3manager.DefaultUploadPartSize
   155  	if filesize > 0 {
   156  		minimumPartSize := float64(filesize) / float64(s3manager.MaxUploadParts)
   157  		// Round up to ensure large enough partsize
   158  		calculatedPartSizeBytes := int64(math.Ceil(minimumPartSize))
   159  		if calculatedPartSizeBytes > partSizeBytes {
   160  			partSizeBytes = calculatedPartSizeBytes
   161  		}
   162  	}
   163  
   164  	reader, writer := io.Pipe()
   165  	bh.waitGroup.Add(1)
   166  
   167  	go func() {
   168  		defer bh.waitGroup.Done()
   169  		uploader := s3manager.NewUploaderWithClient(bh.client, func(u *s3manager.Uploader) {
   170  			u.PartSize = partSizeBytes
   171  		})
   172  		object := objName(bh.dir, bh.name, filename)
   173  
   174  		_, err := uploader.Upload(&s3manager.UploadInput{
   175  			Bucket:               &bucket,
   176  			Key:                  object,
   177  			Body:                 reader,
   178  			ServerSideEncryption: bh.bs.s3SSE.awsAlg,
   179  			SSECustomerAlgorithm: bh.bs.s3SSE.customerAlg,
   180  			SSECustomerKey:       bh.bs.s3SSE.customerKey,
   181  			SSECustomerKeyMD5:    bh.bs.s3SSE.customerMd5,
   182  		})
   183  		if err != nil {
   184  			reader.CloseWithError(err)
   185  			bh.RecordError(err)
   186  		}
   187  	}()
   188  
   189  	return writer, nil
   190  }
   191  
   192  // EndBackup is part of the backupstorage.BackupHandle interface.
   193  func (bh *S3BackupHandle) EndBackup(ctx context.Context) error {
   194  	if bh.readOnly {
   195  		return fmt.Errorf("EndBackup cannot be called on read-only backup")
   196  	}
   197  	bh.waitGroup.Wait()
   198  	return bh.Error()
   199  }
   200  
   201  // AbortBackup is part of the backupstorage.BackupHandle interface.
   202  func (bh *S3BackupHandle) AbortBackup(ctx context.Context) error {
   203  	if bh.readOnly {
   204  		return fmt.Errorf("AbortBackup cannot be called on read-only backup")
   205  	}
   206  	return bh.bs.RemoveBackup(ctx, bh.dir, bh.name)
   207  }
   208  
   209  // ReadFile is part of the backupstorage.BackupHandle interface.
   210  func (bh *S3BackupHandle) ReadFile(ctx context.Context, filename string) (io.ReadCloser, error) {
   211  	if !bh.readOnly {
   212  		return nil, fmt.Errorf("ReadFile cannot be called on read-write backup")
   213  	}
   214  	object := objName(bh.dir, bh.name, filename)
   215  	out, err := bh.client.GetObject(&s3.GetObjectInput{
   216  		Bucket:               &bucket,
   217  		Key:                  object,
   218  		SSECustomerAlgorithm: bh.bs.s3SSE.customerAlg,
   219  		SSECustomerKey:       bh.bs.s3SSE.customerKey,
   220  		SSECustomerKeyMD5:    bh.bs.s3SSE.customerMd5,
   221  	})
   222  	if err != nil {
   223  		return nil, err
   224  	}
   225  	return out.Body, nil
   226  }
   227  
   228  var _ backupstorage.BackupHandle = (*S3BackupHandle)(nil)
   229  
   230  type S3ServerSideEncryption struct {
   231  	awsAlg      *string
   232  	customerAlg *string
   233  	customerKey *string
   234  	customerMd5 *string
   235  }
   236  
   237  func (s3ServerSideEncryption *S3ServerSideEncryption) init() error {
   238  	s3ServerSideEncryption.reset()
   239  
   240  	if strings.HasPrefix(sse, sseCustomerPrefix) {
   241  		sseCustomerKeyFile := strings.TrimPrefix(sse, sseCustomerPrefix)
   242  		base64CodedKey, err := os.ReadFile(sseCustomerKeyFile)
   243  		if err != nil {
   244  			log.Errorf(err.Error())
   245  			return err
   246  		}
   247  
   248  		decodedKey, err := base64.StdEncoding.DecodeString(string(base64CodedKey))
   249  		if err != nil {
   250  			decodedKey = base64CodedKey
   251  		}
   252  
   253  		md5Hash := md5.Sum(decodedKey)
   254  		s3ServerSideEncryption.customerAlg = aws.String("AES256")
   255  		s3ServerSideEncryption.customerKey = aws.String(string(decodedKey))
   256  		s3ServerSideEncryption.customerMd5 = aws.String(base64.StdEncoding.EncodeToString(md5Hash[:]))
   257  	} else if sse != "" {
   258  		s3ServerSideEncryption.awsAlg = &sse
   259  	}
   260  	return nil
   261  }
   262  
   263  func (s3ServerSideEncryption *S3ServerSideEncryption) reset() {
   264  	s3ServerSideEncryption.awsAlg = nil
   265  	s3ServerSideEncryption.customerAlg = nil
   266  	s3ServerSideEncryption.customerKey = nil
   267  	s3ServerSideEncryption.customerMd5 = nil
   268  }
   269  
   270  // S3BackupStorage implements the backupstorage.BackupStorage interface.
   271  type S3BackupStorage struct {
   272  	_client *s3.S3
   273  	mu      sync.Mutex
   274  	s3SSE   S3ServerSideEncryption
   275  }
   276  
   277  // ListBackups is part of the backupstorage.BackupStorage interface.
   278  func (bs *S3BackupStorage) ListBackups(ctx context.Context, dir string) ([]backupstorage.BackupHandle, error) {
   279  	log.Infof("ListBackups: [s3] dir: %v, bucket: %v", dir, bucket)
   280  	c, err := bs.client()
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  
   285  	var searchPrefix *string
   286  	if dir == "/" {
   287  		searchPrefix = objName("")
   288  	} else {
   289  		searchPrefix = objName(dir, "")
   290  	}
   291  	log.Infof("objName: %v", *searchPrefix)
   292  
   293  	query := &s3.ListObjectsV2Input{
   294  		Bucket:    &bucket,
   295  		Delimiter: &delimiter,
   296  		Prefix:    searchPrefix,
   297  	}
   298  
   299  	var subdirs []string
   300  	for {
   301  		objs, err := c.ListObjectsV2(query)
   302  		if err != nil {
   303  			return nil, err
   304  		}
   305  		for _, prefix := range objs.CommonPrefixes {
   306  			subdir := strings.TrimPrefix(*prefix.Prefix, *searchPrefix)
   307  			subdir = strings.TrimSuffix(subdir, delimiter)
   308  			subdirs = append(subdirs, subdir)
   309  		}
   310  
   311  		if objs.NextContinuationToken == nil {
   312  			break
   313  		}
   314  		query.ContinuationToken = objs.NextContinuationToken
   315  	}
   316  
   317  	// Backups must be returned in order, oldest first.
   318  	sort.Strings(subdirs)
   319  
   320  	result := make([]backupstorage.BackupHandle, 0, len(subdirs))
   321  	for _, subdir := range subdirs {
   322  		result = append(result, &S3BackupHandle{
   323  			client:   c,
   324  			bs:       bs,
   325  			dir:      dir,
   326  			name:     subdir,
   327  			readOnly: true,
   328  		})
   329  	}
   330  	return result, nil
   331  }
   332  
   333  // StartBackup is part of the backupstorage.BackupStorage interface.
   334  func (bs *S3BackupStorage) StartBackup(ctx context.Context, dir, name string) (backupstorage.BackupHandle, error) {
   335  	log.Infof("StartBackup: [s3] dir: %v, name: %v, bucket: %v", dir, name, bucket)
   336  	c, err := bs.client()
   337  	if err != nil {
   338  		return nil, err
   339  	}
   340  
   341  	return &S3BackupHandle{
   342  		client:   c,
   343  		bs:       bs,
   344  		dir:      dir,
   345  		name:     name,
   346  		readOnly: false,
   347  	}, nil
   348  }
   349  
   350  // RemoveBackup is part of the backupstorage.BackupStorage interface.
   351  func (bs *S3BackupStorage) RemoveBackup(ctx context.Context, dir, name string) error {
   352  	log.Infof("RemoveBackup: [s3] dir: %v, name: %v, bucket: %v", dir, name, bucket)
   353  
   354  	c, err := bs.client()
   355  	if err != nil {
   356  		return err
   357  	}
   358  
   359  	query := &s3.ListObjectsV2Input{
   360  		Bucket: &bucket,
   361  		Prefix: objName(dir, name),
   362  	}
   363  
   364  	for {
   365  		objs, err := c.ListObjectsV2(query)
   366  		if err != nil {
   367  			return err
   368  		}
   369  
   370  		objIds := make([]*s3.ObjectIdentifier, 0, len(objs.Contents))
   371  		for _, obj := range objs.Contents {
   372  			objIds = append(objIds, &s3.ObjectIdentifier{
   373  				Key: obj.Key,
   374  			})
   375  		}
   376  
   377  		quiet := true // return less in the Delete response
   378  		out, err := c.DeleteObjects(&s3.DeleteObjectsInput{
   379  			Bucket: &bucket,
   380  			Delete: &s3.Delete{
   381  				Objects: objIds,
   382  				Quiet:   &quiet,
   383  			},
   384  		})
   385  
   386  		if err != nil {
   387  			return err
   388  		}
   389  
   390  		for _, objError := range out.Errors {
   391  			return fmt.Errorf(objError.String())
   392  		}
   393  
   394  		if objs.NextContinuationToken == nil {
   395  			break
   396  		}
   397  
   398  		query.ContinuationToken = objs.NextContinuationToken
   399  	}
   400  
   401  	return nil
   402  }
   403  
   404  // Close is part of the backupstorage.BackupStorage interface.
   405  func (bs *S3BackupStorage) Close() error {
   406  	bs.mu.Lock()
   407  	defer bs.mu.Unlock()
   408  	bs._client = nil
   409  	bs.s3SSE.reset()
   410  	return nil
   411  }
   412  
   413  var _ backupstorage.BackupStorage = (*S3BackupStorage)(nil)
   414  
   415  // getLogLevel converts the string loglevel to an aws.LogLevelType
   416  func getLogLevel() *aws.LogLevelType {
   417  	l := new(aws.LogLevelType)
   418  	*l = aws.LogOff // default setting
   419  	if level, found := logNameMap[requiredLogLevel]; found {
   420  		*l = level // adjust as required
   421  	}
   422  	return l
   423  }
   424  
   425  func (bs *S3BackupStorage) client() (*s3.S3, error) {
   426  	bs.mu.Lock()
   427  	defer bs.mu.Unlock()
   428  	if bs._client == nil {
   429  		logLevel := getLogLevel()
   430  
   431  		tlsClientConf := &tls.Config{InsecureSkipVerify: tlsSkipVerifyCert}
   432  		httpTransport := &http.Transport{TLSClientConfig: tlsClientConf}
   433  		httpClient := &http.Client{Transport: httpTransport}
   434  
   435  		session, err := session.NewSession()
   436  		if err != nil {
   437  			return nil, err
   438  		}
   439  
   440  		awsConfig := aws.Config{
   441  			HTTPClient:       httpClient,
   442  			LogLevel:         logLevel,
   443  			Endpoint:         aws.String(endpoint),
   444  			Region:           aws.String(region),
   445  			S3ForcePathStyle: aws.Bool(forcePath),
   446  		}
   447  
   448  		if retryCount >= 0 {
   449  			awsConfig = *request.WithRetryer(&awsConfig, &ClosedConnectionRetryer{
   450  				awsRetryer: &client.DefaultRetryer{
   451  					NumMaxRetries: retryCount,
   452  				},
   453  			})
   454  		}
   455  
   456  		bs._client = s3.New(session, &awsConfig)
   457  
   458  		if len(bucket) == 0 {
   459  			return nil, fmt.Errorf("--s3_backup_storage_bucket required")
   460  		}
   461  
   462  		if _, err := bs._client.HeadBucket(&s3.HeadBucketInput{Bucket: &bucket}); err != nil {
   463  			return nil, err
   464  		}
   465  
   466  		if err := bs.s3SSE.init(); err != nil {
   467  			return nil, err
   468  		}
   469  	}
   470  	return bs._client, nil
   471  }
   472  
   473  func objName(parts ...string) *string {
   474  	res := ""
   475  	if root != "" {
   476  		res += root + delimiter
   477  	}
   478  	res += strings.Join(parts, delimiter)
   479  	return &res
   480  }
   481  
   482  func init() {
   483  	backupstorage.BackupStorageMap["s3"] = &S3BackupStorage{}
   484  
   485  	logNameMap = logNameToLogLevel{
   486  		"LogOff":                     aws.LogOff,
   487  		"LogDebug":                   aws.LogDebug,
   488  		"LogDebugWithSigning":        aws.LogDebugWithSigning,
   489  		"LogDebugWithHTTPBody":       aws.LogDebugWithHTTPBody,
   490  		"LogDebugWithRequestRetries": aws.LogDebugWithRequestRetries,
   491  		"LogDebugWithRequestErrors":  aws.LogDebugWithRequestErrors,
   492  	}
   493  }