github.com/google/go-github/v42@v42.0.0/github/gen-accessors.go (about) 1 // Copyright 2017 The go-github AUTHORS. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 //go:build ignore 7 // +build ignore 8 9 // gen-accessors generates accessor methods for structs with pointer fields. 10 // 11 // It is meant to be used by go-github contributors in conjunction with the 12 // go generate tool before sending a PR to GitHub. 13 // Please see the CONTRIBUTING.md file for more information. 14 package main 15 16 import ( 17 "bytes" 18 "flag" 19 "fmt" 20 "go/ast" 21 "go/format" 22 "go/parser" 23 "go/token" 24 "io/ioutil" 25 "log" 26 "os" 27 "sort" 28 "strings" 29 "text/template" 30 ) 31 32 const ( 33 fileSuffix = "-accessors.go" 34 ) 35 36 var ( 37 verbose = flag.Bool("v", false, "Print verbose log messages") 38 39 sourceTmpl = template.Must(template.New("source").Parse(source)) 40 testTmpl = template.Must(template.New("test").Parse(test)) 41 42 // skipStructMethods lists "struct.method" combos to skip. 43 skipStructMethods = map[string]bool{ 44 "RepositoryContent.GetContent": true, 45 "Client.GetBaseURL": true, 46 "Client.GetUploadURL": true, 47 "ErrorResponse.GetResponse": true, 48 "RateLimitError.GetResponse": true, 49 "AbuseRateLimitError.GetResponse": true, 50 } 51 // skipStructs lists structs to skip. 52 skipStructs = map[string]bool{ 53 "Client": true, 54 } 55 ) 56 57 func logf(fmt string, args ...interface{}) { 58 if *verbose { 59 log.Printf(fmt, args...) 60 } 61 } 62 63 func main() { 64 flag.Parse() 65 fset := token.NewFileSet() 66 67 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0) 68 if err != nil { 69 log.Fatal(err) 70 return 71 } 72 73 for pkgName, pkg := range pkgs { 74 t := &templateData{ 75 filename: pkgName + fileSuffix, 76 Year: 2017, 77 Package: pkgName, 78 Imports: map[string]string{}, 79 } 80 for filename, f := range pkg.Files { 81 logf("Processing %v...", filename) 82 if err := t.processAST(f); err != nil { 83 log.Fatal(err) 84 } 85 } 86 if err := t.dump(); err != nil { 87 log.Fatal(err) 88 } 89 } 90 logf("Done.") 91 } 92 93 func (t *templateData) processAST(f *ast.File) error { 94 for _, decl := range f.Decls { 95 gd, ok := decl.(*ast.GenDecl) 96 if !ok { 97 continue 98 } 99 for _, spec := range gd.Specs { 100 ts, ok := spec.(*ast.TypeSpec) 101 if !ok { 102 continue 103 } 104 // Skip unexported identifiers. 105 if !ts.Name.IsExported() { 106 logf("Struct %v is unexported; skipping.", ts.Name) 107 continue 108 } 109 // Check if the struct should be skipped. 110 if skipStructs[ts.Name.Name] { 111 logf("Struct %v is in skip list; skipping.", ts.Name) 112 continue 113 } 114 st, ok := ts.Type.(*ast.StructType) 115 if !ok { 116 continue 117 } 118 for _, field := range st.Fields.List { 119 if len(field.Names) == 0 { 120 continue 121 } 122 123 fieldName := field.Names[0] 124 // Skip unexported identifiers. 125 if !fieldName.IsExported() { 126 logf("Field %v is unexported; skipping.", fieldName) 127 continue 128 } 129 // Check if "struct.method" should be skipped. 130 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] { 131 logf("Method %v is skip list; skipping.", key) 132 continue 133 } 134 135 se, ok := field.Type.(*ast.StarExpr) 136 if !ok { 137 switch x := field.Type.(type) { 138 case *ast.MapType: 139 t.addMapType(x, ts.Name.String(), fieldName.String(), false) 140 continue 141 } 142 143 logf("Skipping field type %T, fieldName=%v", field.Type, fieldName) 144 continue 145 } 146 147 switch x := se.X.(type) { 148 case *ast.ArrayType: 149 t.addArrayType(x, ts.Name.String(), fieldName.String()) 150 case *ast.Ident: 151 t.addIdent(x, ts.Name.String(), fieldName.String()) 152 case *ast.MapType: 153 t.addMapType(x, ts.Name.String(), fieldName.String(), true) 154 case *ast.SelectorExpr: 155 t.addSelectorExpr(x, ts.Name.String(), fieldName.String()) 156 default: 157 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x) 158 } 159 } 160 } 161 } 162 return nil 163 } 164 165 func sourceFilter(fi os.FileInfo) bool { 166 return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix) 167 } 168 169 func (t *templateData) dump() error { 170 if len(t.Getters) == 0 { 171 logf("No getters for %v; skipping.", t.filename) 172 return nil 173 } 174 175 // Sort getters by ReceiverType.FieldName. 176 sort.Sort(byName(t.Getters)) 177 178 processTemplate := func(tmpl *template.Template, filename string) error { 179 var buf bytes.Buffer 180 if err := tmpl.Execute(&buf, t); err != nil { 181 return err 182 } 183 clean, err := format.Source(buf.Bytes()) 184 if err != nil { 185 return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err) 186 } 187 188 logf("Writing %v...", filename) 189 if err := ioutil.WriteFile(filename, clean, 0644); err != nil { 190 return err 191 } 192 193 return nil 194 } 195 196 if err := processTemplate(sourceTmpl, t.filename); err != nil { 197 return err 198 } 199 return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go")) 200 } 201 202 func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter { 203 return &getter{ 204 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName), 205 ReceiverVar: strings.ToLower(receiverType[:1]), 206 ReceiverType: receiverType, 207 FieldName: fieldName, 208 FieldType: fieldType, 209 ZeroValue: zeroValue, 210 NamedStruct: namedStruct, 211 } 212 } 213 214 func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) { 215 var eltType string 216 switch elt := x.Elt.(type) { 217 case *ast.Ident: 218 eltType = elt.String() 219 default: 220 logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt) 221 return 222 } 223 224 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false)) 225 } 226 227 func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) { 228 var zeroValue string 229 var namedStruct = false 230 switch x.String() { 231 case "int", "int64": 232 zeroValue = "0" 233 case "string": 234 zeroValue = `""` 235 case "bool": 236 zeroValue = "false" 237 case "Timestamp": 238 zeroValue = "Timestamp{}" 239 default: 240 zeroValue = "nil" 241 namedStruct = true 242 } 243 244 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct)) 245 } 246 247 func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) { 248 var keyType string 249 switch key := x.Key.(type) { 250 case *ast.Ident: 251 keyType = key.String() 252 default: 253 logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key) 254 return 255 } 256 257 var valueType string 258 switch value := x.Value.(type) { 259 case *ast.Ident: 260 valueType = value.String() 261 default: 262 logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value) 263 return 264 } 265 266 fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType) 267 zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType) 268 ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false) 269 ng.MapType = !isAPointer 270 t.Getters = append(t.Getters, ng) 271 } 272 273 func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) { 274 if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field. 275 return 276 } 277 278 var xX string 279 if xx, ok := x.X.(*ast.Ident); ok { 280 xX = xx.String() 281 } 282 283 switch xX { 284 case "time", "json": 285 if xX == "json" { 286 t.Imports["encoding/json"] = "encoding/json" 287 } else { 288 t.Imports[xX] = xX 289 } 290 fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name) 291 zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name) 292 if xX == "time" && x.Sel.Name == "Duration" { 293 zeroValue = "0" 294 } 295 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false)) 296 default: 297 logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x) 298 } 299 } 300 301 type templateData struct { 302 filename string 303 Year int 304 Package string 305 Imports map[string]string 306 Getters []*getter 307 } 308 309 type getter struct { 310 sortVal string // Lower-case version of "ReceiverType.FieldName". 311 ReceiverVar string // The one-letter variable name to match the ReceiverType. 312 ReceiverType string 313 FieldName string 314 FieldType string 315 ZeroValue string 316 NamedStruct bool // Getter for named struct. 317 MapType bool 318 } 319 320 type byName []*getter 321 322 func (b byName) Len() int { return len(b) } 323 func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal } 324 func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } 325 326 const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 327 // 328 // Use of this source code is governed by a BSD-style 329 // license that can be found in the LICENSE file. 330 331 // Code generated by gen-accessors; DO NOT EDIT. 332 333 package {{.Package}} 334 {{with .Imports}} 335 import ( 336 {{- range . -}} 337 "{{.}}" 338 {{end -}} 339 ) 340 {{end}} 341 {{range .Getters}} 342 {{if .NamedStruct}} 343 // Get{{.FieldName}} returns the {{.FieldName}} field. 344 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} { 345 if {{.ReceiverVar}} == nil { 346 return {{.ZeroValue}} 347 } 348 return {{.ReceiverVar}}.{{.FieldName}} 349 } 350 {{else if .MapType}} 351 // Get{{.FieldName}} returns the {{.FieldName}} map if it's non-nil, an empty map otherwise. 352 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 353 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 354 return {{.ZeroValue}} 355 } 356 return {{.ReceiverVar}}.{{.FieldName}} 357 } 358 {{else}} 359 // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise. 360 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 361 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 362 return {{.ZeroValue}} 363 } 364 return *{{.ReceiverVar}}.{{.FieldName}} 365 } 366 {{end}} 367 {{end}} 368 ` 369 370 const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 371 // 372 // Use of this source code is governed by a BSD-style 373 // license that can be found in the LICENSE file. 374 375 // Code generated by gen-accessors; DO NOT EDIT. 376 377 package {{.Package}} 378 {{with .Imports}} 379 import ( 380 "testing" 381 {{range . -}} 382 "{{.}}" 383 {{end -}} 384 ) 385 {{end}} 386 {{range .Getters}} 387 {{if .NamedStruct}} 388 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 389 {{.ReceiverVar}} := &{{.ReceiverType}}{} 390 {{.ReceiverVar}}.Get{{.FieldName}}() 391 {{.ReceiverVar}} = nil 392 {{.ReceiverVar}}.Get{{.FieldName}}() 393 } 394 {{else if .MapType}} 395 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 396 zeroValue := {{.FieldType}}{} 397 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue } 398 {{.ReceiverVar}}.Get{{.FieldName}}() 399 {{.ReceiverVar}} = &{{.ReceiverType}}{} 400 {{.ReceiverVar}}.Get{{.FieldName}}() 401 {{.ReceiverVar}} = nil 402 {{.ReceiverVar}}.Get{{.FieldName}}() 403 } 404 {{else}} 405 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 406 var zeroValue {{.FieldType}} 407 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue } 408 {{.ReceiverVar}}.Get{{.FieldName}}() 409 {{.ReceiverVar}} = &{{.ReceiverType}}{} 410 {{.ReceiverVar}}.Get{{.FieldName}}() 411 {{.ReceiverVar}} = nil 412 {{.ReceiverVar}}.Get{{.FieldName}}() 413 } 414 {{end}} 415 {{end}} 416 `