github.com/google/go-github/v33@v33.0.0/github/gen-accessors.go (about)

     1  // Copyright 2017 The go-github AUTHORS. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  // +build ignore
     7  
     8  // gen-accessors generates accessor methods for structs with pointer fields.
     9  //
    10  // It is meant to be used by go-github contributors in conjunction with the
    11  // go generate tool before sending a PR to GitHub.
    12  // Please see the CONTRIBUTING.md file for more information.
    13  package main
    14  
    15  import (
    16  	"bytes"
    17  	"flag"
    18  	"fmt"
    19  	"go/ast"
    20  	"go/format"
    21  	"go/parser"
    22  	"go/token"
    23  	"io/ioutil"
    24  	"log"
    25  	"os"
    26  	"sort"
    27  	"strings"
    28  	"text/template"
    29  )
    30  
    31  const (
    32  	fileSuffix = "-accessors.go"
    33  )
    34  
    35  var (
    36  	verbose = flag.Bool("v", false, "Print verbose log messages")
    37  
    38  	sourceTmpl = template.Must(template.New("source").Parse(source))
    39  
    40  	// skipStructMethods lists "struct.method" combos to skip.
    41  	skipStructMethods = map[string]bool{
    42  		"RepositoryContent.GetContent":    true,
    43  		"Client.GetBaseURL":               true,
    44  		"Client.GetUploadURL":             true,
    45  		"ErrorResponse.GetResponse":       true,
    46  		"RateLimitError.GetResponse":      true,
    47  		"AbuseRateLimitError.GetResponse": true,
    48  	}
    49  	// skipStructs lists structs to skip.
    50  	skipStructs = map[string]bool{
    51  		"Client": true,
    52  	}
    53  )
    54  
    55  func logf(fmt string, args ...interface{}) {
    56  	if *verbose {
    57  		log.Printf(fmt, args...)
    58  	}
    59  }
    60  
    61  func main() {
    62  	flag.Parse()
    63  	fset := token.NewFileSet()
    64  
    65  	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
    66  	if err != nil {
    67  		log.Fatal(err)
    68  		return
    69  	}
    70  
    71  	for pkgName, pkg := range pkgs {
    72  		t := &templateData{
    73  			filename: pkgName + fileSuffix,
    74  			Year:     2017,
    75  			Package:  pkgName,
    76  			Imports:  map[string]string{},
    77  		}
    78  		for filename, f := range pkg.Files {
    79  			logf("Processing %v...", filename)
    80  			if err := t.processAST(f); err != nil {
    81  				log.Fatal(err)
    82  			}
    83  		}
    84  		if err := t.dump(); err != nil {
    85  			log.Fatal(err)
    86  		}
    87  	}
    88  	logf("Done.")
    89  }
    90  
    91  func (t *templateData) processAST(f *ast.File) error {
    92  	for _, decl := range f.Decls {
    93  		gd, ok := decl.(*ast.GenDecl)
    94  		if !ok {
    95  			continue
    96  		}
    97  		for _, spec := range gd.Specs {
    98  			ts, ok := spec.(*ast.TypeSpec)
    99  			if !ok {
   100  				continue
   101  			}
   102  			// Skip unexported identifiers.
   103  			if !ts.Name.IsExported() {
   104  				logf("Struct %v is unexported; skipping.", ts.Name)
   105  				continue
   106  			}
   107  			// Check if the struct should be skipped.
   108  			if skipStructs[ts.Name.Name] {
   109  				logf("Struct %v is in skip list; skipping.", ts.Name)
   110  				continue
   111  			}
   112  			st, ok := ts.Type.(*ast.StructType)
   113  			if !ok {
   114  				continue
   115  			}
   116  			for _, field := range st.Fields.List {
   117  				se, ok := field.Type.(*ast.StarExpr)
   118  				if len(field.Names) == 0 || !ok {
   119  					continue
   120  				}
   121  
   122  				fieldName := field.Names[0]
   123  				// Skip unexported identifiers.
   124  				if !fieldName.IsExported() {
   125  					logf("Field %v is unexported; skipping.", fieldName)
   126  					continue
   127  				}
   128  				// Check if "struct.method" should be skipped.
   129  				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
   130  					logf("Method %v is skip list; skipping.", key)
   131  					continue
   132  				}
   133  
   134  				switch x := se.X.(type) {
   135  				case *ast.ArrayType:
   136  					t.addArrayType(x, ts.Name.String(), fieldName.String())
   137  				case *ast.Ident:
   138  					t.addIdent(x, ts.Name.String(), fieldName.String())
   139  				case *ast.MapType:
   140  					t.addMapType(x, ts.Name.String(), fieldName.String())
   141  				case *ast.SelectorExpr:
   142  					t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
   143  				default:
   144  					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
   145  				}
   146  			}
   147  		}
   148  	}
   149  	return nil
   150  }
   151  
   152  func sourceFilter(fi os.FileInfo) bool {
   153  	return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
   154  }
   155  
   156  func (t *templateData) dump() error {
   157  	if len(t.Getters) == 0 {
   158  		logf("No getters for %v; skipping.", t.filename)
   159  		return nil
   160  	}
   161  
   162  	// Sort getters by ReceiverType.FieldName.
   163  	sort.Sort(byName(t.Getters))
   164  
   165  	var buf bytes.Buffer
   166  	if err := sourceTmpl.Execute(&buf, t); err != nil {
   167  		return err
   168  	}
   169  	clean, err := format.Source(buf.Bytes())
   170  	if err != nil {
   171  		return err
   172  	}
   173  
   174  	logf("Writing %v...", t.filename)
   175  	return ioutil.WriteFile(t.filename, clean, 0644)
   176  }
   177  
   178  func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
   179  	return &getter{
   180  		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
   181  		ReceiverVar:  strings.ToLower(receiverType[:1]),
   182  		ReceiverType: receiverType,
   183  		FieldName:    fieldName,
   184  		FieldType:    fieldType,
   185  		ZeroValue:    zeroValue,
   186  		NamedStruct:  namedStruct,
   187  	}
   188  }
   189  
   190  func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) {
   191  	var eltType string
   192  	switch elt := x.Elt.(type) {
   193  	case *ast.Ident:
   194  		eltType = elt.String()
   195  	default:
   196  		logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
   197  		return
   198  	}
   199  
   200  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false))
   201  }
   202  
   203  func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
   204  	var zeroValue string
   205  	var namedStruct = false
   206  	switch x.String() {
   207  	case "int", "int64":
   208  		zeroValue = "0"
   209  	case "string":
   210  		zeroValue = `""`
   211  	case "bool":
   212  		zeroValue = "false"
   213  	case "Timestamp":
   214  		zeroValue = "Timestamp{}"
   215  	default:
   216  		zeroValue = "nil"
   217  		namedStruct = true
   218  	}
   219  
   220  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   221  }
   222  
   223  func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string) {
   224  	var keyType string
   225  	switch key := x.Key.(type) {
   226  	case *ast.Ident:
   227  		keyType = key.String()
   228  	default:
   229  		logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
   230  		return
   231  	}
   232  
   233  	var valueType string
   234  	switch value := x.Value.(type) {
   235  	case *ast.Ident:
   236  		valueType = value.String()
   237  	default:
   238  		logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
   239  		return
   240  	}
   241  
   242  	fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
   243  	zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
   244  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
   245  }
   246  
   247  func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
   248  	if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
   249  		return
   250  	}
   251  
   252  	var xX string
   253  	if xx, ok := x.X.(*ast.Ident); ok {
   254  		xX = xx.String()
   255  	}
   256  
   257  	switch xX {
   258  	case "time", "json":
   259  		if xX == "json" {
   260  			t.Imports["encoding/json"] = "encoding/json"
   261  		} else {
   262  			t.Imports[xX] = xX
   263  		}
   264  		fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
   265  		zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
   266  		if xX == "time" && x.Sel.Name == "Duration" {
   267  			zeroValue = "0"
   268  		}
   269  		t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
   270  	default:
   271  		logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
   272  	}
   273  }
   274  
   275  type templateData struct {
   276  	filename string
   277  	Year     int
   278  	Package  string
   279  	Imports  map[string]string
   280  	Getters  []*getter
   281  }
   282  
   283  type getter struct {
   284  	sortVal      string // Lower-case version of "ReceiverType.FieldName".
   285  	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
   286  	ReceiverType string
   287  	FieldName    string
   288  	FieldType    string
   289  	ZeroValue    string
   290  	NamedStruct  bool // Getter for named struct.
   291  }
   292  
   293  type byName []*getter
   294  
   295  func (b byName) Len() int           { return len(b) }
   296  func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
   297  func (b byName) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
   298  
   299  const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   300  //
   301  // Use of this source code is governed by a BSD-style
   302  // license that can be found in the LICENSE file.
   303  
   304  // Code generated by gen-accessors; DO NOT EDIT.
   305  
   306  package {{.Package}}
   307  {{with .Imports}}
   308  import (
   309    {{- range . -}}
   310    "{{.}}"
   311    {{end -}}
   312  )
   313  {{end}}
   314  {{range .Getters}}
   315  {{if .NamedStruct}}
   316  // Get{{.FieldName}} returns the {{.FieldName}} field.
   317  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
   318    if {{.ReceiverVar}} == nil {
   319      return {{.ZeroValue}}
   320    }
   321    return {{.ReceiverVar}}.{{.FieldName}}
   322  }
   323  {{else}}
   324  // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
   325  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   326    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   327      return {{.ZeroValue}}
   328    }
   329    return *{{.ReceiverVar}}.{{.FieldName}}
   330  }
   331  {{end}}
   332  {{end}}
   333  `