github.com/google/go-github/v70@v70.0.0/github/gen-accessors.go (about) 1 // Copyright 2017 The go-github AUTHORS. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 //go:build ignore 7 8 // gen-accessors generates accessor methods for structs with pointer fields. 9 // 10 // It is meant to be used by go-github contributors in conjunction with the 11 // go generate tool before sending a PR to GitHub. 12 // Please see the CONTRIBUTING.md file for more information. 13 package main 14 15 import ( 16 "bytes" 17 "flag" 18 "fmt" 19 "go/ast" 20 "go/format" 21 "go/parser" 22 "go/token" 23 "log" 24 "os" 25 "slices" 26 "strings" 27 "text/template" 28 ) 29 30 const ( 31 fileSuffix = "-accessors.go" 32 ) 33 34 var ( 35 verbose = flag.Bool("v", false, "Print verbose log messages") 36 37 sourceTmpl = template.Must(template.New("source").Parse(source)) 38 testTmpl = template.Must(template.New("test").Parse(test)) 39 40 // skipStructMethods lists "struct.method" combos to skip. 41 skipStructMethods = map[string]bool{ 42 "RepositoryContent.GetContent": true, 43 "Client.GetBaseURL": true, 44 "Client.GetUploadURL": true, 45 "ErrorResponse.GetResponse": true, 46 "RateLimitError.GetResponse": true, 47 "AbuseRateLimitError.GetResponse": true, 48 "PackageVersion.GetBody": true, 49 "PackageVersion.GetMetadata": true, 50 } 51 // skipStructs lists structs to skip. 52 skipStructs = map[string]bool{ 53 "Client": true, 54 } 55 56 // whitelistSliceGetters lists "struct.field" to add getter method 57 whitelistSliceGetters = map[string]bool{ 58 "PushEvent.Commits": true, 59 } 60 ) 61 62 func logf(fmt string, args ...interface{}) { 63 if *verbose { 64 log.Printf(fmt, args...) 65 } 66 } 67 68 func main() { 69 flag.Parse() 70 fset := token.NewFileSet() 71 72 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0) 73 if err != nil { 74 log.Fatal(err) 75 return 76 } 77 78 for pkgName, pkg := range pkgs { 79 t := &templateData{ 80 filename: pkgName + fileSuffix, 81 Year: 2017, 82 Package: pkgName, 83 Imports: map[string]string{}, 84 } 85 for filename, f := range pkg.Files { 86 logf("Processing %v...", filename) 87 if err := t.processAST(f); err != nil { 88 log.Fatal(err) 89 } 90 } 91 if err := t.dump(); err != nil { 92 log.Fatal(err) 93 } 94 } 95 logf("Done.") 96 } 97 98 func (t *templateData) processAST(f *ast.File) error { 99 for _, decl := range f.Decls { 100 gd, ok := decl.(*ast.GenDecl) 101 if !ok { 102 continue 103 } 104 for _, spec := range gd.Specs { 105 ts, ok := spec.(*ast.TypeSpec) 106 if !ok { 107 continue 108 } 109 // Skip unexported identifiers. 110 if !ts.Name.IsExported() { 111 logf("Struct %v is unexported; skipping.", ts.Name) 112 continue 113 } 114 // Check if the struct should be skipped. 115 if skipStructs[ts.Name.Name] { 116 logf("Struct %v is in skip list; skipping.", ts.Name) 117 continue 118 } 119 st, ok := ts.Type.(*ast.StructType) 120 if !ok { 121 continue 122 } 123 for _, field := range st.Fields.List { 124 if len(field.Names) == 0 { 125 continue 126 } 127 128 fieldName := field.Names[0] 129 // Skip unexported identifiers. 130 if !fieldName.IsExported() { 131 logf("Field %v is unexported; skipping.", fieldName) 132 continue 133 } 134 // Check if "struct.method" should be skipped. 135 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] { 136 logf("Method %v is skip list; skipping.", key) 137 continue 138 } 139 140 se, ok := field.Type.(*ast.StarExpr) 141 if !ok { 142 switch x := field.Type.(type) { 143 case *ast.MapType: 144 t.addMapType(x, ts.Name.String(), fieldName.String(), false) 145 continue 146 case *ast.ArrayType: 147 if key := fmt.Sprintf("%v.%v", ts.Name, fieldName); whitelistSliceGetters[key] { 148 logf("Method %v is whitelist; adding getter method.", key) 149 t.addArrayType(x, ts.Name.String(), fieldName.String(), false) 150 continue 151 } 152 } 153 154 logf("Skipping field type %T, fieldName=%v", field.Type, fieldName) 155 continue 156 } 157 158 switch x := se.X.(type) { 159 case *ast.ArrayType: 160 t.addArrayType(x, ts.Name.String(), fieldName.String(), true) 161 case *ast.Ident: 162 t.addIdent(x, ts.Name.String(), fieldName.String()) 163 case *ast.MapType: 164 t.addMapType(x, ts.Name.String(), fieldName.String(), true) 165 case *ast.SelectorExpr: 166 t.addSelectorExpr(x, ts.Name.String(), fieldName.String()) 167 default: 168 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x) 169 } 170 } 171 } 172 } 173 return nil 174 } 175 176 func sourceFilter(fi os.FileInfo) bool { 177 return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix) 178 } 179 180 func (t *templateData) dump() error { 181 if len(t.Getters) == 0 { 182 logf("No getters for %v; skipping.", t.filename) 183 return nil 184 } 185 186 // Sort getters by ReceiverType.FieldName. 187 slices.SortStableFunc(t.Getters, func(a, b *getter) int { 188 return strings.Compare(a.sortVal, b.sortVal) 189 }) 190 191 processTemplate := func(tmpl *template.Template, filename string) error { 192 var buf bytes.Buffer 193 if err := tmpl.Execute(&buf, t); err != nil { 194 return err 195 } 196 clean, err := format.Source(buf.Bytes()) 197 if err != nil { 198 return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err) 199 } 200 201 logf("Writing %v...", filename) 202 if err := os.Chmod(filename, 0644); err != nil { 203 return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err) 204 } 205 206 if err := os.WriteFile(filename, clean, 0444); err != nil { 207 return err 208 } 209 210 if err := os.Chmod(filename, 0444); err != nil { 211 return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err) 212 } 213 214 return nil 215 } 216 217 if err := processTemplate(sourceTmpl, t.filename); err != nil { 218 return err 219 } 220 return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go")) 221 } 222 223 func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter { 224 return &getter{ 225 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName), 226 ReceiverVar: strings.ToLower(receiverType[:1]), 227 ReceiverType: receiverType, 228 FieldName: fieldName, 229 FieldType: fieldType, 230 ZeroValue: zeroValue, 231 NamedStruct: namedStruct, 232 } 233 } 234 235 func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string, isAPointer bool) { 236 var eltType string 237 var ng *getter 238 switch elt := x.Elt.(type) { 239 case *ast.Ident: 240 eltType = elt.String() 241 ng = newGetter(receiverType, fieldName, "[]"+eltType, "nil", false) 242 case *ast.StarExpr: 243 ident, ok := elt.X.(*ast.Ident) 244 if !ok { 245 return 246 } 247 ng = newGetter(receiverType, fieldName, "[]*"+ident.String(), "nil", false) 248 default: 249 logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt) 250 return 251 } 252 253 ng.ArrayType = !isAPointer 254 t.Getters = append(t.Getters, ng) 255 } 256 257 func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) { 258 var zeroValue string 259 var namedStruct = false 260 switch x.String() { 261 case "int", "int64": 262 zeroValue = "0" 263 case "string": 264 zeroValue = `""` 265 case "bool": 266 zeroValue = "false" 267 case "Timestamp": 268 zeroValue = "Timestamp{}" 269 default: 270 zeroValue = "nil" 271 namedStruct = true 272 } 273 274 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct)) 275 } 276 277 func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) { 278 var keyType string 279 switch key := x.Key.(type) { 280 case *ast.Ident: 281 keyType = key.String() 282 default: 283 logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key) 284 return 285 } 286 287 var valueType string 288 switch value := x.Value.(type) { 289 case *ast.Ident: 290 valueType = value.String() 291 default: 292 logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value) 293 return 294 } 295 296 fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType) 297 zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType) 298 ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false) 299 ng.MapType = !isAPointer 300 t.Getters = append(t.Getters, ng) 301 } 302 303 func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) { 304 if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field. 305 return 306 } 307 308 var xX string 309 if xx, ok := x.X.(*ast.Ident); ok { 310 xX = xx.String() 311 } 312 313 switch xX { 314 case "time", "json": 315 if xX == "json" { 316 t.Imports["encoding/json"] = "encoding/json" 317 } else { 318 t.Imports[xX] = xX 319 } 320 fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name) 321 zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name) 322 if xX == "time" && x.Sel.Name == "Duration" { 323 zeroValue = "0" 324 } 325 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false)) 326 default: 327 logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x) 328 } 329 } 330 331 type templateData struct { 332 filename string 333 Year int 334 Package string 335 Imports map[string]string 336 Getters []*getter 337 } 338 339 type getter struct { 340 sortVal string // Lower-case version of "ReceiverType.FieldName". 341 ReceiverVar string // The one-letter variable name to match the ReceiverType. 342 ReceiverType string 343 FieldName string 344 FieldType string 345 ZeroValue string 346 NamedStruct bool // Getter for named struct. 347 MapType bool 348 ArrayType bool 349 } 350 351 const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 352 // 353 // Use of this source code is governed by a BSD-style 354 // license that can be found in the LICENSE file. 355 356 // Code generated by gen-accessors; DO NOT EDIT. 357 // Instead, please run "go generate ./..." as described here: 358 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch 359 360 package {{.Package}} 361 {{with .Imports}} 362 import ( 363 {{- range . -}} 364 "{{.}}" 365 {{end -}} 366 ) 367 {{end}} 368 {{range .Getters}} 369 {{if .NamedStruct}} 370 // Get{{.FieldName}} returns the {{.FieldName}} field. 371 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} { 372 if {{.ReceiverVar}} == nil { 373 return {{.ZeroValue}} 374 } 375 return {{.ReceiverVar}}.{{.FieldName}} 376 } 377 {{else if or .MapType .ArrayType }} 378 // Get{{.FieldName}} returns the {{.FieldName}} {{if .MapType}}map{{else if .ArrayType }}slice{{end}} if it's non-nil, {{if .MapType}}an empty map{{else if .ArrayType }}nil{{end}} otherwise. 379 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 380 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 381 return {{.ZeroValue}} 382 } 383 return {{.ReceiverVar}}.{{.FieldName}} 384 } 385 {{else}} 386 // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise. 387 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 388 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 389 return {{.ZeroValue}} 390 } 391 return *{{.ReceiverVar}}.{{.FieldName}} 392 } 393 {{end}} 394 {{end}} 395 ` 396 397 const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 398 // 399 // Use of this source code is governed by a BSD-style 400 // license that can be found in the LICENSE file. 401 402 // Code generated by gen-accessors; DO NOT EDIT. 403 // Instead, please run "go generate ./..." as described here: 404 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch 405 406 package {{.Package}} 407 {{with .Imports}} 408 import ( 409 "testing" 410 {{range . -}} 411 "{{.}}" 412 {{end -}} 413 ) 414 {{end}} 415 {{range .Getters}} 416 {{if .NamedStruct}} 417 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 418 tt.Parallel() 419 {{.ReceiverVar}} := &{{.ReceiverType}}{} 420 {{.ReceiverVar}}.Get{{.FieldName}}() 421 {{.ReceiverVar}} = nil 422 {{.ReceiverVar}}.Get{{.FieldName}}() 423 } 424 {{else if or .MapType .ArrayType}} 425 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 426 tt.Parallel() 427 zeroValue := {{.FieldType}}{} 428 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue } 429 {{.ReceiverVar}}.Get{{.FieldName}}() 430 {{.ReceiverVar}} = &{{.ReceiverType}}{} 431 {{.ReceiverVar}}.Get{{.FieldName}}() 432 {{.ReceiverVar}} = nil 433 {{.ReceiverVar}}.Get{{.FieldName}}() 434 } 435 {{else}} 436 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 437 tt.Parallel() 438 var zeroValue {{.FieldType}} 439 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue } 440 {{.ReceiverVar}}.Get{{.FieldName}}() 441 {{.ReceiverVar}} = &{{.ReceiverType}}{} 442 {{.ReceiverVar}}.Get{{.FieldName}}() 443 {{.ReceiverVar}} = nil 444 {{.ReceiverVar}}.Get{{.FieldName}}() 445 } 446 {{end}} 447 {{end}} 448 `