github.com/google/go-github/v49@v49.1.0/github/gen-accessors.go (about)

     1  // Copyright 2017 The go-github AUTHORS. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build ignore
     7  // +build ignore
     8  
     9  // gen-accessors generates accessor methods for structs with pointer fields.
    10  //
    11  // It is meant to be used by go-github contributors in conjunction with the
    12  // go generate tool before sending a PR to GitHub.
    13  // Please see the CONTRIBUTING.md file for more information.
    14  package main
    15  
    16  import (
    17  	"bytes"
    18  	"flag"
    19  	"fmt"
    20  	"go/ast"
    21  	"go/format"
    22  	"go/parser"
    23  	"go/token"
    24  	"log"
    25  	"os"
    26  	"sort"
    27  	"strings"
    28  	"text/template"
    29  )
    30  
    31  const (
    32  	fileSuffix = "-accessors.go"
    33  )
    34  
    35  var (
    36  	verbose = flag.Bool("v", false, "Print verbose log messages")
    37  
    38  	sourceTmpl = template.Must(template.New("source").Parse(source))
    39  	testTmpl   = template.Must(template.New("test").Parse(test))
    40  
    41  	// skipStructMethods lists "struct.method" combos to skip.
    42  	skipStructMethods = map[string]bool{
    43  		"RepositoryContent.GetContent":    true,
    44  		"Client.GetBaseURL":               true,
    45  		"Client.GetUploadURL":             true,
    46  		"ErrorResponse.GetResponse":       true,
    47  		"RateLimitError.GetResponse":      true,
    48  		"AbuseRateLimitError.GetResponse": true,
    49  	}
    50  	// skipStructs lists structs to skip.
    51  	skipStructs = map[string]bool{
    52  		"Client": true,
    53  	}
    54  )
    55  
    56  func logf(fmt string, args ...interface{}) {
    57  	if *verbose {
    58  		log.Printf(fmt, args...)
    59  	}
    60  }
    61  
    62  func main() {
    63  	flag.Parse()
    64  	fset := token.NewFileSet()
    65  
    66  	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
    67  	if err != nil {
    68  		log.Fatal(err)
    69  		return
    70  	}
    71  
    72  	for pkgName, pkg := range pkgs {
    73  		t := &templateData{
    74  			filename: pkgName + fileSuffix,
    75  			Year:     2017,
    76  			Package:  pkgName,
    77  			Imports:  map[string]string{},
    78  		}
    79  		for filename, f := range pkg.Files {
    80  			logf("Processing %v...", filename)
    81  			if err := t.processAST(f); err != nil {
    82  				log.Fatal(err)
    83  			}
    84  		}
    85  		if err := t.dump(); err != nil {
    86  			log.Fatal(err)
    87  		}
    88  	}
    89  	logf("Done.")
    90  }
    91  
    92  func (t *templateData) processAST(f *ast.File) error {
    93  	for _, decl := range f.Decls {
    94  		gd, ok := decl.(*ast.GenDecl)
    95  		if !ok {
    96  			continue
    97  		}
    98  		for _, spec := range gd.Specs {
    99  			ts, ok := spec.(*ast.TypeSpec)
   100  			if !ok {
   101  				continue
   102  			}
   103  			// Skip unexported identifiers.
   104  			if !ts.Name.IsExported() {
   105  				logf("Struct %v is unexported; skipping.", ts.Name)
   106  				continue
   107  			}
   108  			// Check if the struct should be skipped.
   109  			if skipStructs[ts.Name.Name] {
   110  				logf("Struct %v is in skip list; skipping.", ts.Name)
   111  				continue
   112  			}
   113  			st, ok := ts.Type.(*ast.StructType)
   114  			if !ok {
   115  				continue
   116  			}
   117  			for _, field := range st.Fields.List {
   118  				if len(field.Names) == 0 {
   119  					continue
   120  				}
   121  
   122  				fieldName := field.Names[0]
   123  				// Skip unexported identifiers.
   124  				if !fieldName.IsExported() {
   125  					logf("Field %v is unexported; skipping.", fieldName)
   126  					continue
   127  				}
   128  				// Check if "struct.method" should be skipped.
   129  				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
   130  					logf("Method %v is skip list; skipping.", key)
   131  					continue
   132  				}
   133  
   134  				se, ok := field.Type.(*ast.StarExpr)
   135  				if !ok {
   136  					switch x := field.Type.(type) {
   137  					case *ast.MapType:
   138  						t.addMapType(x, ts.Name.String(), fieldName.String(), false)
   139  						continue
   140  					}
   141  
   142  					logf("Skipping field type %T, fieldName=%v", field.Type, fieldName)
   143  					continue
   144  				}
   145  
   146  				switch x := se.X.(type) {
   147  				case *ast.ArrayType:
   148  					t.addArrayType(x, ts.Name.String(), fieldName.String())
   149  				case *ast.Ident:
   150  					t.addIdent(x, ts.Name.String(), fieldName.String())
   151  				case *ast.MapType:
   152  					t.addMapType(x, ts.Name.String(), fieldName.String(), true)
   153  				case *ast.SelectorExpr:
   154  					t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
   155  				default:
   156  					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
   157  				}
   158  			}
   159  		}
   160  	}
   161  	return nil
   162  }
   163  
   164  func sourceFilter(fi os.FileInfo) bool {
   165  	return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
   166  }
   167  
   168  func (t *templateData) dump() error {
   169  	if len(t.Getters) == 0 {
   170  		logf("No getters for %v; skipping.", t.filename)
   171  		return nil
   172  	}
   173  
   174  	// Sort getters by ReceiverType.FieldName.
   175  	sort.Sort(byName(t.Getters))
   176  
   177  	processTemplate := func(tmpl *template.Template, filename string) error {
   178  		var buf bytes.Buffer
   179  		if err := tmpl.Execute(&buf, t); err != nil {
   180  			return err
   181  		}
   182  		clean, err := format.Source(buf.Bytes())
   183  		if err != nil {
   184  			return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err)
   185  		}
   186  
   187  		logf("Writing %v...", filename)
   188  		if err := os.Chmod(filename, 0644); err != nil {
   189  			return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err)
   190  		}
   191  
   192  		if err := os.WriteFile(filename, clean, 0444); err != nil {
   193  			return err
   194  		}
   195  
   196  		if err := os.Chmod(filename, 0444); err != nil {
   197  			return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err)
   198  		}
   199  
   200  		return nil
   201  	}
   202  
   203  	if err := processTemplate(sourceTmpl, t.filename); err != nil {
   204  		return err
   205  	}
   206  	return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go"))
   207  }
   208  
   209  func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
   210  	return &getter{
   211  		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
   212  		ReceiverVar:  strings.ToLower(receiverType[:1]),
   213  		ReceiverType: receiverType,
   214  		FieldName:    fieldName,
   215  		FieldType:    fieldType,
   216  		ZeroValue:    zeroValue,
   217  		NamedStruct:  namedStruct,
   218  	}
   219  }
   220  
   221  func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) {
   222  	var eltType string
   223  	switch elt := x.Elt.(type) {
   224  	case *ast.Ident:
   225  		eltType = elt.String()
   226  	default:
   227  		logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
   228  		return
   229  	}
   230  
   231  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false))
   232  }
   233  
   234  func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
   235  	var zeroValue string
   236  	var namedStruct = false
   237  	switch x.String() {
   238  	case "int", "int64":
   239  		zeroValue = "0"
   240  	case "string":
   241  		zeroValue = `""`
   242  	case "bool":
   243  		zeroValue = "false"
   244  	case "Timestamp":
   245  		zeroValue = "Timestamp{}"
   246  	default:
   247  		zeroValue = "nil"
   248  		namedStruct = true
   249  	}
   250  
   251  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   252  }
   253  
   254  func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) {
   255  	var keyType string
   256  	switch key := x.Key.(type) {
   257  	case *ast.Ident:
   258  		keyType = key.String()
   259  	default:
   260  		logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
   261  		return
   262  	}
   263  
   264  	var valueType string
   265  	switch value := x.Value.(type) {
   266  	case *ast.Ident:
   267  		valueType = value.String()
   268  	default:
   269  		logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
   270  		return
   271  	}
   272  
   273  	fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
   274  	zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
   275  	ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false)
   276  	ng.MapType = !isAPointer
   277  	t.Getters = append(t.Getters, ng)
   278  }
   279  
   280  func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
   281  	if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
   282  		return
   283  	}
   284  
   285  	var xX string
   286  	if xx, ok := x.X.(*ast.Ident); ok {
   287  		xX = xx.String()
   288  	}
   289  
   290  	switch xX {
   291  	case "time", "json":
   292  		if xX == "json" {
   293  			t.Imports["encoding/json"] = "encoding/json"
   294  		} else {
   295  			t.Imports[xX] = xX
   296  		}
   297  		fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
   298  		zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
   299  		if xX == "time" && x.Sel.Name == "Duration" {
   300  			zeroValue = "0"
   301  		}
   302  		t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
   303  	default:
   304  		logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
   305  	}
   306  }
   307  
   308  type templateData struct {
   309  	filename string
   310  	Year     int
   311  	Package  string
   312  	Imports  map[string]string
   313  	Getters  []*getter
   314  }
   315  
   316  type getter struct {
   317  	sortVal      string // Lower-case version of "ReceiverType.FieldName".
   318  	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
   319  	ReceiverType string
   320  	FieldName    string
   321  	FieldType    string
   322  	ZeroValue    string
   323  	NamedStruct  bool // Getter for named struct.
   324  	MapType      bool
   325  }
   326  
   327  type byName []*getter
   328  
   329  func (b byName) Len() int           { return len(b) }
   330  func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
   331  func (b byName) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
   332  
   333  const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   334  //
   335  // Use of this source code is governed by a BSD-style
   336  // license that can be found in the LICENSE file.
   337  
   338  // Code generated by gen-accessors; DO NOT EDIT.
   339  // Instead, please run "go generate ./..." as described here:
   340  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   341  
   342  package {{.Package}}
   343  {{with .Imports}}
   344  import (
   345    {{- range . -}}
   346    "{{.}}"
   347    {{end -}}
   348  )
   349  {{end}}
   350  {{range .Getters}}
   351  {{if .NamedStruct}}
   352  // Get{{.FieldName}} returns the {{.FieldName}} field.
   353  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
   354    if {{.ReceiverVar}} == nil {
   355      return {{.ZeroValue}}
   356    }
   357    return {{.ReceiverVar}}.{{.FieldName}}
   358  }
   359  {{else if .MapType}}
   360  // Get{{.FieldName}} returns the {{.FieldName}} map if it's non-nil, an empty map otherwise.
   361  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   362    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   363      return {{.ZeroValue}}
   364    }
   365    return {{.ReceiverVar}}.{{.FieldName}}
   366  }
   367  {{else}}
   368  // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
   369  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   370    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   371      return {{.ZeroValue}}
   372    }
   373    return *{{.ReceiverVar}}.{{.FieldName}}
   374  }
   375  {{end}}
   376  {{end}}
   377  `
   378  
   379  const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   380  //
   381  // Use of this source code is governed by a BSD-style
   382  // license that can be found in the LICENSE file.
   383  
   384  // Code generated by gen-accessors; DO NOT EDIT.
   385  // Instead, please run "go generate ./..." as described here:
   386  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   387  
   388  package {{.Package}}
   389  {{with .Imports}}
   390  import (
   391    "testing"
   392    {{range . -}}
   393    "{{.}}"
   394    {{end -}}
   395  )
   396  {{end}}
   397  {{range .Getters}}
   398  {{if .NamedStruct}}
   399  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   400    {{.ReceiverVar}} := &{{.ReceiverType}}{}
   401    {{.ReceiverVar}}.Get{{.FieldName}}()
   402    {{.ReceiverVar}} = nil
   403    {{.ReceiverVar}}.Get{{.FieldName}}()
   404  }
   405  {{else if .MapType}}
   406  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   407    zeroValue := {{.FieldType}}{}
   408    {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue }
   409    {{.ReceiverVar}}.Get{{.FieldName}}()
   410    {{.ReceiverVar}} = &{{.ReceiverType}}{}
   411    {{.ReceiverVar}}.Get{{.FieldName}}()
   412    {{.ReceiverVar}} = nil
   413    {{.ReceiverVar}}.Get{{.FieldName}}()
   414  }
   415  {{else}}
   416  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   417    var zeroValue {{.FieldType}}
   418    {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue }
   419    {{.ReceiverVar}}.Get{{.FieldName}}()
   420    {{.ReceiverVar}} = &{{.ReceiverType}}{}
   421    {{.ReceiverVar}}.Get{{.FieldName}}()
   422    {{.ReceiverVar}} = nil
   423    {{.ReceiverVar}}.Get{{.FieldName}}()
   424  }
   425  {{end}}
   426  {{end}}
   427  `