github.com/clysto/awgo@v0.15.0/config_bind.go (about)

     1  //
     2  // Copyright (c) 2018 Dean Jackson <deanishe@deanishe.net>
     3  //
     4  // MIT Licence. See http://opensource.org/licenses/MIT
     5  //
     6  // Created on 2018-06-30
     7  //
     8  
     9  package aw
    10  
    11  import (
    12  	"fmt"
    13  	"log"
    14  	"reflect"
    15  	"regexp"
    16  	"strconv"
    17  	"strings"
    18  	"time"
    19  )
    20  
    21  // To populates (tagged) struct v with values from the environment.
    22  func (cfg *Config) To(v interface{}) error {
    23  
    24  	binds, err := extract(v)
    25  	if err != nil {
    26  		return err
    27  	}
    28  
    29  	for _, bind := range binds {
    30  		if err := bind.Import(cfg); err != nil {
    31  			return err
    32  		}
    33  	}
    34  
    35  	return nil
    36  }
    37  
    38  // From saves the fields of (tagged) struct v to the workflow's settings in Alfred.
    39  //
    40  // All supported and unignored fields are saved, although empty variables
    41  // (i.e. "") are not overwritten with Go zero values, e.g. "0" or "false".
    42  func (cfg *Config) From(v interface{}) error {
    43  
    44  	variables, err := cfg.bindVars(v)
    45  	if err != nil {
    46  		return err
    47  	}
    48  
    49  	return cfg.setMulti(variables, false)
    50  }
    51  
    52  // extract binding values as {ENVVAR: value} map.
    53  func (cfg *Config) bindVars(v interface{}) (map[string]string, error) {
    54  
    55  	variables := map[string]string{}
    56  
    57  	binds, err := extract(v)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	for _, bind := range binds {
    63  		if k, v, ok := bind.GetVar(cfg); ok {
    64  			variables[k] = v
    65  		}
    66  	}
    67  
    68  	return variables, nil
    69  }
    70  
    71  // setMulti batches the saving of multiple variables.
    72  func (cfg *Config) setMulti(variables map[string]string, export bool) error {
    73  
    74  	for k, v := range variables {
    75  		cfg.Set(k, v, export)
    76  	}
    77  
    78  	return cfg.Do()
    79  }
    80  
    81  // binding links an environment variable to the field of a struct.
    82  type binding struct {
    83  	Name     string
    84  	EnvVar   string
    85  	FieldNum int
    86  	Target   interface{}
    87  	Kind     reflect.Kind
    88  }
    89  
    90  type bindSource interface {
    91  	GetBool(key string, fallback ...bool) bool
    92  	GetInt(key string, fallback ...int) int
    93  	GetFloat(key string, fallback ...float64) float64
    94  	GetString(key string, fallback ...string) string
    95  }
    96  
    97  type bindDest interface {
    98  	GetString(key string, fallback ...string) string
    99  	// SetConfig(key, value string, export bool, bundleID ...string) *Config
   100  	setMulti(variables map[string]string, export bool) error
   101  }
   102  
   103  // Import populates the target struct from src.
   104  func (bind *binding) Import(src bindSource) error {
   105  
   106  	rv := reflect.Indirect(reflect.ValueOf(bind.Target))
   107  
   108  	if bind.FieldNum > rv.NumField() {
   109  		return fmt.Errorf("invalid FieldNum (%d) for %s (%v)", bind.FieldNum, bind.Name, rv)
   110  	}
   111  
   112  	value := rv.Field(bind.FieldNum)
   113  
   114  	// Ignore empty/unset variables
   115  	if src.GetString(bind.EnvVar) == "" {
   116  		return nil
   117  	}
   118  
   119  	return bind.setValue(&value, src)
   120  }
   121  
   122  // GetVar populates dst from target struct.
   123  func (bind *binding) GetVar(dst bindDest) (key, value string, ok bool) {
   124  
   125  	rv := reflect.Indirect(reflect.ValueOf(bind.Target))
   126  
   127  	if bind.FieldNum > rv.NumField() {
   128  		return
   129  	}
   130  
   131  	var (
   132  		val     = rv.Field(bind.FieldNum)
   133  		cur     = dst.GetString(bind.EnvVar)
   134  		curZero = isZeroString(cur, val.Kind())
   135  		newZero = isZeroValue(val)
   136  	)
   137  
   138  	// field key & value
   139  	key = bind.EnvVar
   140  	value = fmt.Sprintf("%v", val)
   141  
   142  	// Don't pull zero-value fields into empty variables.
   143  	if curZero && newZero {
   144  		// log.Printf("[bind] %s: both empty", field.Name)
   145  		return
   146  	}
   147  
   148  	ok = true
   149  
   150  	return
   151  }
   152  
   153  func (bind *binding) setValue(rv *reflect.Value, src bindSource) error {
   154  
   155  	switch bind.Kind {
   156  
   157  	case reflect.Bool:
   158  		b := src.GetBool(bind.EnvVar)
   159  		reflect.Indirect(*rv).SetBool(b)
   160  		// log.Printf("[%s] value=%v", bind.Name, b)
   161  
   162  	case reflect.String:
   163  
   164  		s := src.GetString(bind.EnvVar)
   165  		reflect.Indirect(*rv).SetString(s)
   166  		// log.Printf("[%s] value=%s", bind.Name, s)
   167  
   168  	// Special-case int64, as it may also be a duration.
   169  	case reflect.Int64:
   170  
   171  		// Try to parse value as an int, and if that fails, try
   172  		// to parse it as a duration.
   173  		s := src.GetString(bind.EnvVar)
   174  
   175  		if _, err := strconv.ParseInt(s, 10, 64); err == nil {
   176  
   177  			i := src.GetInt(bind.EnvVar)
   178  			reflect.Indirect(*rv).SetInt(int64(i))
   179  
   180  		} else {
   181  
   182  			if d, err := time.ParseDuration(s); err == nil {
   183  				reflect.Indirect(*rv).SetInt(int64(d))
   184  			}
   185  
   186  		}
   187  
   188  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
   189  
   190  		i := src.GetInt(bind.EnvVar)
   191  		reflect.Indirect(*rv).SetInt(int64(i))
   192  		// log.Printf("[%s] value=%d", bind.Name, i)
   193  
   194  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   195  
   196  		i := src.GetInt(bind.EnvVar)
   197  		reflect.Indirect(*rv).SetUint(uint64(i))
   198  		// log.Printf("[%s] value=%d", bind.Name, i)
   199  
   200  	case reflect.Float32, reflect.Float64:
   201  
   202  		n := src.GetFloat(bind.EnvVar)
   203  		reflect.Indirect(*rv).SetFloat(n)
   204  		// log.Printf("[%s] value=%f", bind.Name, n)
   205  
   206  	default:
   207  		return fmt.Errorf("unsupported type: %v", bind.Kind.String())
   208  	}
   209  
   210  	return nil
   211  
   212  }
   213  
   214  func extract(v interface{}) ([]*binding, error) {
   215  
   216  	var binds []*binding
   217  
   218  	rv := reflect.ValueOf(v)
   219  
   220  	if rv.Kind() == reflect.Ptr {
   221  		rv = rv.Elem()
   222  	}
   223  
   224  	if rv.Kind() != reflect.Struct {
   225  		return nil, fmt.Errorf("need struct, not %s: %v", rv.Kind(), v)
   226  	}
   227  
   228  	typ := rv.Type()
   229  
   230  	for i := 0; i < rv.NumField(); i++ {
   231  
   232  		var (
   233  			ok      bool
   234  			field   reflect.StructField
   235  			name    string
   236  			tag     string
   237  			varname string
   238  		)
   239  
   240  		field = typ.Field(i)
   241  		name = field.Name
   242  
   243  		if tag, ok = field.Tag.Lookup("env"); ok {
   244  			if tag == "-" { // Ignore this field
   245  				continue
   246  			}
   247  		}
   248  
   249  		if !isBindable(field.Type.Kind()) {
   250  			log.Printf("[bind] unbindable kind: %s", field.Type.Kind())
   251  			continue
   252  		}
   253  
   254  		if tag != "" {
   255  			varname = tag
   256  		} else {
   257  			varname = EnvVarForField(name)
   258  		}
   259  
   260  		bind := &binding{
   261  			Name:     name,
   262  			EnvVar:   varname,
   263  			FieldNum: i,
   264  			Target:   v,
   265  			Kind:     field.Type.Kind(),
   266  		}
   267  		binds = append(binds, bind)
   268  
   269  	}
   270  
   271  	return binds, nil
   272  }
   273  
   274  var bindableKinds = map[reflect.Kind]bool{
   275  	reflect.String:  true,
   276  	reflect.Bool:    true,
   277  	reflect.Int:     true,
   278  	reflect.Int8:    true,
   279  	reflect.Int16:   true,
   280  	reflect.Int32:   true,
   281  	reflect.Int64:   true,
   282  	reflect.Uint:    true,
   283  	reflect.Uint8:   true,
   284  	reflect.Uint16:  true,
   285  	reflect.Uint32:  true,
   286  	reflect.Uint64:  true,
   287  	reflect.Float32: true,
   288  	reflect.Float64: true,
   289  }
   290  
   291  func isBindable(kind reflect.Kind) bool {
   292  
   293  	_, ok := bindableKinds[kind]
   294  	return ok
   295  }
   296  
   297  func isZeroValue(rv reflect.Value) bool {
   298  
   299  	if !rv.IsValid() {
   300  		return true
   301  	}
   302  
   303  	typ := rv.Type()
   304  	kind := typ.Kind()
   305  
   306  	switch kind {
   307  
   308  	case reflect.Ptr:
   309  		return rv.IsNil()
   310  
   311  	case reflect.String:
   312  		return rv.String() == ""
   313  
   314  	case reflect.Bool:
   315  		return rv.Bool() == false
   316  
   317  	case reflect.Float32, reflect.Float64:
   318  		return rv.Float() == 0.0
   319  
   320  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   321  		return rv.Int() == 0
   322  
   323  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   324  		return rv.Uint() == 0
   325  
   326  	default:
   327  		log.Printf("[pull] unknown kind: %v", kind)
   328  
   329  	}
   330  
   331  	return false
   332  }
   333  
   334  func isZeroString(s string, kind reflect.Kind) bool {
   335  
   336  	if s == "" {
   337  		return true
   338  	}
   339  
   340  	switch kind {
   341  	case reflect.Bool:
   342  
   343  		v, err := strconv.ParseBool(s)
   344  		if err != nil {
   345  			log.Printf("couldn't convert %s to bool: %v", s, err)
   346  			return false
   347  		}
   348  		return v == false
   349  
   350  	case reflect.Float32, reflect.Float64:
   351  		v, err := strconv.ParseFloat(s, 64)
   352  		if err != nil {
   353  			log.Printf("couldn't convert %s to float: %v", s, err)
   354  			return false
   355  		}
   356  		return v == 0.0
   357  
   358  	case reflect.Int64: // special-case int64: may also be duration
   359  
   360  		// Try to parse as int64 first
   361  		if v, err := strconv.ParseInt(s, 10, 64); err == nil {
   362  			return v == 0
   363  		}
   364  
   365  		// try to parse as a duration
   366  		d, err := time.ParseDuration(s)
   367  		if err != nil {
   368  			log.Printf("couldn't convert %s to int64/duration: %s", s, err)
   369  			return false
   370  		}
   371  		return d.Nanoseconds() == 0
   372  
   373  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
   374  		v, err := strconv.ParseInt(s, 10, 32)
   375  		if err != nil {
   376  			log.Println(err)
   377  			return false
   378  		}
   379  		return v == 0
   380  
   381  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   382  		v, err := strconv.ParseUint(s, 10, 64)
   383  		if err != nil {
   384  			log.Println(err)
   385  			return false
   386  		}
   387  		return v == 0
   388  
   389  	}
   390  
   391  	return false
   392  }
   393  
   394  // EnvVarForField generates an environment variable name from a struct field name.
   395  // This is documented to show how the automatic names are generated.
   396  func EnvVarForField(name string) string {
   397  	if !isCamelCase(name) {
   398  		return strings.ToUpper(name)
   399  	}
   400  	return splitCamelCase(name)
   401  }
   402  
   403  func isCamelCase(s string) bool {
   404  	if ok, _ := regexp.MatchString(".*[a-z]+[0-9_]*[A-Z]+.*", s); ok {
   405  		return true
   406  	}
   407  	if ok, _ := regexp.MatchString("[A-Z][A-Z][A-Z]+[0-9_]*[a-z]+.*", s); ok {
   408  		return true
   409  	}
   410  	return false
   411  }
   412  
   413  func splitCamelCase(name string) string {
   414  
   415  	var (
   416  		i     int
   417  		re    *regexp.Regexp
   418  		rest  string
   419  		words []string
   420  	)
   421  
   422  	rest = name
   423  
   424  	// Start with 3 or more capital letters.
   425  	re = regexp.MustCompile("[A-Z]+([A-Z])[0-9]*[a-z]")
   426  	for {
   427  		if idx := re.FindStringSubmatchIndex(rest); idx != nil {
   428  			i = idx[2]
   429  			s := rest[:i]
   430  			rest = rest[i:]
   431  			words = append(words, s)
   432  		} else {
   433  			break
   434  		}
   435  	}
   436  
   437  	re = regexp.MustCompile("[a-z][0-9_]*([A-Z])")
   438  	for {
   439  
   440  		if idx := re.FindStringSubmatchIndex(rest); idx != nil {
   441  			i = idx[2]
   442  			s := rest[:i]
   443  			rest = rest[i:]
   444  			words = append(words, s)
   445  		} else {
   446  			break
   447  		}
   448  	}
   449  
   450  	if rest != "" {
   451  		words = append(words, rest)
   452  	}
   453  
   454  	if len(words) > 0 {
   455  		s := strings.ToUpper(strings.Join(words, "_"))
   456  		return s
   457  	}
   458  
   459  	return ""
   460  }