github.com/google/go-github/v68@v68.0.0/github/gen-accessors.go (about)

     1  // Copyright 2017 The go-github AUTHORS. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build ignore
     7  
     8  // gen-accessors generates accessor methods for structs with pointer fields.
     9  //
    10  // It is meant to be used by go-github contributors in conjunction with the
    11  // go generate tool before sending a PR to GitHub.
    12  // Please see the CONTRIBUTING.md file for more information.
    13  package main
    14  
    15  import (
    16  	"bytes"
    17  	"flag"
    18  	"fmt"
    19  	"go/ast"
    20  	"go/format"
    21  	"go/parser"
    22  	"go/token"
    23  	"log"
    24  	"os"
    25  	"sort"
    26  	"strings"
    27  	"text/template"
    28  )
    29  
    30  const (
    31  	fileSuffix = "-accessors.go"
    32  )
    33  
    34  var (
    35  	verbose = flag.Bool("v", false, "Print verbose log messages")
    36  
    37  	sourceTmpl = template.Must(template.New("source").Parse(source))
    38  	testTmpl   = template.Must(template.New("test").Parse(test))
    39  
    40  	// skipStructMethods lists "struct.method" combos to skip.
    41  	skipStructMethods = map[string]bool{
    42  		"RepositoryContent.GetContent":    true,
    43  		"Client.GetBaseURL":               true,
    44  		"Client.GetUploadURL":             true,
    45  		"ErrorResponse.GetResponse":       true,
    46  		"RateLimitError.GetResponse":      true,
    47  		"AbuseRateLimitError.GetResponse": true,
    48  	}
    49  	// skipStructs lists structs to skip.
    50  	skipStructs = map[string]bool{
    51  		"Client": true,
    52  	}
    53  
    54  	// whitelistSliceGetters lists "struct.field" to add getter method
    55  	whitelistSliceGetters = map[string]bool{
    56  		"PushEvent.Commits": true,
    57  	}
    58  )
    59  
    60  func logf(fmt string, args ...interface{}) {
    61  	if *verbose {
    62  		log.Printf(fmt, args...)
    63  	}
    64  }
    65  
    66  func main() {
    67  	flag.Parse()
    68  	fset := token.NewFileSet()
    69  
    70  	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
    71  	if err != nil {
    72  		log.Fatal(err)
    73  		return
    74  	}
    75  
    76  	for pkgName, pkg := range pkgs {
    77  		t := &templateData{
    78  			filename: pkgName + fileSuffix,
    79  			Year:     2017,
    80  			Package:  pkgName,
    81  			Imports:  map[string]string{},
    82  		}
    83  		for filename, f := range pkg.Files {
    84  			logf("Processing %v...", filename)
    85  			if err := t.processAST(f); err != nil {
    86  				log.Fatal(err)
    87  			}
    88  		}
    89  		if err := t.dump(); err != nil {
    90  			log.Fatal(err)
    91  		}
    92  	}
    93  	logf("Done.")
    94  }
    95  
    96  func (t *templateData) processAST(f *ast.File) error {
    97  	for _, decl := range f.Decls {
    98  		gd, ok := decl.(*ast.GenDecl)
    99  		if !ok {
   100  			continue
   101  		}
   102  		for _, spec := range gd.Specs {
   103  			ts, ok := spec.(*ast.TypeSpec)
   104  			if !ok {
   105  				continue
   106  			}
   107  			// Skip unexported identifiers.
   108  			if !ts.Name.IsExported() {
   109  				logf("Struct %v is unexported; skipping.", ts.Name)
   110  				continue
   111  			}
   112  			// Check if the struct should be skipped.
   113  			if skipStructs[ts.Name.Name] {
   114  				logf("Struct %v is in skip list; skipping.", ts.Name)
   115  				continue
   116  			}
   117  			st, ok := ts.Type.(*ast.StructType)
   118  			if !ok {
   119  				continue
   120  			}
   121  			for _, field := range st.Fields.List {
   122  				if len(field.Names) == 0 {
   123  					continue
   124  				}
   125  
   126  				fieldName := field.Names[0]
   127  				// Skip unexported identifiers.
   128  				if !fieldName.IsExported() {
   129  					logf("Field %v is unexported; skipping.", fieldName)
   130  					continue
   131  				}
   132  				// Check if "struct.method" should be skipped.
   133  				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
   134  					logf("Method %v is skip list; skipping.", key)
   135  					continue
   136  				}
   137  
   138  				se, ok := field.Type.(*ast.StarExpr)
   139  				if !ok {
   140  					switch x := field.Type.(type) {
   141  					case *ast.MapType:
   142  						t.addMapType(x, ts.Name.String(), fieldName.String(), false)
   143  						continue
   144  					case *ast.ArrayType:
   145  						if key := fmt.Sprintf("%v.%v", ts.Name, fieldName); whitelistSliceGetters[key] {
   146  							logf("Method %v is whitelist; adding getter method.", key)
   147  							t.addArrayType(x, ts.Name.String(), fieldName.String(), false)
   148  							continue
   149  						}
   150  					}
   151  
   152  					logf("Skipping field type %T, fieldName=%v", field.Type, fieldName)
   153  					continue
   154  				}
   155  
   156  				switch x := se.X.(type) {
   157  				case *ast.ArrayType:
   158  					t.addArrayType(x, ts.Name.String(), fieldName.String(), true)
   159  				case *ast.Ident:
   160  					t.addIdent(x, ts.Name.String(), fieldName.String())
   161  				case *ast.MapType:
   162  					t.addMapType(x, ts.Name.String(), fieldName.String(), true)
   163  				case *ast.SelectorExpr:
   164  					t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
   165  				default:
   166  					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
   167  				}
   168  			}
   169  		}
   170  	}
   171  	return nil
   172  }
   173  
   174  func sourceFilter(fi os.FileInfo) bool {
   175  	return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
   176  }
   177  
   178  func (t *templateData) dump() error {
   179  	if len(t.Getters) == 0 {
   180  		logf("No getters for %v; skipping.", t.filename)
   181  		return nil
   182  	}
   183  
   184  	// Sort getters by ReceiverType.FieldName.
   185  	sort.Sort(byName(t.Getters))
   186  
   187  	processTemplate := func(tmpl *template.Template, filename string) error {
   188  		var buf bytes.Buffer
   189  		if err := tmpl.Execute(&buf, t); err != nil {
   190  			return err
   191  		}
   192  		clean, err := format.Source(buf.Bytes())
   193  		if err != nil {
   194  			return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err)
   195  		}
   196  
   197  		logf("Writing %v...", filename)
   198  		if err := os.Chmod(filename, 0644); err != nil {
   199  			return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err)
   200  		}
   201  
   202  		if err := os.WriteFile(filename, clean, 0444); err != nil {
   203  			return err
   204  		}
   205  
   206  		if err := os.Chmod(filename, 0444); err != nil {
   207  			return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err)
   208  		}
   209  
   210  		return nil
   211  	}
   212  
   213  	if err := processTemplate(sourceTmpl, t.filename); err != nil {
   214  		return err
   215  	}
   216  	return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go"))
   217  }
   218  
   219  func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
   220  	return &getter{
   221  		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
   222  		ReceiverVar:  strings.ToLower(receiverType[:1]),
   223  		ReceiverType: receiverType,
   224  		FieldName:    fieldName,
   225  		FieldType:    fieldType,
   226  		ZeroValue:    zeroValue,
   227  		NamedStruct:  namedStruct,
   228  	}
   229  }
   230  
   231  func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string, isAPointer bool) {
   232  	var eltType string
   233  	var ng *getter
   234  	switch elt := x.Elt.(type) {
   235  	case *ast.Ident:
   236  		eltType = elt.String()
   237  		ng = newGetter(receiverType, fieldName, "[]"+eltType, "nil", false)
   238  	case *ast.StarExpr:
   239  		ident, ok := elt.X.(*ast.Ident)
   240  		if !ok {
   241  			return
   242  		}
   243  		ng = newGetter(receiverType, fieldName, "[]*"+ident.String(), "nil", false)
   244  	default:
   245  		logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
   246  		return
   247  	}
   248  
   249  	ng.ArrayType = !isAPointer
   250  	t.Getters = append(t.Getters, ng)
   251  }
   252  
   253  func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
   254  	var zeroValue string
   255  	var namedStruct = false
   256  	switch x.String() {
   257  	case "int", "int64":
   258  		zeroValue = "0"
   259  	case "string":
   260  		zeroValue = `""`
   261  	case "bool":
   262  		zeroValue = "false"
   263  	case "Timestamp":
   264  		zeroValue = "Timestamp{}"
   265  	default:
   266  		zeroValue = "nil"
   267  		namedStruct = true
   268  	}
   269  
   270  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   271  }
   272  
   273  func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) {
   274  	var keyType string
   275  	switch key := x.Key.(type) {
   276  	case *ast.Ident:
   277  		keyType = key.String()
   278  	default:
   279  		logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
   280  		return
   281  	}
   282  
   283  	var valueType string
   284  	switch value := x.Value.(type) {
   285  	case *ast.Ident:
   286  		valueType = value.String()
   287  	default:
   288  		logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
   289  		return
   290  	}
   291  
   292  	fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
   293  	zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
   294  	ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false)
   295  	ng.MapType = !isAPointer
   296  	t.Getters = append(t.Getters, ng)
   297  }
   298  
   299  func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
   300  	if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
   301  		return
   302  	}
   303  
   304  	var xX string
   305  	if xx, ok := x.X.(*ast.Ident); ok {
   306  		xX = xx.String()
   307  	}
   308  
   309  	switch xX {
   310  	case "time", "json":
   311  		if xX == "json" {
   312  			t.Imports["encoding/json"] = "encoding/json"
   313  		} else {
   314  			t.Imports[xX] = xX
   315  		}
   316  		fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
   317  		zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
   318  		if xX == "time" && x.Sel.Name == "Duration" {
   319  			zeroValue = "0"
   320  		}
   321  		t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
   322  	default:
   323  		logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
   324  	}
   325  }
   326  
   327  type templateData struct {
   328  	filename string
   329  	Year     int
   330  	Package  string
   331  	Imports  map[string]string
   332  	Getters  []*getter
   333  }
   334  
   335  type getter struct {
   336  	sortVal      string // Lower-case version of "ReceiverType.FieldName".
   337  	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
   338  	ReceiverType string
   339  	FieldName    string
   340  	FieldType    string
   341  	ZeroValue    string
   342  	NamedStruct  bool // Getter for named struct.
   343  	MapType      bool
   344  	ArrayType    bool
   345  }
   346  
   347  type byName []*getter
   348  
   349  func (b byName) Len() int           { return len(b) }
   350  func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
   351  func (b byName) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
   352  
   353  const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   354  //
   355  // Use of this source code is governed by a BSD-style
   356  // license that can be found in the LICENSE file.
   357  
   358  // Code generated by gen-accessors; DO NOT EDIT.
   359  // Instead, please run "go generate ./..." as described here:
   360  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   361  
   362  package {{.Package}}
   363  {{with .Imports}}
   364  import (
   365    {{- range . -}}
   366    "{{.}}"
   367    {{end -}}
   368  )
   369  {{end}}
   370  {{range .Getters}}
   371  {{if .NamedStruct}}
   372  // Get{{.FieldName}} returns the {{.FieldName}} field.
   373  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
   374    if {{.ReceiverVar}} == nil {
   375      return {{.ZeroValue}}
   376    }
   377    return {{.ReceiverVar}}.{{.FieldName}}
   378  }
   379  {{else if or .MapType .ArrayType }}
   380  // Get{{.FieldName}} returns the {{.FieldName}} {{if .MapType}}map{{else if .ArrayType }}slice{{end}} if it's non-nil, {{if .MapType}}an empty map{{else if .ArrayType }}nil{{end}} otherwise.
   381  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   382    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   383      return {{.ZeroValue}}
   384    }
   385    return {{.ReceiverVar}}.{{.FieldName}}
   386  }
   387  {{else}}
   388  // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
   389  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   390    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   391      return {{.ZeroValue}}
   392    }
   393    return *{{.ReceiverVar}}.{{.FieldName}}
   394  }
   395  {{end}}
   396  {{end}}
   397  `
   398  
   399  const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   400  //
   401  // Use of this source code is governed by a BSD-style
   402  // license that can be found in the LICENSE file.
   403  
   404  // Code generated by gen-accessors; DO NOT EDIT.
   405  // Instead, please run "go generate ./..." as described here:
   406  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   407  
   408  package {{.Package}}
   409  {{with .Imports}}
   410  import (
   411    "testing"
   412    {{range . -}}
   413    "{{.}}"
   414    {{end -}}
   415  )
   416  {{end}}
   417  {{range .Getters}}
   418  {{if .NamedStruct}}
   419  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   420    tt.Parallel()
   421    {{.ReceiverVar}} := &{{.ReceiverType}}{}
   422    {{.ReceiverVar}}.Get{{.FieldName}}()
   423    {{.ReceiverVar}} = nil
   424    {{.ReceiverVar}}.Get{{.FieldName}}()
   425  }
   426  {{else if or .MapType .ArrayType}}
   427  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   428    tt.Parallel()
   429    zeroValue := {{.FieldType}}{}
   430    {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue }
   431    {{.ReceiverVar}}.Get{{.FieldName}}()
   432    {{.ReceiverVar}} = &{{.ReceiverType}}{}
   433    {{.ReceiverVar}}.Get{{.FieldName}}()
   434    {{.ReceiverVar}} = nil
   435    {{.ReceiverVar}}.Get{{.FieldName}}()
   436  }
   437  {{else}}
   438  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   439    tt.Parallel()
   440    var zeroValue {{.FieldType}}
   441    {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue }
   442    {{.ReceiverVar}}.Get{{.FieldName}}()
   443    {{.ReceiverVar}} = &{{.ReceiverType}}{}
   444    {{.ReceiverVar}}.Get{{.FieldName}}()
   445    {{.ReceiverVar}} = nil
   446    {{.ReceiverVar}}.Get{{.FieldName}}()
   447  }
   448  {{end}}
   449  {{end}}
   450  `