github.com/google/go-github/v42@v42.0.0/github/gen-stringify-test.go (about)

     1  // Copyright 2019 The go-github AUTHORS. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build ignore
     7  // +build ignore
     8  
     9  // gen-stringify-test generates test methods to test the String methods.
    10  //
    11  // These tests eliminate most of the code coverage problems so that real
    12  // code coverage issues can be more readily identified.
    13  //
    14  // It is meant to be used by go-github contributors in conjunction with the
    15  // go generate tool before sending a PR to GitHub.
    16  // Please see the CONTRIBUTING.md file for more information.
    17  package main
    18  
    19  import (
    20  	"bytes"
    21  	"flag"
    22  	"fmt"
    23  	"go/ast"
    24  	"go/format"
    25  	"go/parser"
    26  	"go/token"
    27  	"io/ioutil"
    28  	"log"
    29  	"os"
    30  	"strings"
    31  	"text/template"
    32  )
    33  
    34  const (
    35  	ignoreFilePrefix1 = "gen-"
    36  	ignoreFilePrefix2 = "github-"
    37  	outputFileSuffix  = "-stringify_test.go"
    38  )
    39  
    40  var (
    41  	verbose = flag.Bool("v", false, "Print verbose log messages")
    42  
    43  	// skipStructMethods lists "struct.method" combos to skip.
    44  	skipStructMethods = map[string]bool{}
    45  	// skipStructs lists structs to skip.
    46  	skipStructs = map[string]bool{
    47  		"RateLimits": true,
    48  	}
    49  
    50  	funcMap = template.FuncMap{
    51  		"isNotLast": func(index int, slice []*structField) string {
    52  			if index+1 < len(slice) {
    53  				return ", "
    54  			}
    55  			return ""
    56  		},
    57  		"processZeroValue": func(v string) string {
    58  			switch v {
    59  			case "Bool(false)":
    60  				return "false"
    61  			case "Float64(0.0)":
    62  				return "0"
    63  			case "0", "Int(0)", "Int64(0)":
    64  				return "0"
    65  			case `""`, `String("")`:
    66  				return `""`
    67  			case "Timestamp{}", "&Timestamp{}":
    68  				return "github.Timestamp{0001-01-01 00:00:00 +0000 UTC}"
    69  			case "nil":
    70  				return "map[]"
    71  			}
    72  			log.Fatalf("Unhandled zero value: %q", v)
    73  			return ""
    74  		},
    75  	}
    76  
    77  	sourceTmpl = template.Must(template.New("source").Funcs(funcMap).Parse(source))
    78  )
    79  
    80  func main() {
    81  	flag.Parse()
    82  	fset := token.NewFileSet()
    83  
    84  	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
    85  	if err != nil {
    86  		log.Fatal(err)
    87  		return
    88  	}
    89  
    90  	for pkgName, pkg := range pkgs {
    91  		t := &templateData{
    92  			filename:     pkgName + outputFileSuffix,
    93  			Year:         2019, // No need to change this once set (even in following years).
    94  			Package:      pkgName,
    95  			Imports:      map[string]string{"testing": "testing"},
    96  			StringFuncs:  map[string]bool{},
    97  			StructFields: map[string][]*structField{},
    98  		}
    99  		for filename, f := range pkg.Files {
   100  			logf("Processing %v...", filename)
   101  			if err := t.processAST(f); err != nil {
   102  				log.Fatal(err)
   103  			}
   104  		}
   105  		if err := t.dump(); err != nil {
   106  			log.Fatal(err)
   107  		}
   108  	}
   109  	logf("Done.")
   110  }
   111  
   112  func sourceFilter(fi os.FileInfo) bool {
   113  	return !strings.HasSuffix(fi.Name(), "_test.go") &&
   114  		!strings.HasPrefix(fi.Name(), ignoreFilePrefix1) &&
   115  		!strings.HasPrefix(fi.Name(), ignoreFilePrefix2)
   116  }
   117  
   118  type templateData struct {
   119  	filename     string
   120  	Year         int
   121  	Package      string
   122  	Imports      map[string]string
   123  	StringFuncs  map[string]bool
   124  	StructFields map[string][]*structField
   125  }
   126  
   127  type structField struct {
   128  	sortVal      string // Lower-case version of "ReceiverType.FieldName".
   129  	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
   130  	ReceiverType string
   131  	FieldName    string
   132  	FieldType    string
   133  	ZeroValue    string
   134  	NamedStruct  bool // Getter for named struct.
   135  }
   136  
   137  func (t *templateData) processAST(f *ast.File) error {
   138  	for _, decl := range f.Decls {
   139  		fn, ok := decl.(*ast.FuncDecl)
   140  		if ok {
   141  			if fn.Recv != nil && len(fn.Recv.List) > 0 {
   142  				id, ok := fn.Recv.List[0].Type.(*ast.Ident)
   143  				if ok && fn.Name.Name == "String" {
   144  					logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
   145  					t.StringFuncs[id.Name] = true
   146  				} else {
   147  					logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
   148  				}
   149  			} else {
   150  				logf("Ignoring FuncDecl: Name=%q, fn=%#v", fn.Name.Name, fn)
   151  			}
   152  			continue
   153  		}
   154  
   155  		gd, ok := decl.(*ast.GenDecl)
   156  		if !ok {
   157  			logf("Ignoring AST decl type %T", decl)
   158  			continue
   159  		}
   160  		for _, spec := range gd.Specs {
   161  			ts, ok := spec.(*ast.TypeSpec)
   162  			if !ok {
   163  				continue
   164  			}
   165  			// Skip unexported identifiers.
   166  			if !ts.Name.IsExported() {
   167  				logf("Struct %v is unexported; skipping.", ts.Name)
   168  				continue
   169  			}
   170  			// Check if the struct should be skipped.
   171  			if skipStructs[ts.Name.Name] {
   172  				logf("Struct %v is in skip list; skipping.", ts.Name)
   173  				continue
   174  			}
   175  			st, ok := ts.Type.(*ast.StructType)
   176  			if !ok {
   177  				logf("Ignoring AST type %T, Name=%q", ts.Type, ts.Name.String())
   178  				continue
   179  			}
   180  			for _, field := range st.Fields.List {
   181  				if len(field.Names) == 0 {
   182  					continue
   183  				}
   184  
   185  				fieldName := field.Names[0]
   186  				if id, ok := field.Type.(*ast.Ident); ok {
   187  					t.addIdent(id, ts.Name.String(), fieldName.String())
   188  					continue
   189  				}
   190  
   191  				se, ok := field.Type.(*ast.StarExpr)
   192  				if !ok {
   193  					logf("Ignoring type %T for Name=%q, FieldName=%q", field.Type, ts.Name.String(), fieldName.String())
   194  					continue
   195  				}
   196  
   197  				// Skip unexported identifiers.
   198  				if !fieldName.IsExported() {
   199  					logf("Field %v is unexported; skipping.", fieldName)
   200  					continue
   201  				}
   202  				// Check if "struct.method" should be skipped.
   203  				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
   204  					logf("Method %v is in skip list; skipping.", key)
   205  					continue
   206  				}
   207  
   208  				switch x := se.X.(type) {
   209  				case *ast.ArrayType:
   210  				case *ast.Ident:
   211  					t.addIdentPtr(x, ts.Name.String(), fieldName.String())
   212  				case *ast.MapType:
   213  				case *ast.SelectorExpr:
   214  				default:
   215  					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
   216  				}
   217  			}
   218  		}
   219  	}
   220  	return nil
   221  }
   222  
   223  func (t *templateData) addMapType(receiverType, fieldName string) {
   224  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, "map[]", "nil", false))
   225  }
   226  
   227  func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
   228  	var zeroValue string
   229  	var namedStruct = false
   230  	switch x.String() {
   231  	case "int":
   232  		zeroValue = "0"
   233  	case "int64":
   234  		zeroValue = "0"
   235  	case "float64":
   236  		zeroValue = "0.0"
   237  	case "string":
   238  		zeroValue = `""`
   239  	case "bool":
   240  		zeroValue = "false"
   241  	case "Timestamp":
   242  		zeroValue = "Timestamp{}"
   243  	default:
   244  		zeroValue = "nil"
   245  		namedStruct = true
   246  	}
   247  
   248  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   249  }
   250  
   251  func (t *templateData) addIdentPtr(x *ast.Ident, receiverType, fieldName string) {
   252  	var zeroValue string
   253  	var namedStruct = false
   254  	switch x.String() {
   255  	case "int":
   256  		zeroValue = "Int(0)"
   257  	case "int64":
   258  		zeroValue = "Int64(0)"
   259  	case "float64":
   260  		zeroValue = "Float64(0.0)"
   261  	case "string":
   262  		zeroValue = `String("")`
   263  	case "bool":
   264  		zeroValue = "Bool(false)"
   265  	case "Timestamp":
   266  		zeroValue = "&Timestamp{}"
   267  	default:
   268  		zeroValue = "nil"
   269  		namedStruct = true
   270  	}
   271  
   272  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   273  }
   274  
   275  func (t *templateData) dump() error {
   276  	if len(t.StructFields) == 0 {
   277  		logf("No StructFields for %v; skipping.", t.filename)
   278  		return nil
   279  	}
   280  
   281  	// Remove unused structs.
   282  	var toDelete []string
   283  	for k := range t.StructFields {
   284  		if !t.StringFuncs[k] {
   285  			toDelete = append(toDelete, k)
   286  			continue
   287  		}
   288  	}
   289  	for _, k := range toDelete {
   290  		delete(t.StructFields, k)
   291  	}
   292  
   293  	var buf bytes.Buffer
   294  	if err := sourceTmpl.Execute(&buf, t); err != nil {
   295  		return err
   296  	}
   297  	clean, err := format.Source(buf.Bytes())
   298  	if err != nil {
   299  		log.Printf("failed-to-format source:\n%v", buf.String())
   300  		return err
   301  	}
   302  
   303  	logf("Writing %v...", t.filename)
   304  	return ioutil.WriteFile(t.filename, clean, 0644)
   305  }
   306  
   307  func newStructField(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *structField {
   308  	return &structField{
   309  		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
   310  		ReceiverVar:  strings.ToLower(receiverType[:1]),
   311  		ReceiverType: receiverType,
   312  		FieldName:    fieldName,
   313  		FieldType:    fieldType,
   314  		ZeroValue:    zeroValue,
   315  		NamedStruct:  namedStruct,
   316  	}
   317  }
   318  
   319  func logf(fmt string, args ...interface{}) {
   320  	if *verbose {
   321  		log.Printf(fmt, args...)
   322  	}
   323  }
   324  
   325  const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   326  //
   327  // Use of this source code is governed by a BSD-style
   328  // license that can be found in the LICENSE file.
   329  
   330  // Code generated by gen-stringify-tests; DO NOT EDIT.
   331  
   332  package {{ $package := .Package}}{{$package}}
   333  {{with .Imports}}
   334  import (
   335    {{- range . -}}
   336    "{{.}}"
   337    {{end -}}
   338  )
   339  {{end}}
   340  func Float64(v float64) *float64 { return &v }
   341  {{range $key, $value := .StructFields}}
   342  func Test{{ $key }}_String(t *testing.T) {
   343    v := {{ $key }}{ {{range .}}{{if .NamedStruct}}
   344      {{ .FieldName }}: &{{ .FieldType }}{},{{else}}
   345      {{ .FieldName }}: {{.ZeroValue}},{{end}}{{end}}
   346    }
   347   	want := ` + "`" + `{{ $package }}.{{ $key }}{{ $slice := . }}{
   348  {{- range $ind, $val := .}}{{if .NamedStruct}}{{ .FieldName }}:{{ $package }}.{{ .FieldType }}{}{{else}}{{ .FieldName }}:{{ processZeroValue .ZeroValue }}{{end}}{{ isNotLast $ind $slice }}{{end}}}` + "`" + `
   349  	if got := v.String(); got != want {
   350  		t.Errorf("{{ $key }}.String = %v, want %v", got, want)
   351  	}
   352  }
   353  {{end}}
   354  `