github.com/google/go-github/v33@v33.0.0/github/gen-stringify-test.go (about)

     1  // Copyright 2019 The go-github AUTHORS. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  // +build ignore
     7  
     8  // gen-stringify-test generates test methods to test the String methods.
     9  //
    10  // These tests eliminate most of the code coverage problems so that real
    11  // code coverage issues can be more readily identified.
    12  //
    13  // It is meant to be used by go-github contributors in conjunction with the
    14  // go generate tool before sending a PR to GitHub.
    15  // Please see the CONTRIBUTING.md file for more information.
    16  package main
    17  
    18  import (
    19  	"bytes"
    20  	"flag"
    21  	"fmt"
    22  	"go/ast"
    23  	"go/format"
    24  	"go/parser"
    25  	"go/token"
    26  	"io/ioutil"
    27  	"log"
    28  	"os"
    29  	"strings"
    30  	"text/template"
    31  )
    32  
    33  const (
    34  	ignoreFilePrefix1 = "gen-"
    35  	ignoreFilePrefix2 = "github-"
    36  	outputFileSuffix  = "-stringify_test.go"
    37  )
    38  
    39  var (
    40  	verbose = flag.Bool("v", false, "Print verbose log messages")
    41  
    42  	// skipStructMethods lists "struct.method" combos to skip.
    43  	skipStructMethods = map[string]bool{}
    44  	// skipStructs lists structs to skip.
    45  	skipStructs = map[string]bool{
    46  		"RateLimits": true,
    47  	}
    48  
    49  	funcMap = template.FuncMap{
    50  		"isNotLast": func(index int, slice []*structField) string {
    51  			if index+1 < len(slice) {
    52  				return ", "
    53  			}
    54  			return ""
    55  		},
    56  		"processZeroValue": func(v string) string {
    57  			switch v {
    58  			case "Bool(false)":
    59  				return "false"
    60  			case "Float64(0.0)":
    61  				return "0"
    62  			case "0", "Int(0)", "Int64(0)":
    63  				return "0"
    64  			case `""`, `String("")`:
    65  				return `""`
    66  			case "Timestamp{}", "&Timestamp{}":
    67  				return "github.Timestamp{0001-01-01 00:00:00 +0000 UTC}"
    68  			case "nil":
    69  				return "map[]"
    70  			}
    71  			log.Fatalf("Unhandled zero value: %q", v)
    72  			return ""
    73  		},
    74  	}
    75  
    76  	sourceTmpl = template.Must(template.New("source").Funcs(funcMap).Parse(source))
    77  )
    78  
    79  func main() {
    80  	flag.Parse()
    81  	fset := token.NewFileSet()
    82  
    83  	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
    84  	if err != nil {
    85  		log.Fatal(err)
    86  		return
    87  	}
    88  
    89  	for pkgName, pkg := range pkgs {
    90  		t := &templateData{
    91  			filename:     pkgName + outputFileSuffix,
    92  			Year:         2019, // No need to change this once set (even in following years).
    93  			Package:      pkgName,
    94  			Imports:      map[string]string{"testing": "testing"},
    95  			StringFuncs:  map[string]bool{},
    96  			StructFields: map[string][]*structField{},
    97  		}
    98  		for filename, f := range pkg.Files {
    99  			logf("Processing %v...", filename)
   100  			if err := t.processAST(f); err != nil {
   101  				log.Fatal(err)
   102  			}
   103  		}
   104  		if err := t.dump(); err != nil {
   105  			log.Fatal(err)
   106  		}
   107  	}
   108  	logf("Done.")
   109  }
   110  
   111  func sourceFilter(fi os.FileInfo) bool {
   112  	return !strings.HasSuffix(fi.Name(), "_test.go") &&
   113  		!strings.HasPrefix(fi.Name(), ignoreFilePrefix1) &&
   114  		!strings.HasPrefix(fi.Name(), ignoreFilePrefix2)
   115  }
   116  
   117  type templateData struct {
   118  	filename     string
   119  	Year         int
   120  	Package      string
   121  	Imports      map[string]string
   122  	StringFuncs  map[string]bool
   123  	StructFields map[string][]*structField
   124  }
   125  
   126  type structField struct {
   127  	sortVal      string // Lower-case version of "ReceiverType.FieldName".
   128  	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
   129  	ReceiverType string
   130  	FieldName    string
   131  	FieldType    string
   132  	ZeroValue    string
   133  	NamedStruct  bool // Getter for named struct.
   134  }
   135  
   136  func (t *templateData) processAST(f *ast.File) error {
   137  	for _, decl := range f.Decls {
   138  		fn, ok := decl.(*ast.FuncDecl)
   139  		if ok {
   140  			if fn.Recv != nil && len(fn.Recv.List) > 0 {
   141  				id, ok := fn.Recv.List[0].Type.(*ast.Ident)
   142  				if ok && fn.Name.Name == "String" {
   143  					logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
   144  					t.StringFuncs[id.Name] = true
   145  				} else {
   146  					logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
   147  				}
   148  			} else {
   149  				logf("Ignoring FuncDecl: Name=%q, fn=%#v", fn.Name.Name, fn)
   150  			}
   151  			continue
   152  		}
   153  
   154  		gd, ok := decl.(*ast.GenDecl)
   155  		if !ok {
   156  			logf("Ignoring AST decl type %T", decl)
   157  			continue
   158  		}
   159  		for _, spec := range gd.Specs {
   160  			ts, ok := spec.(*ast.TypeSpec)
   161  			if !ok {
   162  				continue
   163  			}
   164  			// Skip unexported identifiers.
   165  			if !ts.Name.IsExported() {
   166  				logf("Struct %v is unexported; skipping.", ts.Name)
   167  				continue
   168  			}
   169  			// Check if the struct should be skipped.
   170  			if skipStructs[ts.Name.Name] {
   171  				logf("Struct %v is in skip list; skipping.", ts.Name)
   172  				continue
   173  			}
   174  			st, ok := ts.Type.(*ast.StructType)
   175  			if !ok {
   176  				logf("Ignoring AST type %T, Name=%q", ts.Type, ts.Name.String())
   177  				continue
   178  			}
   179  			for _, field := range st.Fields.List {
   180  				if len(field.Names) == 0 {
   181  					continue
   182  				}
   183  
   184  				fieldName := field.Names[0]
   185  				if id, ok := field.Type.(*ast.Ident); ok {
   186  					t.addIdent(id, ts.Name.String(), fieldName.String())
   187  					continue
   188  				}
   189  
   190  				if _, ok := field.Type.(*ast.MapType); ok {
   191  					t.addMapType(ts.Name.String(), fieldName.String())
   192  					continue
   193  				}
   194  
   195  				se, ok := field.Type.(*ast.StarExpr)
   196  				if !ok {
   197  					logf("Ignoring type %T for Name=%q, FieldName=%q", field.Type, ts.Name.String(), fieldName.String())
   198  					continue
   199  				}
   200  
   201  				// Skip unexported identifiers.
   202  				if !fieldName.IsExported() {
   203  					logf("Field %v is unexported; skipping.", fieldName)
   204  					continue
   205  				}
   206  				// Check if "struct.method" should be skipped.
   207  				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
   208  					logf("Method %v is in skip list; skipping.", key)
   209  					continue
   210  				}
   211  
   212  				switch x := se.X.(type) {
   213  				case *ast.ArrayType:
   214  				case *ast.Ident:
   215  					t.addIdentPtr(x, ts.Name.String(), fieldName.String())
   216  				case *ast.MapType:
   217  				case *ast.SelectorExpr:
   218  				default:
   219  					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
   220  				}
   221  			}
   222  		}
   223  	}
   224  	return nil
   225  }
   226  
   227  func (t *templateData) addMapType(receiverType, fieldName string) {
   228  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, "map[]", "nil", false))
   229  }
   230  
   231  func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
   232  	var zeroValue string
   233  	var namedStruct = false
   234  	switch x.String() {
   235  	case "int":
   236  		zeroValue = "0"
   237  	case "int64":
   238  		zeroValue = "0"
   239  	case "float64":
   240  		zeroValue = "0.0"
   241  	case "string":
   242  		zeroValue = `""`
   243  	case "bool":
   244  		zeroValue = "false"
   245  	case "Timestamp":
   246  		zeroValue = "Timestamp{}"
   247  	default:
   248  		zeroValue = "nil"
   249  		namedStruct = true
   250  	}
   251  
   252  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   253  }
   254  
   255  func (t *templateData) addIdentPtr(x *ast.Ident, receiverType, fieldName string) {
   256  	var zeroValue string
   257  	var namedStruct = false
   258  	switch x.String() {
   259  	case "int":
   260  		zeroValue = "Int(0)"
   261  	case "int64":
   262  		zeroValue = "Int64(0)"
   263  	case "float64":
   264  		zeroValue = "Float64(0.0)"
   265  	case "string":
   266  		zeroValue = `String("")`
   267  	case "bool":
   268  		zeroValue = "Bool(false)"
   269  	case "Timestamp":
   270  		zeroValue = "&Timestamp{}"
   271  	default:
   272  		zeroValue = "nil"
   273  		namedStruct = true
   274  	}
   275  
   276  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   277  }
   278  
   279  func (t *templateData) dump() error {
   280  	if len(t.StructFields) == 0 {
   281  		logf("No StructFields for %v; skipping.", t.filename)
   282  		return nil
   283  	}
   284  
   285  	// Remove unused structs.
   286  	var toDelete []string
   287  	for k := range t.StructFields {
   288  		if !t.StringFuncs[k] {
   289  			toDelete = append(toDelete, k)
   290  			continue
   291  		}
   292  	}
   293  	for _, k := range toDelete {
   294  		delete(t.StructFields, k)
   295  	}
   296  
   297  	var buf bytes.Buffer
   298  	if err := sourceTmpl.Execute(&buf, t); err != nil {
   299  		return err
   300  	}
   301  	clean, err := format.Source(buf.Bytes())
   302  	if err != nil {
   303  		log.Printf("failed-to-format source:\n%v", buf.String())
   304  		return err
   305  	}
   306  
   307  	logf("Writing %v...", t.filename)
   308  	return ioutil.WriteFile(t.filename, clean, 0644)
   309  }
   310  
   311  func newStructField(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *structField {
   312  	return &structField{
   313  		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
   314  		ReceiverVar:  strings.ToLower(receiverType[:1]),
   315  		ReceiverType: receiverType,
   316  		FieldName:    fieldName,
   317  		FieldType:    fieldType,
   318  		ZeroValue:    zeroValue,
   319  		NamedStruct:  namedStruct,
   320  	}
   321  }
   322  
   323  func logf(fmt string, args ...interface{}) {
   324  	if *verbose {
   325  		log.Printf(fmt, args...)
   326  	}
   327  }
   328  
   329  const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   330  //
   331  // Use of this source code is governed by a BSD-style
   332  // license that can be found in the LICENSE file.
   333  
   334  // Code generated by gen-stringify-tests; DO NOT EDIT.
   335  
   336  package {{ $package := .Package}}{{$package}}
   337  {{with .Imports}}
   338  import (
   339    {{- range . -}}
   340    "{{.}}"
   341    {{end -}}
   342  )
   343  {{end}}
   344  func Float64(v float64) *float64 { return &v }
   345  {{range $key, $value := .StructFields}}
   346  func Test{{ $key }}_String(t *testing.T) {
   347    v := {{ $key }}{ {{range .}}{{if .NamedStruct}}
   348      {{ .FieldName }}: &{{ .FieldType }}{},{{else}}
   349      {{ .FieldName }}: {{.ZeroValue}},{{end}}{{end}}
   350    }
   351   	want := ` + "`" + `{{ $package }}.{{ $key }}{{ $slice := . }}{
   352  {{- range $ind, $val := .}}{{if .NamedStruct}}{{ .FieldName }}:{{ $package }}.{{ .FieldType }}{}{{else}}{{ .FieldName }}:{{ processZeroValue .ZeroValue }}{{end}}{{ isNotLast $ind $slice }}{{end}}}` + "`" + `
   353  	if got := v.String(); got != want {
   354  		t.Errorf("{{ $key }}.String = %v, want %v", got, want)
   355  	}
   356  }
   357  {{end}}
   358  `