github.com/lyft/flytestdlib@v0.3.12-0.20210213045714-8cdd111ecda1/cli/pflags/api/templates.go (about)

     1  package api
     2  
     3  import (
     4  	"bytes"
     5  	"text/template"
     6  )
     7  
     8  func GenerateCodeFile(buffer *bytes.Buffer, info TypeInfo) error {
     9  	return mainTmpl.Execute(buffer, info)
    10  }
    11  
    12  func GenerateTestFile(buffer *bytes.Buffer, info TypeInfo) error {
    13  	return testTmpl.Execute(buffer, info)
    14  }
    15  
    16  var mainTmpl = template.Must(template.New("MainFile").Parse(
    17  	`// Code generated by go generate; DO NOT EDIT.
    18  // This file was generated by robots.
    19  
    20  package {{ .Package }}
    21  
    22  import (
    23  	"encoding/json"
    24  
    25  	"github.com/spf13/pflag"
    26  	"fmt"
    27  {{range $path, $name := .Imports}}
    28  	{{$name}} "{{$path}}"{{end}}
    29  )
    30  
    31  // If v is a pointer, it will get its element value or the zero value of the element type.
    32  // If v is not a pointer, it will return it as is.
    33  func ({{ .Name }}) elemValueOrNil(v interface{}) interface{} {
    34  	if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr {
    35  		if reflect.ValueOf(v).IsNil() {
    36  			return reflect.Zero(t.Elem()).Interface()
    37  		} else {
    38  			return reflect.ValueOf(v).Interface()
    39  		}
    40  	} else if v == nil {
    41  		return reflect.Zero(t).Interface()
    42  	}
    43  
    44  	return v
    45  }
    46  
    47  func ({{ .Name }}) mustMarshalJSON(v json.Marshaler) string {
    48      raw, err := v.MarshalJSON()
    49      if err != nil {
    50          panic(err)
    51      }
    52  
    53      return string(raw)
    54  }
    55  
    56  // GetPFlagSet will return strongly types pflags for all fields in {{ .Name }} and its nested types. The format of the
    57  // flags is json-name.json-sub-name... etc.
    58  func (cfg {{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet {
    59  	cmdFlags := pflag.NewFlagSet("{{ .Name }}", pflag.ExitOnError)
    60  	{{- range .Fields }}
    61  	{{- if .ShouldBindDefault }}
    62  	cmdFlags.{{ .FlagMethodName }}Var(&{{ .DefaultValue }}, fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }})
    63  	{{- else }}
    64  	cmdFlags.{{ .FlagMethodName }}(fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }})
    65  	{{- end }}
    66  	{{- end }}
    67  	return cmdFlags
    68  }
    69  `))
    70  
    71  var testTmpl = template.Must(template.New("TestFile").Parse(
    72  	`// Code generated by go generate; DO NOT EDIT.
    73  // This file was generated by robots.
    74  
    75  package {{ .Package }}
    76  
    77  import (
    78  	"encoding/json"
    79  	"fmt"
    80  	"reflect"
    81  	"strings"
    82  	"testing"
    83  
    84  	"github.com/mitchellh/mapstructure"
    85  	"github.com/stretchr/testify/assert"
    86  {{- range $path, $name := .Imports}}
    87  	{{$name}} "{{$path}}"
    88  {{- end}}
    89  )
    90  
    91  var dereferencableKinds{{ .Name }} = map[reflect.Kind]struct{}{
    92  	reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {},
    93  }
    94  
    95  // Checks if t is a kind that can be dereferenced to get its underlying type.
    96  func canGetElement{{ .Name }}(t reflect.Kind) bool {
    97  	_, exists := dereferencableKinds{{ .Name }}[t]
    98  	return exists
    99  }
   100  
   101  // This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the
   102  // object. Otherwise, it'll just pass on the original data.
   103  func jsonUnmarshalerHook{{ .Name }}(_, to reflect.Type, data interface{}) (interface{}, error) {
   104  	unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
   105  	if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) ||
   106  		(canGetElement{{ .Name }}(to.Kind()) && to.Elem().Implements(unmarshalerType)) {
   107  
   108  		raw, err := json.Marshal(data)
   109  		if err != nil {
   110  			fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err)
   111  			return data, nil
   112  		}
   113  
   114  		res := reflect.New(to).Interface()
   115  		err = json.Unmarshal(raw, &res)
   116  		if err != nil {
   117  			fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err)
   118  			return data, nil
   119  		}
   120  
   121  		return res, nil
   122  	}
   123  
   124  	return data, nil
   125  }
   126  
   127  func decode_{{ .Name }}(input, result interface{}) error {
   128  	config := &mapstructure.DecoderConfig{
   129  		TagName:          "json",
   130  		WeaklyTypedInput: true,
   131  		Result:           result,
   132  		DecodeHook: mapstructure.ComposeDecodeHookFunc(
   133  			mapstructure.StringToTimeDurationHookFunc(),
   134  			mapstructure.StringToSliceHookFunc(","),
   135  			jsonUnmarshalerHook{{ .Name }},
   136  		),
   137  	}
   138  
   139  	decoder, err := mapstructure.NewDecoder(config)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	return decoder.Decode(input)
   145  }
   146  
   147  func join_{{ .Name }}(arr interface{}, sep string) string {
   148  	listValue := reflect.ValueOf(arr)
   149  	strs := make([]string, 0, listValue.Len())
   150  	for i := 0; i < listValue.Len(); i++ {
   151  		strs = append(strs, fmt.Sprintf("%v", listValue.Index(i)))
   152  	}
   153  
   154  	return strings.Join(strs, sep)
   155  }
   156  
   157  func testDecodeJson_{{ .Name }}(t *testing.T, val, result interface{}) {
   158  	assert.NoError(t, decode_{{ .Name }}(val, result))
   159  }
   160  
   161  func testDecodeSlice_{{ .Name }}(t *testing.T, vStringSlice, result interface{}) {
   162  	assert.NoError(t, decode_{{ .Name }}(vStringSlice, result))
   163  }
   164  
   165  func Test{{ .Name }}_GetPFlagSet(t *testing.T) {
   166  	val := {{ .Name }}{}
   167  	cmdFlags := val.GetPFlagSet("")
   168  	assert.True(t, cmdFlags.HasFlags())
   169  }
   170  
   171  func Test{{ .Name }}_SetFlags(t *testing.T) {
   172  	actual := {{ .Name }}{}
   173  	cmdFlags := actual.GetPFlagSet("")
   174  	assert.True(t, cmdFlags.HasFlags())
   175  
   176  	{{ $ParentName := .Name }}
   177  	{{- range .Fields }}
   178  	t.Run("Test_{{ .Name }}", func(t *testing.T) { {{ $varName := print "v" .FlagMethodName }}
   179  		t.Run("DefaultValue", func(t *testing.T) {
   180  			// Test that default value is set properly
   181  			if {{ $varName }}, err := cmdFlags.Get{{ .FlagMethodName }}("{{ .Name }}"); err == nil {
   182  				assert.Equal(t, {{ .Typ }}({{ .DefaultValue }}), {{ $varName }})
   183  			} else {
   184  				assert.FailNow(t, err.Error())
   185  			}
   186  		})
   187  
   188  		t.Run("Override", func(t *testing.T) {
   189  			{{ if eq .TestStrategy "Json" }}testValue := {{ .TestValue }}
   190  			{{ else if eq .TestStrategy "SliceRaw" }}testValue := {{ .TestValue }}
   191  			{{ else }}testValue := join_{{ $ParentName }}({{ .TestValue }}, ",")
   192  			{{ end }}
   193  			cmdFlags.Set("{{ .Name }}", testValue)
   194  			if {{ $varName }}, err := cmdFlags.Get{{ .FlagMethodName }}("{{ .Name }}"); err == nil {
   195  				{{ if eq .TestStrategy "Json" }}testDecodeJson_{{ $ParentName }}(t, fmt.Sprintf("%v", {{ print "v" .FlagMethodName }}), &actual.{{ .GoName }})
   196  				{{ else if eq .TestStrategy "SliceRaw" }}testDecodeSlice_{{ $ParentName }}(t, {{ print "v" .FlagMethodName }}, &actual.{{ .GoName }})
   197  				{{ else }}testDecodeSlice_{{ $ParentName }}(t, join_{{ $ParentName }}({{ print "v" .FlagMethodName }}, ","), &actual.{{ .GoName }})
   198  				{{ end }}
   199  			} else {
   200  				assert.FailNow(t, err.Error())
   201  			}
   202  		})
   203  	})
   204  	{{- end }}
   205  }
   206  `))