github.com/lyft/flytestdlib@v0.3.12-0.20210213045714-8cdd111ecda1/cli/pflags/api/templates.go (about) 1 package api 2 3 import ( 4 "bytes" 5 "text/template" 6 ) 7 8 func GenerateCodeFile(buffer *bytes.Buffer, info TypeInfo) error { 9 return mainTmpl.Execute(buffer, info) 10 } 11 12 func GenerateTestFile(buffer *bytes.Buffer, info TypeInfo) error { 13 return testTmpl.Execute(buffer, info) 14 } 15 16 var mainTmpl = template.Must(template.New("MainFile").Parse( 17 `// Code generated by go generate; DO NOT EDIT. 18 // This file was generated by robots. 19 20 package {{ .Package }} 21 22 import ( 23 "encoding/json" 24 25 "github.com/spf13/pflag" 26 "fmt" 27 {{range $path, $name := .Imports}} 28 {{$name}} "{{$path}}"{{end}} 29 ) 30 31 // If v is a pointer, it will get its element value or the zero value of the element type. 32 // If v is not a pointer, it will return it as is. 33 func ({{ .Name }}) elemValueOrNil(v interface{}) interface{} { 34 if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { 35 if reflect.ValueOf(v).IsNil() { 36 return reflect.Zero(t.Elem()).Interface() 37 } else { 38 return reflect.ValueOf(v).Interface() 39 } 40 } else if v == nil { 41 return reflect.Zero(t).Interface() 42 } 43 44 return v 45 } 46 47 func ({{ .Name }}) mustMarshalJSON(v json.Marshaler) string { 48 raw, err := v.MarshalJSON() 49 if err != nil { 50 panic(err) 51 } 52 53 return string(raw) 54 } 55 56 // GetPFlagSet will return strongly types pflags for all fields in {{ .Name }} and its nested types. The format of the 57 // flags is json-name.json-sub-name... etc. 58 func (cfg {{ .Name }}) GetPFlagSet(prefix string) *pflag.FlagSet { 59 cmdFlags := pflag.NewFlagSet("{{ .Name }}", pflag.ExitOnError) 60 {{- range .Fields }} 61 {{- if .ShouldBindDefault }} 62 cmdFlags.{{ .FlagMethodName }}Var(&{{ .DefaultValue }}, fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }}) 63 {{- else }} 64 cmdFlags.{{ .FlagMethodName }}(fmt.Sprintf("%v%v", prefix, "{{ .Name }}"), {{ .DefaultValue }}, {{ .UsageString }}) 65 {{- end }} 66 {{- end }} 67 return cmdFlags 68 } 69 `)) 70 71 var testTmpl = template.Must(template.New("TestFile").Parse( 72 `// Code generated by go generate; DO NOT EDIT. 73 // This file was generated by robots. 74 75 package {{ .Package }} 76 77 import ( 78 "encoding/json" 79 "fmt" 80 "reflect" 81 "strings" 82 "testing" 83 84 "github.com/mitchellh/mapstructure" 85 "github.com/stretchr/testify/assert" 86 {{- range $path, $name := .Imports}} 87 {{$name}} "{{$path}}" 88 {{- end}} 89 ) 90 91 var dereferencableKinds{{ .Name }} = map[reflect.Kind]struct{}{ 92 reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, 93 } 94 95 // Checks if t is a kind that can be dereferenced to get its underlying type. 96 func canGetElement{{ .Name }}(t reflect.Kind) bool { 97 _, exists := dereferencableKinds{{ .Name }}[t] 98 return exists 99 } 100 101 // This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the 102 // object. Otherwise, it'll just pass on the original data. 103 func jsonUnmarshalerHook{{ .Name }}(_, to reflect.Type, data interface{}) (interface{}, error) { 104 unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() 105 if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || 106 (canGetElement{{ .Name }}(to.Kind()) && to.Elem().Implements(unmarshalerType)) { 107 108 raw, err := json.Marshal(data) 109 if err != nil { 110 fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) 111 return data, nil 112 } 113 114 res := reflect.New(to).Interface() 115 err = json.Unmarshal(raw, &res) 116 if err != nil { 117 fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) 118 return data, nil 119 } 120 121 return res, nil 122 } 123 124 return data, nil 125 } 126 127 func decode_{{ .Name }}(input, result interface{}) error { 128 config := &mapstructure.DecoderConfig{ 129 TagName: "json", 130 WeaklyTypedInput: true, 131 Result: result, 132 DecodeHook: mapstructure.ComposeDecodeHookFunc( 133 mapstructure.StringToTimeDurationHookFunc(), 134 mapstructure.StringToSliceHookFunc(","), 135 jsonUnmarshalerHook{{ .Name }}, 136 ), 137 } 138 139 decoder, err := mapstructure.NewDecoder(config) 140 if err != nil { 141 return err 142 } 143 144 return decoder.Decode(input) 145 } 146 147 func join_{{ .Name }}(arr interface{}, sep string) string { 148 listValue := reflect.ValueOf(arr) 149 strs := make([]string, 0, listValue.Len()) 150 for i := 0; i < listValue.Len(); i++ { 151 strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) 152 } 153 154 return strings.Join(strs, sep) 155 } 156 157 func testDecodeJson_{{ .Name }}(t *testing.T, val, result interface{}) { 158 assert.NoError(t, decode_{{ .Name }}(val, result)) 159 } 160 161 func testDecodeSlice_{{ .Name }}(t *testing.T, vStringSlice, result interface{}) { 162 assert.NoError(t, decode_{{ .Name }}(vStringSlice, result)) 163 } 164 165 func Test{{ .Name }}_GetPFlagSet(t *testing.T) { 166 val := {{ .Name }}{} 167 cmdFlags := val.GetPFlagSet("") 168 assert.True(t, cmdFlags.HasFlags()) 169 } 170 171 func Test{{ .Name }}_SetFlags(t *testing.T) { 172 actual := {{ .Name }}{} 173 cmdFlags := actual.GetPFlagSet("") 174 assert.True(t, cmdFlags.HasFlags()) 175 176 {{ $ParentName := .Name }} 177 {{- range .Fields }} 178 t.Run("Test_{{ .Name }}", func(t *testing.T) { {{ $varName := print "v" .FlagMethodName }} 179 t.Run("DefaultValue", func(t *testing.T) { 180 // Test that default value is set properly 181 if {{ $varName }}, err := cmdFlags.Get{{ .FlagMethodName }}("{{ .Name }}"); err == nil { 182 assert.Equal(t, {{ .Typ }}({{ .DefaultValue }}), {{ $varName }}) 183 } else { 184 assert.FailNow(t, err.Error()) 185 } 186 }) 187 188 t.Run("Override", func(t *testing.T) { 189 {{ if eq .TestStrategy "Json" }}testValue := {{ .TestValue }} 190 {{ else if eq .TestStrategy "SliceRaw" }}testValue := {{ .TestValue }} 191 {{ else }}testValue := join_{{ $ParentName }}({{ .TestValue }}, ",") 192 {{ end }} 193 cmdFlags.Set("{{ .Name }}", testValue) 194 if {{ $varName }}, err := cmdFlags.Get{{ .FlagMethodName }}("{{ .Name }}"); err == nil { 195 {{ if eq .TestStrategy "Json" }}testDecodeJson_{{ $ParentName }}(t, fmt.Sprintf("%v", {{ print "v" .FlagMethodName }}), &actual.{{ .GoName }}) 196 {{ else if eq .TestStrategy "SliceRaw" }}testDecodeSlice_{{ $ParentName }}(t, {{ print "v" .FlagMethodName }}, &actual.{{ .GoName }}) 197 {{ else }}testDecodeSlice_{{ $ParentName }}(t, join_{{ $ParentName }}({{ print "v" .FlagMethodName }}, ","), &actual.{{ .GoName }}) 198 {{ end }} 199 } else { 200 assert.FailNow(t, err.Error()) 201 } 202 }) 203 }) 204 {{- end }} 205 } 206 `))