github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/dbfactory/aws.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dbfactory
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"net/url"
    21  	"os"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/aws/aws-sdk-go/aws"
    26  	"github.com/aws/aws-sdk-go/aws/credentials"
    27  	"github.com/aws/aws-sdk-go/aws/session"
    28  	"github.com/aws/aws-sdk-go/service/dynamodb"
    29  	"github.com/aws/aws-sdk-go/service/s3"
    30  
    31  	"github.com/dolthub/dolt/go/libraries/utils/awsrefreshcreds"
    32  	"github.com/dolthub/dolt/go/store/chunks"
    33  	"github.com/dolthub/dolt/go/store/datas"
    34  	"github.com/dolthub/dolt/go/store/nbs"
    35  	"github.com/dolthub/dolt/go/store/prolly/tree"
    36  	"github.com/dolthub/dolt/go/store/types"
    37  )
    38  
    39  const (
    40  	// AWSRegionParam is a creation parameter that can be used to set the AWS region
    41  	AWSRegionParam = "aws-region"
    42  
    43  	// AWSCredsTypeParam is a creation parameter that can be used to set the type of credentials that should be used.
    44  	// valid values are role, env, auto, and file
    45  	AWSCredsTypeParam = "aws-creds-type"
    46  
    47  	// AWSCredsFileParam is a creation parameter that can be used to specify a credential file to use.
    48  	AWSCredsFileParam = "aws-creds-file"
    49  
    50  	//AWSCredsProfile is a creation parameter that can be used to specify which AWS profile to use.
    51  	AWSCredsProfile = "aws-creds-profile"
    52  )
    53  
    54  var AWSFileCredsRefreshDuration = time.Minute
    55  
    56  var AWSCredTypes = []string{RoleCS.String(), EnvCS.String(), FileCS.String()}
    57  
    58  // AWSCredentialSource is an enum type representing the different credential sources (auto, role, env, file, or invalid)
    59  type AWSCredentialSource int
    60  
    61  const (
    62  	InvalidCS AWSCredentialSource = iota - 1
    63  
    64  	// Auto will try env first and fall back to role (This is the default)
    65  	AutoCS
    66  
    67  	// Role Uses the AWS IAM role of the instance for auth
    68  	RoleCS
    69  
    70  	// Env uses the credentials stored in the environment variables AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY
    71  	EnvCS
    72  
    73  	// Uses credentials stored in a file
    74  	FileCS
    75  )
    76  
    77  // String returns the string representation of the of an AWSCredentialSource
    78  func (ct AWSCredentialSource) String() string {
    79  	switch ct {
    80  	case RoleCS:
    81  		return "role"
    82  	case EnvCS:
    83  		return "env"
    84  	case AutoCS:
    85  		return "auto"
    86  	case FileCS:
    87  		return "file"
    88  	default:
    89  		return "invalid"
    90  	}
    91  }
    92  
    93  // AWSCredentialSourceFromStr converts a string to an AWSCredentialSource
    94  func AWSCredentialSourceFromStr(str string) AWSCredentialSource {
    95  	strlwr := strings.TrimSpace(strings.ToLower(str))
    96  	switch strlwr {
    97  	case "", "auto":
    98  		return AutoCS
    99  	case "role":
   100  		return RoleCS
   101  	case "env":
   102  		return EnvCS
   103  	case "file":
   104  		return FileCS
   105  	default:
   106  		return InvalidCS
   107  	}
   108  }
   109  
   110  // AWSFactory is a DBFactory implementation for creating AWS backed databases
   111  type AWSFactory struct {
   112  }
   113  
   114  func (fact AWSFactory) PrepareDB(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}) error {
   115  	// nothing to prepare
   116  	return nil
   117  }
   118  
   119  // CreateDB creates an AWS backed database
   120  func (fact AWSFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}) (datas.Database, types.ValueReadWriter, tree.NodeStore, error) {
   121  	var db datas.Database
   122  	cs, err := fact.newChunkStore(ctx, nbf, urlObj, params)
   123  
   124  	if err != nil {
   125  		return nil, nil, nil, err
   126  	}
   127  
   128  	vrw := types.NewValueStore(cs)
   129  	ns := tree.NewNodeStore(cs)
   130  	db = datas.NewTypesDatabase(vrw, ns)
   131  
   132  	return db, vrw, ns, nil
   133  }
   134  
   135  func (fact AWSFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}) (chunks.ChunkStore, error) {
   136  	parts := strings.SplitN(urlObj.Hostname(), ":", 2) // [table]:[bucket]
   137  	if len(parts) != 2 {
   138  		return nil, errors.New("aws url has an invalid format")
   139  	}
   140  
   141  	opts, err := awsConfigFromParams(params)
   142  
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	dbName, err := validatePath(urlObj.Path)
   148  
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	sess := session.Must(session.NewSessionWithOptions(opts))
   154  	_, err = sess.Config.Credentials.Get()
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  
   159  	q := nbs.NewUnlimitedMemQuotaProvider()
   160  	return nbs.NewAWSStore(ctx, nbf.VersionString(), parts[0], dbName, parts[1], s3.New(sess), dynamodb.New(sess), defaultMemTableSize, q)
   161  }
   162  
   163  func validatePath(path string) (string, error) {
   164  	for len(path) > 0 && path[0] == '/' {
   165  		path = path[1:]
   166  	}
   167  
   168  	pathLen := len(path)
   169  	for pathLen > 0 && path[pathLen-1] == '/' {
   170  		path = path[:pathLen-1]
   171  		pathLen--
   172  	}
   173  
   174  	if len(path) == 0 {
   175  		return "", errors.New("invalid database name")
   176  	}
   177  
   178  	return path, nil
   179  }
   180  
   181  func awsConfigFromParams(params map[string]interface{}) (session.Options, error) {
   182  	awsConfig := aws.NewConfig()
   183  	if val, ok := params[AWSRegionParam]; ok {
   184  		awsConfig = awsConfig.WithRegion(val.(string))
   185  	}
   186  
   187  	awsCredsSource := RoleCS
   188  	if val, ok := params[AWSCredsTypeParam]; ok {
   189  		awsCredsSource = AWSCredentialSourceFromStr(val.(string))
   190  		if awsCredsSource == InvalidCS {
   191  			return session.Options{}, errors.New("invalid value for aws-creds-source")
   192  		}
   193  	}
   194  
   195  	opts := session.Options{
   196  		SharedConfigState: session.SharedConfigEnable,
   197  	}
   198  
   199  	profile := ""
   200  	if val, ok := params[AWSCredsProfile]; ok {
   201  		profile = val.(string)
   202  		opts.Profile = val.(string)
   203  	}
   204  
   205  	filePath, ok := params[AWSCredsFileParam]
   206  	if ok && len(filePath.(string)) != 0 && awsCredsSource == RoleCS {
   207  		awsCredsSource = FileCS
   208  	}
   209  
   210  	switch awsCredsSource {
   211  	case EnvCS:
   212  		awsConfig = awsConfig.WithCredentials(credentials.NewEnvCredentials())
   213  	case FileCS:
   214  		if filePath, ok := params[AWSCredsFileParam]; !ok {
   215  			return opts, os.ErrNotExist
   216  		} else {
   217  			provider := &credentials.SharedCredentialsProvider{
   218  				Filename: filePath.(string),
   219  				Profile:  profile,
   220  			}
   221  			creds := credentials.NewCredentials(awsrefreshcreds.NewRefreshingCredentialsProvider(provider, AWSFileCredsRefreshDuration))
   222  			awsConfig = awsConfig.WithCredentials(creds)
   223  		}
   224  	case AutoCS:
   225  		// start by trying to get the credentials from the environment
   226  		envCreds := credentials.NewEnvCredentials()
   227  		if _, err := envCreds.Get(); err == nil {
   228  			awsConfig = awsConfig.WithCredentials(envCreds)
   229  		} else {
   230  			// if env credentials don't exist try looking for a credentials file
   231  			if filePath, ok := params[AWSCredsFileParam]; ok {
   232  				if _, err := os.Stat(filePath.(string)); err == nil {
   233  					creds := credentials.NewSharedCredentials(filePath.(string), profile)
   234  					awsConfig = awsConfig.WithCredentials(creds)
   235  				}
   236  			}
   237  
   238  			// if file and env do not return valid credentials use the default credentials of the box (same as role)
   239  		}
   240  	case RoleCS:
   241  	default:
   242  	}
   243  
   244  	opts.Config.MergeIn(awsConfig)
   245  
   246  	return opts, nil
   247  }