github.com/google/go-github/v69@v69.2.0/github/gen-accessors.go (about) 1 // Copyright 2017 The go-github AUTHORS. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 //go:build ignore 7 8 // gen-accessors generates accessor methods for structs with pointer fields. 9 // 10 // It is meant to be used by go-github contributors in conjunction with the 11 // go generate tool before sending a PR to GitHub. 12 // Please see the CONTRIBUTING.md file for more information. 13 package main 14 15 import ( 16 "bytes" 17 "flag" 18 "fmt" 19 "go/ast" 20 "go/format" 21 "go/parser" 22 "go/token" 23 "log" 24 "os" 25 "slices" 26 "strings" 27 "text/template" 28 ) 29 30 const ( 31 fileSuffix = "-accessors.go" 32 ) 33 34 var ( 35 verbose = flag.Bool("v", false, "Print verbose log messages") 36 37 sourceTmpl = template.Must(template.New("source").Parse(source)) 38 testTmpl = template.Must(template.New("test").Parse(test)) 39 40 // skipStructMethods lists "struct.method" combos to skip. 41 skipStructMethods = map[string]bool{ 42 "RepositoryContent.GetContent": true, 43 "Client.GetBaseURL": true, 44 "Client.GetUploadURL": true, 45 "ErrorResponse.GetResponse": true, 46 "RateLimitError.GetResponse": true, 47 "AbuseRateLimitError.GetResponse": true, 48 } 49 // skipStructs lists structs to skip. 50 skipStructs = map[string]bool{ 51 "Client": true, 52 } 53 54 // whitelistSliceGetters lists "struct.field" to add getter method 55 whitelistSliceGetters = map[string]bool{ 56 "PushEvent.Commits": true, 57 } 58 ) 59 60 func logf(fmt string, args ...interface{}) { 61 if *verbose { 62 log.Printf(fmt, args...) 63 } 64 } 65 66 func main() { 67 flag.Parse() 68 fset := token.NewFileSet() 69 70 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0) 71 if err != nil { 72 log.Fatal(err) 73 return 74 } 75 76 for pkgName, pkg := range pkgs { 77 t := &templateData{ 78 filename: pkgName + fileSuffix, 79 Year: 2017, 80 Package: pkgName, 81 Imports: map[string]string{}, 82 } 83 for filename, f := range pkg.Files { 84 logf("Processing %v...", filename) 85 if err := t.processAST(f); err != nil { 86 log.Fatal(err) 87 } 88 } 89 if err := t.dump(); err != nil { 90 log.Fatal(err) 91 } 92 } 93 logf("Done.") 94 } 95 96 func (t *templateData) processAST(f *ast.File) error { 97 for _, decl := range f.Decls { 98 gd, ok := decl.(*ast.GenDecl) 99 if !ok { 100 continue 101 } 102 for _, spec := range gd.Specs { 103 ts, ok := spec.(*ast.TypeSpec) 104 if !ok { 105 continue 106 } 107 // Skip unexported identifiers. 108 if !ts.Name.IsExported() { 109 logf("Struct %v is unexported; skipping.", ts.Name) 110 continue 111 } 112 // Check if the struct should be skipped. 113 if skipStructs[ts.Name.Name] { 114 logf("Struct %v is in skip list; skipping.", ts.Name) 115 continue 116 } 117 st, ok := ts.Type.(*ast.StructType) 118 if !ok { 119 continue 120 } 121 for _, field := range st.Fields.List { 122 if len(field.Names) == 0 { 123 continue 124 } 125 126 fieldName := field.Names[0] 127 // Skip unexported identifiers. 128 if !fieldName.IsExported() { 129 logf("Field %v is unexported; skipping.", fieldName) 130 continue 131 } 132 // Check if "struct.method" should be skipped. 133 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] { 134 logf("Method %v is skip list; skipping.", key) 135 continue 136 } 137 138 se, ok := field.Type.(*ast.StarExpr) 139 if !ok { 140 switch x := field.Type.(type) { 141 case *ast.MapType: 142 t.addMapType(x, ts.Name.String(), fieldName.String(), false) 143 continue 144 case *ast.ArrayType: 145 if key := fmt.Sprintf("%v.%v", ts.Name, fieldName); whitelistSliceGetters[key] { 146 logf("Method %v is whitelist; adding getter method.", key) 147 t.addArrayType(x, ts.Name.String(), fieldName.String(), false) 148 continue 149 } 150 } 151 152 logf("Skipping field type %T, fieldName=%v", field.Type, fieldName) 153 continue 154 } 155 156 switch x := se.X.(type) { 157 case *ast.ArrayType: 158 t.addArrayType(x, ts.Name.String(), fieldName.String(), true) 159 case *ast.Ident: 160 t.addIdent(x, ts.Name.String(), fieldName.String()) 161 case *ast.MapType: 162 t.addMapType(x, ts.Name.String(), fieldName.String(), true) 163 case *ast.SelectorExpr: 164 t.addSelectorExpr(x, ts.Name.String(), fieldName.String()) 165 default: 166 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x) 167 } 168 } 169 } 170 } 171 return nil 172 } 173 174 func sourceFilter(fi os.FileInfo) bool { 175 return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix) 176 } 177 178 func (t *templateData) dump() error { 179 if len(t.Getters) == 0 { 180 logf("No getters for %v; skipping.", t.filename) 181 return nil 182 } 183 184 // Sort getters by ReceiverType.FieldName. 185 slices.SortStableFunc(t.Getters, func(a, b *getter) int { 186 return strings.Compare(a.sortVal, b.sortVal) 187 }) 188 189 processTemplate := func(tmpl *template.Template, filename string) error { 190 var buf bytes.Buffer 191 if err := tmpl.Execute(&buf, t); err != nil { 192 return err 193 } 194 clean, err := format.Source(buf.Bytes()) 195 if err != nil { 196 return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err) 197 } 198 199 logf("Writing %v...", filename) 200 if err := os.Chmod(filename, 0644); err != nil { 201 return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err) 202 } 203 204 if err := os.WriteFile(filename, clean, 0444); err != nil { 205 return err 206 } 207 208 if err := os.Chmod(filename, 0444); err != nil { 209 return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err) 210 } 211 212 return nil 213 } 214 215 if err := processTemplate(sourceTmpl, t.filename); err != nil { 216 return err 217 } 218 return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go")) 219 } 220 221 func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter { 222 return &getter{ 223 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName), 224 ReceiverVar: strings.ToLower(receiverType[:1]), 225 ReceiverType: receiverType, 226 FieldName: fieldName, 227 FieldType: fieldType, 228 ZeroValue: zeroValue, 229 NamedStruct: namedStruct, 230 } 231 } 232 233 func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string, isAPointer bool) { 234 var eltType string 235 var ng *getter 236 switch elt := x.Elt.(type) { 237 case *ast.Ident: 238 eltType = elt.String() 239 ng = newGetter(receiverType, fieldName, "[]"+eltType, "nil", false) 240 case *ast.StarExpr: 241 ident, ok := elt.X.(*ast.Ident) 242 if !ok { 243 return 244 } 245 ng = newGetter(receiverType, fieldName, "[]*"+ident.String(), "nil", false) 246 default: 247 logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt) 248 return 249 } 250 251 ng.ArrayType = !isAPointer 252 t.Getters = append(t.Getters, ng) 253 } 254 255 func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) { 256 var zeroValue string 257 var namedStruct = false 258 switch x.String() { 259 case "int", "int64": 260 zeroValue = "0" 261 case "string": 262 zeroValue = `""` 263 case "bool": 264 zeroValue = "false" 265 case "Timestamp": 266 zeroValue = "Timestamp{}" 267 default: 268 zeroValue = "nil" 269 namedStruct = true 270 } 271 272 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct)) 273 } 274 275 func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) { 276 var keyType string 277 switch key := x.Key.(type) { 278 case *ast.Ident: 279 keyType = key.String() 280 default: 281 logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key) 282 return 283 } 284 285 var valueType string 286 switch value := x.Value.(type) { 287 case *ast.Ident: 288 valueType = value.String() 289 default: 290 logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value) 291 return 292 } 293 294 fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType) 295 zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType) 296 ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false) 297 ng.MapType = !isAPointer 298 t.Getters = append(t.Getters, ng) 299 } 300 301 func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) { 302 if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field. 303 return 304 } 305 306 var xX string 307 if xx, ok := x.X.(*ast.Ident); ok { 308 xX = xx.String() 309 } 310 311 switch xX { 312 case "time", "json": 313 if xX == "json" { 314 t.Imports["encoding/json"] = "encoding/json" 315 } else { 316 t.Imports[xX] = xX 317 } 318 fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name) 319 zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name) 320 if xX == "time" && x.Sel.Name == "Duration" { 321 zeroValue = "0" 322 } 323 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false)) 324 default: 325 logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x) 326 } 327 } 328 329 type templateData struct { 330 filename string 331 Year int 332 Package string 333 Imports map[string]string 334 Getters []*getter 335 } 336 337 type getter struct { 338 sortVal string // Lower-case version of "ReceiverType.FieldName". 339 ReceiverVar string // The one-letter variable name to match the ReceiverType. 340 ReceiverType string 341 FieldName string 342 FieldType string 343 ZeroValue string 344 NamedStruct bool // Getter for named struct. 345 MapType bool 346 ArrayType bool 347 } 348 349 const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 350 // 351 // Use of this source code is governed by a BSD-style 352 // license that can be found in the LICENSE file. 353 354 // Code generated by gen-accessors; DO NOT EDIT. 355 // Instead, please run "go generate ./..." as described here: 356 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch 357 358 package {{.Package}} 359 {{with .Imports}} 360 import ( 361 {{- range . -}} 362 "{{.}}" 363 {{end -}} 364 ) 365 {{end}} 366 {{range .Getters}} 367 {{if .NamedStruct}} 368 // Get{{.FieldName}} returns the {{.FieldName}} field. 369 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} { 370 if {{.ReceiverVar}} == nil { 371 return {{.ZeroValue}} 372 } 373 return {{.ReceiverVar}}.{{.FieldName}} 374 } 375 {{else if or .MapType .ArrayType }} 376 // Get{{.FieldName}} returns the {{.FieldName}} {{if .MapType}}map{{else if .ArrayType }}slice{{end}} if it's non-nil, {{if .MapType}}an empty map{{else if .ArrayType }}nil{{end}} otherwise. 377 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 378 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 379 return {{.ZeroValue}} 380 } 381 return {{.ReceiverVar}}.{{.FieldName}} 382 } 383 {{else}} 384 // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise. 385 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 386 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 387 return {{.ZeroValue}} 388 } 389 return *{{.ReceiverVar}}.{{.FieldName}} 390 } 391 {{end}} 392 {{end}} 393 ` 394 395 const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 396 // 397 // Use of this source code is governed by a BSD-style 398 // license that can be found in the LICENSE file. 399 400 // Code generated by gen-accessors; DO NOT EDIT. 401 // Instead, please run "go generate ./..." as described here: 402 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch 403 404 package {{.Package}} 405 {{with .Imports}} 406 import ( 407 "testing" 408 {{range . -}} 409 "{{.}}" 410 {{end -}} 411 ) 412 {{end}} 413 {{range .Getters}} 414 {{if .NamedStruct}} 415 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 416 tt.Parallel() 417 {{.ReceiverVar}} := &{{.ReceiverType}}{} 418 {{.ReceiverVar}}.Get{{.FieldName}}() 419 {{.ReceiverVar}} = nil 420 {{.ReceiverVar}}.Get{{.FieldName}}() 421 } 422 {{else if or .MapType .ArrayType}} 423 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 424 tt.Parallel() 425 zeroValue := {{.FieldType}}{} 426 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue } 427 {{.ReceiverVar}}.Get{{.FieldName}}() 428 {{.ReceiverVar}} = &{{.ReceiverType}}{} 429 {{.ReceiverVar}}.Get{{.FieldName}}() 430 {{.ReceiverVar}} = nil 431 {{.ReceiverVar}}.Get{{.FieldName}}() 432 } 433 {{else}} 434 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 435 tt.Parallel() 436 var zeroValue {{.FieldType}} 437 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue } 438 {{.ReceiverVar}}.Get{{.FieldName}}() 439 {{.ReceiverVar}} = &{{.ReceiverType}}{} 440 {{.ReceiverVar}}.Get{{.FieldName}}() 441 {{.ReceiverVar}} = nil 442 {{.ReceiverVar}}.Get{{.FieldName}}() 443 } 444 {{end}} 445 {{end}} 446 `