github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/config/config.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // This file incorporates work covered by the following copyright and
    16  // permission notice:
    17  //
    18  // Copyright 2016 Attic Labs, Inc. All rights reserved.
    19  // Licensed under the Apache License, version 2.0:
    20  // http://www.apache.org/licenses/LICENSE-2.0
    21  
    22  package config
    23  
    24  import (
    25  	"bytes"
    26  	"fmt"
    27  	"os"
    28  	"path/filepath"
    29  
    30  	"github.com/BurntSushi/toml"
    31  
    32  	"github.com/dolthub/dolt/go/store/spec"
    33  )
    34  
    35  // All configuration
    36  type Config struct {
    37  	File string
    38  	Db   map[string]DbConfig
    39  	AWS  AWSConfig
    40  }
    41  
    42  // Configuration for a specific database
    43  type DbConfig struct {
    44  	Url     string
    45  	Options map[string]string
    46  }
    47  
    48  // Global AWS Config
    49  type AWSConfig struct {
    50  	Region     string
    51  	CredSource string `toml:"cred_source"`
    52  	CredFile   string `toml:"cred_file"`
    53  }
    54  
    55  const (
    56  	NomsConfigFile = ".nomsconfig"
    57  	DefaultDbAlias = "default"
    58  
    59  	awsRegionParam     = "aws_region"
    60  	awsCredSourceParam = "aws_cred_source"
    61  	awsCredFileParam   = "aws_cred_file"
    62  	authParam          = "authorization"
    63  )
    64  
    65  var ErrNoConfig = fmt.Errorf("no %s found", NomsConfigFile)
    66  
    67  // Find the closest directory containing .nomsconfig starting
    68  // in cwd and then searching up ancestor tree.
    69  // Look first looking in cwd and then up through its ancestors
    70  func FindNomsConfig() (*Config, error) {
    71  	curDir, err := os.Getwd()
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	for {
    76  		nomsConfig := filepath.Join(curDir, NomsConfigFile)
    77  		info, err := os.Stat(nomsConfig)
    78  		if err == nil && !info.IsDir() {
    79  			// found
    80  			return ReadConfig(nomsConfig)
    81  		} else if err != nil && !os.IsNotExist(err) {
    82  			// can't read
    83  			return nil, err
    84  		}
    85  		nextDir := filepath.Dir(curDir)
    86  		if nextDir == curDir {
    87  			// stop at root
    88  			return nil, ErrNoConfig
    89  		}
    90  		curDir = nextDir
    91  	}
    92  }
    93  
    94  func ReadConfig(name string) (*Config, error) {
    95  	data, err := os.ReadFile(name)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	c, err := NewConfig(string(data))
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	c.File = name
   104  	return qualifyPaths(name, c)
   105  }
   106  
   107  func NewConfig(data string) (*Config, error) {
   108  	c := new(Config)
   109  	if _, err := toml.Decode(data, c); err != nil {
   110  		return nil, err
   111  	}
   112  	return c, nil
   113  }
   114  
   115  func (c *Config) WriteTo(configHome string) (string, error) {
   116  	file := filepath.Join(configHome, NomsConfigFile)
   117  	if err := os.MkdirAll(filepath.Dir(file), os.ModePerm); err != nil {
   118  		return "", err
   119  	}
   120  	if err := os.WriteFile(file, []byte(c.writeableString()), os.ModePerm); err != nil {
   121  		return "", err
   122  	}
   123  	return file, nil
   124  }
   125  
   126  // Replace relative directory in path part of spec with an absolute
   127  // directory. Assumes the path is relative to the location of the config file
   128  func absDbSpec(configHome string, url string) string {
   129  	dbSpec, err := spec.ForDatabase(url)
   130  	if err != nil {
   131  		return url
   132  	}
   133  	if dbSpec.Protocol != "nbs" {
   134  		return url
   135  	}
   136  	dbName := dbSpec.DatabaseName
   137  	if !filepath.IsAbs(dbName) {
   138  		dbName = filepath.Join(configHome, dbName)
   139  	}
   140  	return "nbs:" + dbName
   141  }
   142  
   143  func qualifyPaths(configPath string, c *Config) (*Config, error) {
   144  	file, err := filepath.Abs(configPath)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	dir := filepath.Dir(file)
   149  	qc := *c
   150  	qc.File = file
   151  	for k, r := range c.Db {
   152  		qc.Db[k] = DbConfig{absDbSpec(dir, r.Url), r.Options}
   153  	}
   154  	return &qc, nil
   155  }
   156  
   157  func (c *Config) String() string {
   158  	var buffer bytes.Buffer
   159  	if c.File != "" {
   160  		buffer.WriteString(fmt.Sprintf("file = %s\n", c.File))
   161  	}
   162  	buffer.WriteString(c.writeableString())
   163  	return buffer.String()
   164  }
   165  
   166  func (c *Config) writeableString() string {
   167  	var buffer bytes.Buffer
   168  	for k, r := range c.Db {
   169  		buffer.WriteString(fmt.Sprintf("[db.%s]\n", k))
   170  		buffer.WriteString(fmt.Sprintf("\t"+`url = "%s"`+"\n", r.Url))
   171  
   172  		for optKey, optVal := range r.Options {
   173  			buffer.WriteString(fmt.Sprintf("\t[db.%s.options]\n", k))
   174  			buffer.WriteString(fmt.Sprintf("\t\t%s = \"%s\"\n", optKey, optVal))
   175  		}
   176  	}
   177  
   178  	buffer.WriteString("[aws]\n")
   179  
   180  	if c.AWS.Region != "" {
   181  		buffer.WriteString(fmt.Sprintf("\tregion = \"%s\"\n", c.AWS.Region))
   182  	}
   183  
   184  	if c.AWS.CredSource != "" {
   185  		buffer.WriteString(fmt.Sprintf("\tcred_source = \"%s\"\n", c.AWS.CredSource))
   186  	}
   187  
   188  	if c.AWS.CredFile != "" {
   189  		buffer.WriteString(fmt.Sprintf("\tcred_file = \"%s\"\n", c.AWS.CredFile))
   190  	}
   191  
   192  	return buffer.String()
   193  }
   194  
   195  func (c *Config) getAWSRegion(dbParams map[string]string) string {
   196  	if dbParams != nil {
   197  		if val, ok := dbParams[awsRegionParam]; ok {
   198  			return val
   199  		}
   200  	}
   201  
   202  	if c.AWS.Region != "" {
   203  		return c.AWS.Region
   204  	}
   205  
   206  	return ""
   207  }
   208  
   209  func (c *Config) getAuthorization(dbParams map[string]string) string {
   210  	if dbParams != nil {
   211  		if val, ok := dbParams[authParam]; ok {
   212  			return val
   213  		}
   214  	}
   215  
   216  	return ""
   217  }
   218  
   219  func (c *Config) getAWSCredentialSource(dbParams map[string]string) spec.AWSCredentialSource {
   220  	set := false
   221  	credSourceStr := ""
   222  	if dbParams != nil {
   223  		if val, ok := dbParams[awsCredSourceParam]; ok {
   224  			set = true
   225  			credSourceStr = val
   226  		}
   227  	}
   228  
   229  	if !set {
   230  		credSourceStr = c.AWS.CredSource
   231  	}
   232  
   233  	ct := spec.AWSCredentialSourceFromStr(credSourceStr)
   234  
   235  	if ct == spec.InvalidCS {
   236  		panic(credSourceStr + " is not a valid aws credential source")
   237  	}
   238  
   239  	return ct
   240  }
   241  
   242  func (c *Config) getAWSCredFile(dbParams map[string]string) string {
   243  	if dbParams != nil {
   244  		if val, ok := dbParams[awsCredFileParam]; ok {
   245  			return val
   246  		}
   247  	}
   248  
   249  	return ""
   250  }
   251  
   252  // specOptsForConfig Uses config data from the global config and db configuration to
   253  // generate the spec.SpecOptions which should be used in calls to spec.For*opts()
   254  func specOptsForConfig(c *Config, dbc *DbConfig) spec.SpecOptions {
   255  	dbParams := dbc.Options
   256  
   257  	if c == nil {
   258  		return spec.SpecOptions{}
   259  	} else {
   260  		return spec.SpecOptions{
   261  			Authorization: c.getAuthorization(dbParams),
   262  			AWSRegion:     c.getAWSRegion(dbParams),
   263  			AWSCredSource: c.getAWSCredentialSource(dbParams),
   264  			AWSCredFile:   c.getAWSCredFile(dbParams),
   265  		}
   266  	}
   267  }