github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/config/config.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // This file incorporates work covered by the following copyright and 16 // permission notice: 17 // 18 // Copyright 2016 Attic Labs, Inc. All rights reserved. 19 // Licensed under the Apache License, version 2.0: 20 // http://www.apache.org/licenses/LICENSE-2.0 21 22 package config 23 24 import ( 25 "bytes" 26 "fmt" 27 "os" 28 "path/filepath" 29 30 "github.com/BurntSushi/toml" 31 32 "github.com/dolthub/dolt/go/store/spec" 33 ) 34 35 // All configuration 36 type Config struct { 37 File string 38 Db map[string]DbConfig 39 AWS AWSConfig 40 } 41 42 // Configuration for a specific database 43 type DbConfig struct { 44 Url string 45 Options map[string]string 46 } 47 48 // Global AWS Config 49 type AWSConfig struct { 50 Region string 51 CredSource string `toml:"cred_source"` 52 CredFile string `toml:"cred_file"` 53 } 54 55 const ( 56 NomsConfigFile = ".nomsconfig" 57 DefaultDbAlias = "default" 58 59 awsRegionParam = "aws_region" 60 awsCredSourceParam = "aws_cred_source" 61 awsCredFileParam = "aws_cred_file" 62 authParam = "authorization" 63 ) 64 65 var ErrNoConfig = fmt.Errorf("no %s found", NomsConfigFile) 66 67 // Find the closest directory containing .nomsconfig starting 68 // in cwd and then searching up ancestor tree. 69 // Look first looking in cwd and then up through its ancestors 70 func FindNomsConfig() (*Config, error) { 71 curDir, err := os.Getwd() 72 if err != nil { 73 return nil, err 74 } 75 for { 76 nomsConfig := filepath.Join(curDir, NomsConfigFile) 77 info, err := os.Stat(nomsConfig) 78 if err == nil && !info.IsDir() { 79 // found 80 return ReadConfig(nomsConfig) 81 } else if err != nil && !os.IsNotExist(err) { 82 // can't read 83 return nil, err 84 } 85 nextDir := filepath.Dir(curDir) 86 if nextDir == curDir { 87 // stop at root 88 return nil, ErrNoConfig 89 } 90 curDir = nextDir 91 } 92 } 93 94 func ReadConfig(name string) (*Config, error) { 95 data, err := os.ReadFile(name) 96 if err != nil { 97 return nil, err 98 } 99 c, err := NewConfig(string(data)) 100 if err != nil { 101 return nil, err 102 } 103 c.File = name 104 return qualifyPaths(name, c) 105 } 106 107 func NewConfig(data string) (*Config, error) { 108 c := new(Config) 109 if _, err := toml.Decode(data, c); err != nil { 110 return nil, err 111 } 112 return c, nil 113 } 114 115 func (c *Config) WriteTo(configHome string) (string, error) { 116 file := filepath.Join(configHome, NomsConfigFile) 117 if err := os.MkdirAll(filepath.Dir(file), os.ModePerm); err != nil { 118 return "", err 119 } 120 if err := os.WriteFile(file, []byte(c.writeableString()), os.ModePerm); err != nil { 121 return "", err 122 } 123 return file, nil 124 } 125 126 // Replace relative directory in path part of spec with an absolute 127 // directory. Assumes the path is relative to the location of the config file 128 func absDbSpec(configHome string, url string) string { 129 dbSpec, err := spec.ForDatabase(url) 130 if err != nil { 131 return url 132 } 133 if dbSpec.Protocol != "nbs" { 134 return url 135 } 136 dbName := dbSpec.DatabaseName 137 if !filepath.IsAbs(dbName) { 138 dbName = filepath.Join(configHome, dbName) 139 } 140 return "nbs:" + dbName 141 } 142 143 func qualifyPaths(configPath string, c *Config) (*Config, error) { 144 file, err := filepath.Abs(configPath) 145 if err != nil { 146 return nil, err 147 } 148 dir := filepath.Dir(file) 149 qc := *c 150 qc.File = file 151 for k, r := range c.Db { 152 qc.Db[k] = DbConfig{absDbSpec(dir, r.Url), r.Options} 153 } 154 return &qc, nil 155 } 156 157 func (c *Config) String() string { 158 var buffer bytes.Buffer 159 if c.File != "" { 160 buffer.WriteString(fmt.Sprintf("file = %s\n", c.File)) 161 } 162 buffer.WriteString(c.writeableString()) 163 return buffer.String() 164 } 165 166 func (c *Config) writeableString() string { 167 var buffer bytes.Buffer 168 for k, r := range c.Db { 169 buffer.WriteString(fmt.Sprintf("[db.%s]\n", k)) 170 buffer.WriteString(fmt.Sprintf("\t"+`url = "%s"`+"\n", r.Url)) 171 172 for optKey, optVal := range r.Options { 173 buffer.WriteString(fmt.Sprintf("\t[db.%s.options]\n", k)) 174 buffer.WriteString(fmt.Sprintf("\t\t%s = \"%s\"\n", optKey, optVal)) 175 } 176 } 177 178 buffer.WriteString("[aws]\n") 179 180 if c.AWS.Region != "" { 181 buffer.WriteString(fmt.Sprintf("\tregion = \"%s\"\n", c.AWS.Region)) 182 } 183 184 if c.AWS.CredSource != "" { 185 buffer.WriteString(fmt.Sprintf("\tcred_source = \"%s\"\n", c.AWS.CredSource)) 186 } 187 188 if c.AWS.CredFile != "" { 189 buffer.WriteString(fmt.Sprintf("\tcred_file = \"%s\"\n", c.AWS.CredFile)) 190 } 191 192 return buffer.String() 193 } 194 195 func (c *Config) getAWSRegion(dbParams map[string]string) string { 196 if dbParams != nil { 197 if val, ok := dbParams[awsRegionParam]; ok { 198 return val 199 } 200 } 201 202 if c.AWS.Region != "" { 203 return c.AWS.Region 204 } 205 206 return "" 207 } 208 209 func (c *Config) getAuthorization(dbParams map[string]string) string { 210 if dbParams != nil { 211 if val, ok := dbParams[authParam]; ok { 212 return val 213 } 214 } 215 216 return "" 217 } 218 219 func (c *Config) getAWSCredentialSource(dbParams map[string]string) spec.AWSCredentialSource { 220 set := false 221 credSourceStr := "" 222 if dbParams != nil { 223 if val, ok := dbParams[awsCredSourceParam]; ok { 224 set = true 225 credSourceStr = val 226 } 227 } 228 229 if !set { 230 credSourceStr = c.AWS.CredSource 231 } 232 233 ct := spec.AWSCredentialSourceFromStr(credSourceStr) 234 235 if ct == spec.InvalidCS { 236 panic(credSourceStr + " is not a valid aws credential source") 237 } 238 239 return ct 240 } 241 242 func (c *Config) getAWSCredFile(dbParams map[string]string) string { 243 if dbParams != nil { 244 if val, ok := dbParams[awsCredFileParam]; ok { 245 return val 246 } 247 } 248 249 return "" 250 } 251 252 // specOptsForConfig Uses config data from the global config and db configuration to 253 // generate the spec.SpecOptions which should be used in calls to spec.For*opts() 254 func specOptsForConfig(c *Config, dbc *DbConfig) spec.SpecOptions { 255 dbParams := dbc.Options 256 257 if c == nil { 258 return spec.SpecOptions{} 259 } else { 260 return spec.SpecOptions{ 261 Authorization: c.getAuthorization(dbParams), 262 AWSRegion: c.getAWSRegion(dbParams), 263 AWSCredSource: c.getAWSCredentialSource(dbParams), 264 AWSCredFile: c.getAWSCredFile(dbParams), 265 } 266 } 267 }