github.com/datreeio/datree@v1.9.22-rc/pkg/localConfig/localConfig.go (about)

     1  package localConfig
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"path/filepath"
     7  	"strings"
     8  
     9  	"github.com/datreeio/datree/pkg/networkValidator"
    10  
    11  	"github.com/datreeio/datree/pkg/cliClient"
    12  	"github.com/lithammer/shortuuid"
    13  	"github.com/spf13/viper"
    14  )
    15  
    16  type LocalConfig struct {
    17  	Token           string
    18  	ClientId        string
    19  	SchemaVersion   string
    20  	Offline         string
    21  	PolicyConfig    string
    22  	SchemaLocations []string
    23  }
    24  
    25  type TokenClient interface {
    26  	CreateToken() (*cliClient.CreateTokenResponse, error)
    27  }
    28  
    29  type LocalConfigClient struct {
    30  	tokenClient      TokenClient
    31  	networkValidator *networkValidator.NetworkValidator
    32  }
    33  
    34  func NewLocalConfigClient(t TokenClient, nv *networkValidator.NetworkValidator) *LocalConfigClient {
    35  	return &LocalConfigClient{
    36  		tokenClient:      t,
    37  		networkValidator: nv,
    38  	}
    39  }
    40  
    41  const (
    42  	clientIdKey        = "client_id"
    43  	tokenKey           = "token"
    44  	schemaVersionKey   = "schema_version"
    45  	offlineKey         = "offline"
    46  	policyConfigKey    = "policy_config"
    47  	schemaLocationsKey = "schema_locations"
    48  )
    49  
    50  func (lc *LocalConfigClient) GetLocalConfiguration() (*LocalConfig, error) {
    51  	viper.SetEnvPrefix("datree")
    52  	viper.AutomaticEnv()
    53  
    54  	initConfigFileErr := InitLocalConfigFile()
    55  	if initConfigFileErr != nil {
    56  		return nil, initConfigFileErr
    57  	}
    58  
    59  	token := viper.GetString(tokenKey)
    60  	clientId := viper.GetString(clientIdKey)
    61  	schemaVersion := viper.GetString(schemaVersionKey)
    62  	offline := viper.GetString(offlineKey)
    63  	policyConfig := viper.GetString(policyConfigKey)
    64  	schemaLocations := viper.GetStringSlice(schemaLocationsKey)
    65  
    66  	if offline == "" {
    67  		offline = "fail"
    68  		err := setViperVariable(offlineKey, offline)
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  	}
    73  	lc.networkValidator.SetOfflineMode(offline)
    74  
    75  	if token == "" {
    76  		createTokenResponse, err := lc.tokenClient.CreateToken()
    77  		if err != nil {
    78  			return nil, err
    79  		}
    80  		token = createTokenResponse.Token
    81  		if token != "" {
    82  			err = setViperVariable(tokenKey, token)
    83  			if err != nil {
    84  				return nil, err
    85  			}
    86  		}
    87  	}
    88  
    89  	if clientId == "" {
    90  		clientId = shortuuid.New()
    91  		err := setViperVariable(clientIdKey, clientId)
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  	}
    96  
    97  	return &LocalConfig{Token: token, ClientId: clientId, SchemaVersion: schemaVersion, Offline: offline, PolicyConfig: policyConfig, SchemaLocations: schemaLocations}, nil
    98  }
    99  
   100  func (lc *LocalConfigClient) Set(key string, value string) error {
   101  	initConfigFileErr := InitLocalConfigFile()
   102  	if initConfigFileErr != nil {
   103  		return initConfigFileErr
   104  	}
   105  
   106  	err := validateKeyValueConfig(key, value)
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	if key == policyConfigKey {
   112  		absPath, _ := filepath.Abs(value)
   113  		viper.Set(policyConfigKey, absPath)
   114  	} else if key == schemaLocationsKey {
   115  		viper.Set(schemaLocationsKey, strings.Split(value, ","))
   116  	} else {
   117  		viper.Set(key, value)
   118  	}
   119  
   120  	writeClientIdErr := viper.WriteConfig()
   121  	if writeClientIdErr != nil {
   122  		return writeClientIdErr
   123  	}
   124  	return nil
   125  }
   126  
   127  func InitLocalConfigFile() error {
   128  	configHome, configName, configType, err := setViperConfig()
   129  	if err != nil {
   130  		return err
   131  	}
   132  	// workaround for creating config file when not exist
   133  	// open issue in viper: https://github.com/spf13/viper/issues/430
   134  	// should be fixed in pr https://github.com/spf13/viper/pull/936
   135  	configPath := filepath.Join(configHome, configName+"."+configType)
   136  
   137  	isDirExists, err := exists(configHome)
   138  	if err != nil {
   139  		return err
   140  	}
   141  	if !isDirExists {
   142  		osMkdirErr := os.Mkdir(configHome, os.ModePerm)
   143  		if osMkdirErr != nil {
   144  			return osMkdirErr
   145  		}
   146  	}
   147  
   148  	isConfigExists, err := exists(configPath)
   149  	if err != nil {
   150  		return err
   151  	}
   152  	if !isConfigExists {
   153  		_, osCreateErr := os.Create(configPath)
   154  		if osCreateErr != nil {
   155  			return osCreateErr
   156  		}
   157  	}
   158  
   159  	err = viper.ReadInConfig()
   160  	if err != nil {
   161  		return err
   162  	}
   163  
   164  	return nil
   165  }
   166  
   167  func exists(path string) (bool, error) {
   168  	_, err := os.Stat(path)
   169  	if err == nil {
   170  		return true, nil
   171  	}
   172  	if os.IsNotExist(err) {
   173  		return false, nil
   174  	}
   175  	return false, err
   176  }
   177  
   178  func (lc *LocalConfigClient) Get(key string) string {
   179  	return viper.GetString(key)
   180  }
   181  
   182  func getConfigHome() (string, error) {
   183  	homedir, err := os.UserHomeDir()
   184  	if err != nil {
   185  		return "", err
   186  	}
   187  
   188  	configHome := filepath.Join(homedir, ".datree")
   189  
   190  	return configHome, nil
   191  }
   192  
   193  func getConfigName() string {
   194  	return "config"
   195  }
   196  
   197  func getConfigType() string {
   198  	return "yaml"
   199  }
   200  
   201  func setViperConfig() (string, string, string, error) {
   202  	configHome, err := getConfigHome()
   203  	if err != nil {
   204  		return "", "", "", err
   205  	}
   206  
   207  	configName := getConfigName()
   208  	configType := getConfigType()
   209  
   210  	viper.SetConfigName(configName)
   211  	viper.SetConfigType(configType)
   212  	viper.AddConfigPath(configHome)
   213  
   214  	return configHome, configName, configType, nil
   215  }
   216  
   217  func setViperVariable(key string, value string) error {
   218  	if value == "" {
   219  		return fmt.Errorf("value is empty")
   220  	}
   221  
   222  	viper.Set(key, value)
   223  
   224  	err := viper.WriteConfig()
   225  	if err != nil {
   226  		return err
   227  	}
   228  	err = viper.ReadInConfig()
   229  	if err != nil {
   230  		return err
   231  	}
   232  
   233  	return nil
   234  }
   235  
   236  func validateKeyValueConfig(key string, value string) error {
   237  	if key == "offline" && value != "fail" && value != "local" {
   238  		return fmt.Errorf("invalid offline configuration value- %q\n"+
   239  			"Valid offline values are - fail, local", value)
   240  	}
   241  	return nil
   242  }