github.com/hyperledger/burrow@v0.34.5-0.20220512172541-77f09336001d/config/source/source.go (about)

     1  package source
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"os"
    10  	"reflect"
    11  	"strings"
    12  
    13  	"regexp"
    14  
    15  	"github.com/BurntSushi/toml"
    16  	"github.com/cep21/xdgbasedir"
    17  	"github.com/imdario/mergo"
    18  )
    19  
    20  // If passed this identifier try to read config from STDIN
    21  const STDINFileIdentifier = "-"
    22  
    23  type ConfigProvider interface {
    24  	// Description of where this provider sources its config from
    25  	From() string
    26  	// Get the config values to the passed in baseConfig
    27  	Apply(baseConfig interface{}) error
    28  	// Return a copy of the provider that does nothing if skip is true
    29  	SetSkip(skip bool) ConfigProvider
    30  	// Whether to skip this provider
    31  	Skip() bool
    32  }
    33  
    34  var _ ConfigProvider = &configSource{}
    35  
    36  type Format string
    37  
    38  const (
    39  	JSON    Format = "JSON"
    40  	TOML    Format = "TOML"
    41  	Unknown Format = ""
    42  )
    43  
    44  var LogWriter io.Writer
    45  
    46  var jsonRegex = regexp.MustCompile(`^\s*{`)
    47  
    48  type configSource struct {
    49  	from  string
    50  	skip  bool
    51  	apply func(baseConfig interface{}) error
    52  }
    53  
    54  func init() {
    55  	LogWriter = os.Stderr
    56  }
    57  
    58  func NewConfigProvider(from string, skip bool, apply func(baseConfig interface{}) error) *configSource {
    59  	return &configSource{
    60  		from:  from,
    61  		skip:  skip,
    62  		apply: apply,
    63  	}
    64  }
    65  
    66  func (cs *configSource) From() string {
    67  	return cs.from
    68  }
    69  
    70  func (cs *configSource) Apply(baseConfig interface{}) error {
    71  	return cs.apply(baseConfig)
    72  }
    73  
    74  func (cs *configSource) Skip() bool {
    75  	return cs.skip
    76  }
    77  
    78  // Returns a copy of the configSource with skip set as passed in
    79  func (cs *configSource) SetSkip(skip bool) ConfigProvider {
    80  	return &configSource{
    81  		skip:  skip,
    82  		from:  cs.from,
    83  		apply: cs.apply,
    84  	}
    85  }
    86  
    87  // Builds a ConfigProvider by iterating over a cascade of ConfigProvider sources. Can be used
    88  // in two distinct modes: with shortCircuit true the first successful ConfigProvider source
    89  // is returned. With shortCircuit false sources appearing later are used to possibly override
    90  // those appearing earlier
    91  func Cascade(shortCircuit bool, providers ...ConfigProvider) *configSource {
    92  	var fromStrings []string
    93  	skip := true
    94  	for _, provider := range providers {
    95  		if !provider.Skip() {
    96  			skip = false
    97  			fromStrings = append(fromStrings, provider.From())
    98  		}
    99  	}
   100  	fromPrefix := "each of"
   101  	if shortCircuit {
   102  		fromPrefix = "first of"
   103  
   104  	}
   105  	return &configSource{
   106  		skip: skip,
   107  		from: fmt.Sprintf("%s: %s", fromPrefix, strings.Join(fromStrings, " then ")),
   108  		apply: func(baseConfig interface{}) error {
   109  			if baseConfig == nil {
   110  				return fmt.Errorf("baseConfig passed to Cascade(...).Get() must not be nil")
   111  			}
   112  			for _, provider := range providers {
   113  				if !provider.Skip() {
   114  					writeLog(LogWriter, fmt.Sprintf("Sourcing config from %s", provider.From()))
   115  					err := provider.Apply(baseConfig)
   116  					if err != nil {
   117  						return err
   118  					}
   119  					if shortCircuit {
   120  						return nil
   121  					}
   122  				}
   123  			}
   124  			return nil
   125  		},
   126  	}
   127  }
   128  
   129  func FirstOf(providers ...ConfigProvider) *configSource {
   130  	return Cascade(true, providers...)
   131  }
   132  
   133  func EachOf(providers ...ConfigProvider) *configSource {
   134  	return Cascade(false, providers...)
   135  }
   136  
   137  // Try to source config from provided file detecting the file format, is skipNonExistent is true then the provider will
   138  // fall-through (skip) when the file doesn't exist, rather than returning an error
   139  func File(configFile string, skipNonExistent bool) *configSource {
   140  	var from string
   141  	if configFile == STDINFileIdentifier {
   142  		from = "Config from STDIN"
   143  	} else {
   144  		from = fmt.Sprintf("Config file at '%s'", configFile)
   145  	}
   146  	return &configSource{
   147  		skip: ShouldSkipFile(configFile, skipNonExistent),
   148  		from: from,
   149  		apply: func(baseConfig interface{}) error {
   150  			return FromFile(configFile, baseConfig)
   151  		},
   152  	}
   153  }
   154  
   155  // Try to find config by using XDG base dir spec
   156  func XDGBaseDir(configFileName string) *configSource {
   157  	skip := false
   158  	// Look for config in standard XDG specified locations
   159  	configFile, err := xdgbasedir.GetConfigFileLocation(configFileName)
   160  	if err == nil {
   161  		_, err := os.Stat(configFile)
   162  		// Skip if config  file does not exist at default location
   163  		skip = os.IsNotExist(err)
   164  	}
   165  	return &configSource{
   166  		skip: skip,
   167  		from: fmt.Sprintf("XDG base dir"),
   168  		apply: func(baseConfig interface{}) error {
   169  			if err != nil {
   170  				return err
   171  			}
   172  			return FromFile(configFile, baseConfig)
   173  		},
   174  	}
   175  }
   176  
   177  // Source from a single environment variable with config embedded in JSON
   178  func Environment(key string) *configSource {
   179  	configString := os.Getenv(key)
   180  	return &configSource{
   181  		skip: configString == "",
   182  		from: fmt.Sprintf("'%s' environment variable", key),
   183  		apply: func(baseConfig interface{}) error {
   184  			return FromString(configString, baseConfig)
   185  		},
   186  	}
   187  }
   188  
   189  func Default(defaultConfig interface{}) *configSource {
   190  	return &configSource{
   191  		from: "defaults",
   192  		apply: func(baseConfig interface{}) error {
   193  			return mergo.MergeWithOverwrite(baseConfig, defaultConfig)
   194  		},
   195  	}
   196  }
   197  
   198  func FromFile(configFile string, conf interface{}) error {
   199  	bs, err := ReadFile(configFile)
   200  	if err != nil {
   201  		return err
   202  	}
   203  
   204  	return FromString(string(bs), conf)
   205  }
   206  
   207  func FromTOMLString(tomlString string, conf interface{}) error {
   208  	_, err := toml.Decode(tomlString, conf)
   209  	if err != nil {
   210  		return err
   211  	}
   212  	return nil
   213  }
   214  
   215  func FromString(configString string, conf interface{}) error {
   216  	switch DetectFormat(configString) {
   217  	case JSON:
   218  		return FromJSONString(configString, conf)
   219  	case TOML:
   220  		return FromTOMLString(configString, conf)
   221  	default:
   222  		return fmt.Errorf("unknown configuration format:\n%s", configString)
   223  	}
   224  }
   225  
   226  func DetectFormat(configString string) Format {
   227  	if jsonRegex.MatchString(configString) {
   228  		return JSON
   229  	}
   230  	return TOML
   231  }
   232  
   233  func FromJSONString(jsonString string, conf interface{}) error {
   234  	err := json.Unmarshal(([]byte)(jsonString), conf)
   235  	if err != nil {
   236  		return err
   237  	}
   238  	return nil
   239  }
   240  
   241  func TOMLString(conf interface{}) string {
   242  	buf := new(bytes.Buffer)
   243  	encoder := toml.NewEncoder(buf)
   244  	err := encoder.Encode(conf)
   245  	if err != nil {
   246  		return fmt.Sprintf("<Could not serialise config: %v>", err)
   247  	}
   248  	return buf.String()
   249  }
   250  
   251  func JSONString(conf interface{}) string {
   252  	bs, err := json.MarshalIndent(conf, "", "  ")
   253  	if err != nil {
   254  		return fmt.Sprintf("<Could not serialise config: %v>", err)
   255  	}
   256  	return string(bs)
   257  }
   258  
   259  func Merge(base, override interface{}) (interface{}, error) {
   260  	merged, err := DeepCopy(base)
   261  	if err != nil {
   262  		return nil, err
   263  	}
   264  	err = mergo.MergeWithOverwrite(merged, override)
   265  	if err != nil {
   266  		return nil, err
   267  	}
   268  	return merged, nil
   269  }
   270  
   271  // Passed a pointer to struct creates a deep copy of the struct
   272  func DeepCopy(conf interface{}) (interface{}, error) {
   273  	// Create a zero value
   274  	confCopy := reflect.New(reflect.TypeOf(conf).Elem()).Interface()
   275  	// Perform a merge into that value to effect the copy
   276  	err := mergo.Merge(confCopy, conf)
   277  	if err != nil {
   278  		return nil, err
   279  	}
   280  	return confCopy, nil
   281  }
   282  
   283  func writeLog(writer io.Writer, msg string) {
   284  	if writer != nil {
   285  		writer.Write(([]byte)(msg))
   286  		writer.Write(([]byte)("\n"))
   287  	}
   288  }
   289  
   290  func ReadFile(file string) ([]byte, error) {
   291  	if file == STDINFileIdentifier {
   292  		return ioutil.ReadAll(os.Stdin)
   293  	}
   294  	return ioutil.ReadFile(file)
   295  }
   296  
   297  func ShouldSkipFile(file string, skipNonExistent bool) bool {
   298  	skip := file == ""
   299  	if !skip && skipNonExistent {
   300  		_, err := os.Stat(file)
   301  		skip = os.IsNotExist(err)
   302  	}
   303  	return skip
   304  }