github.com/google/go-github/v68@v68.0.0/github/gen-accessors.go (about) 1 // Copyright 2017 The go-github AUTHORS. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 //go:build ignore 7 8 // gen-accessors generates accessor methods for structs with pointer fields. 9 // 10 // It is meant to be used by go-github contributors in conjunction with the 11 // go generate tool before sending a PR to GitHub. 12 // Please see the CONTRIBUTING.md file for more information. 13 package main 14 15 import ( 16 "bytes" 17 "flag" 18 "fmt" 19 "go/ast" 20 "go/format" 21 "go/parser" 22 "go/token" 23 "log" 24 "os" 25 "sort" 26 "strings" 27 "text/template" 28 ) 29 30 const ( 31 fileSuffix = "-accessors.go" 32 ) 33 34 var ( 35 verbose = flag.Bool("v", false, "Print verbose log messages") 36 37 sourceTmpl = template.Must(template.New("source").Parse(source)) 38 testTmpl = template.Must(template.New("test").Parse(test)) 39 40 // skipStructMethods lists "struct.method" combos to skip. 41 skipStructMethods = map[string]bool{ 42 "RepositoryContent.GetContent": true, 43 "Client.GetBaseURL": true, 44 "Client.GetUploadURL": true, 45 "ErrorResponse.GetResponse": true, 46 "RateLimitError.GetResponse": true, 47 "AbuseRateLimitError.GetResponse": true, 48 } 49 // skipStructs lists structs to skip. 50 skipStructs = map[string]bool{ 51 "Client": true, 52 } 53 54 // whitelistSliceGetters lists "struct.field" to add getter method 55 whitelistSliceGetters = map[string]bool{ 56 "PushEvent.Commits": true, 57 } 58 ) 59 60 func logf(fmt string, args ...interface{}) { 61 if *verbose { 62 log.Printf(fmt, args...) 63 } 64 } 65 66 func main() { 67 flag.Parse() 68 fset := token.NewFileSet() 69 70 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0) 71 if err != nil { 72 log.Fatal(err) 73 return 74 } 75 76 for pkgName, pkg := range pkgs { 77 t := &templateData{ 78 filename: pkgName + fileSuffix, 79 Year: 2017, 80 Package: pkgName, 81 Imports: map[string]string{}, 82 } 83 for filename, f := range pkg.Files { 84 logf("Processing %v...", filename) 85 if err := t.processAST(f); err != nil { 86 log.Fatal(err) 87 } 88 } 89 if err := t.dump(); err != nil { 90 log.Fatal(err) 91 } 92 } 93 logf("Done.") 94 } 95 96 func (t *templateData) processAST(f *ast.File) error { 97 for _, decl := range f.Decls { 98 gd, ok := decl.(*ast.GenDecl) 99 if !ok { 100 continue 101 } 102 for _, spec := range gd.Specs { 103 ts, ok := spec.(*ast.TypeSpec) 104 if !ok { 105 continue 106 } 107 // Skip unexported identifiers. 108 if !ts.Name.IsExported() { 109 logf("Struct %v is unexported; skipping.", ts.Name) 110 continue 111 } 112 // Check if the struct should be skipped. 113 if skipStructs[ts.Name.Name] { 114 logf("Struct %v is in skip list; skipping.", ts.Name) 115 continue 116 } 117 st, ok := ts.Type.(*ast.StructType) 118 if !ok { 119 continue 120 } 121 for _, field := range st.Fields.List { 122 if len(field.Names) == 0 { 123 continue 124 } 125 126 fieldName := field.Names[0] 127 // Skip unexported identifiers. 128 if !fieldName.IsExported() { 129 logf("Field %v is unexported; skipping.", fieldName) 130 continue 131 } 132 // Check if "struct.method" should be skipped. 133 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] { 134 logf("Method %v is skip list; skipping.", key) 135 continue 136 } 137 138 se, ok := field.Type.(*ast.StarExpr) 139 if !ok { 140 switch x := field.Type.(type) { 141 case *ast.MapType: 142 t.addMapType(x, ts.Name.String(), fieldName.String(), false) 143 continue 144 case *ast.ArrayType: 145 if key := fmt.Sprintf("%v.%v", ts.Name, fieldName); whitelistSliceGetters[key] { 146 logf("Method %v is whitelist; adding getter method.", key) 147 t.addArrayType(x, ts.Name.String(), fieldName.String(), false) 148 continue 149 } 150 } 151 152 logf("Skipping field type %T, fieldName=%v", field.Type, fieldName) 153 continue 154 } 155 156 switch x := se.X.(type) { 157 case *ast.ArrayType: 158 t.addArrayType(x, ts.Name.String(), fieldName.String(), true) 159 case *ast.Ident: 160 t.addIdent(x, ts.Name.String(), fieldName.String()) 161 case *ast.MapType: 162 t.addMapType(x, ts.Name.String(), fieldName.String(), true) 163 case *ast.SelectorExpr: 164 t.addSelectorExpr(x, ts.Name.String(), fieldName.String()) 165 default: 166 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x) 167 } 168 } 169 } 170 } 171 return nil 172 } 173 174 func sourceFilter(fi os.FileInfo) bool { 175 return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix) 176 } 177 178 func (t *templateData) dump() error { 179 if len(t.Getters) == 0 { 180 logf("No getters for %v; skipping.", t.filename) 181 return nil 182 } 183 184 // Sort getters by ReceiverType.FieldName. 185 sort.Sort(byName(t.Getters)) 186 187 processTemplate := func(tmpl *template.Template, filename string) error { 188 var buf bytes.Buffer 189 if err := tmpl.Execute(&buf, t); err != nil { 190 return err 191 } 192 clean, err := format.Source(buf.Bytes()) 193 if err != nil { 194 return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err) 195 } 196 197 logf("Writing %v...", filename) 198 if err := os.Chmod(filename, 0644); err != nil { 199 return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err) 200 } 201 202 if err := os.WriteFile(filename, clean, 0444); err != nil { 203 return err 204 } 205 206 if err := os.Chmod(filename, 0444); err != nil { 207 return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err) 208 } 209 210 return nil 211 } 212 213 if err := processTemplate(sourceTmpl, t.filename); err != nil { 214 return err 215 } 216 return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go")) 217 } 218 219 func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter { 220 return &getter{ 221 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName), 222 ReceiverVar: strings.ToLower(receiverType[:1]), 223 ReceiverType: receiverType, 224 FieldName: fieldName, 225 FieldType: fieldType, 226 ZeroValue: zeroValue, 227 NamedStruct: namedStruct, 228 } 229 } 230 231 func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string, isAPointer bool) { 232 var eltType string 233 var ng *getter 234 switch elt := x.Elt.(type) { 235 case *ast.Ident: 236 eltType = elt.String() 237 ng = newGetter(receiverType, fieldName, "[]"+eltType, "nil", false) 238 case *ast.StarExpr: 239 ident, ok := elt.X.(*ast.Ident) 240 if !ok { 241 return 242 } 243 ng = newGetter(receiverType, fieldName, "[]*"+ident.String(), "nil", false) 244 default: 245 logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt) 246 return 247 } 248 249 ng.ArrayType = !isAPointer 250 t.Getters = append(t.Getters, ng) 251 } 252 253 func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) { 254 var zeroValue string 255 var namedStruct = false 256 switch x.String() { 257 case "int", "int64": 258 zeroValue = "0" 259 case "string": 260 zeroValue = `""` 261 case "bool": 262 zeroValue = "false" 263 case "Timestamp": 264 zeroValue = "Timestamp{}" 265 default: 266 zeroValue = "nil" 267 namedStruct = true 268 } 269 270 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct)) 271 } 272 273 func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) { 274 var keyType string 275 switch key := x.Key.(type) { 276 case *ast.Ident: 277 keyType = key.String() 278 default: 279 logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key) 280 return 281 } 282 283 var valueType string 284 switch value := x.Value.(type) { 285 case *ast.Ident: 286 valueType = value.String() 287 default: 288 logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value) 289 return 290 } 291 292 fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType) 293 zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType) 294 ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false) 295 ng.MapType = !isAPointer 296 t.Getters = append(t.Getters, ng) 297 } 298 299 func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) { 300 if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field. 301 return 302 } 303 304 var xX string 305 if xx, ok := x.X.(*ast.Ident); ok { 306 xX = xx.String() 307 } 308 309 switch xX { 310 case "time", "json": 311 if xX == "json" { 312 t.Imports["encoding/json"] = "encoding/json" 313 } else { 314 t.Imports[xX] = xX 315 } 316 fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name) 317 zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name) 318 if xX == "time" && x.Sel.Name == "Duration" { 319 zeroValue = "0" 320 } 321 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false)) 322 default: 323 logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x) 324 } 325 } 326 327 type templateData struct { 328 filename string 329 Year int 330 Package string 331 Imports map[string]string 332 Getters []*getter 333 } 334 335 type getter struct { 336 sortVal string // Lower-case version of "ReceiverType.FieldName". 337 ReceiverVar string // The one-letter variable name to match the ReceiverType. 338 ReceiverType string 339 FieldName string 340 FieldType string 341 ZeroValue string 342 NamedStruct bool // Getter for named struct. 343 MapType bool 344 ArrayType bool 345 } 346 347 type byName []*getter 348 349 func (b byName) Len() int { return len(b) } 350 func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal } 351 func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } 352 353 const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 354 // 355 // Use of this source code is governed by a BSD-style 356 // license that can be found in the LICENSE file. 357 358 // Code generated by gen-accessors; DO NOT EDIT. 359 // Instead, please run "go generate ./..." as described here: 360 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch 361 362 package {{.Package}} 363 {{with .Imports}} 364 import ( 365 {{- range . -}} 366 "{{.}}" 367 {{end -}} 368 ) 369 {{end}} 370 {{range .Getters}} 371 {{if .NamedStruct}} 372 // Get{{.FieldName}} returns the {{.FieldName}} field. 373 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} { 374 if {{.ReceiverVar}} == nil { 375 return {{.ZeroValue}} 376 } 377 return {{.ReceiverVar}}.{{.FieldName}} 378 } 379 {{else if or .MapType .ArrayType }} 380 // Get{{.FieldName}} returns the {{.FieldName}} {{if .MapType}}map{{else if .ArrayType }}slice{{end}} if it's non-nil, {{if .MapType}}an empty map{{else if .ArrayType }}nil{{end}} otherwise. 381 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 382 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 383 return {{.ZeroValue}} 384 } 385 return {{.ReceiverVar}}.{{.FieldName}} 386 } 387 {{else}} 388 // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise. 389 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 390 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 391 return {{.ZeroValue}} 392 } 393 return *{{.ReceiverVar}}.{{.FieldName}} 394 } 395 {{end}} 396 {{end}} 397 ` 398 399 const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 400 // 401 // Use of this source code is governed by a BSD-style 402 // license that can be found in the LICENSE file. 403 404 // Code generated by gen-accessors; DO NOT EDIT. 405 // Instead, please run "go generate ./..." as described here: 406 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch 407 408 package {{.Package}} 409 {{with .Imports}} 410 import ( 411 "testing" 412 {{range . -}} 413 "{{.}}" 414 {{end -}} 415 ) 416 {{end}} 417 {{range .Getters}} 418 {{if .NamedStruct}} 419 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 420 tt.Parallel() 421 {{.ReceiverVar}} := &{{.ReceiverType}}{} 422 {{.ReceiverVar}}.Get{{.FieldName}}() 423 {{.ReceiverVar}} = nil 424 {{.ReceiverVar}}.Get{{.FieldName}}() 425 } 426 {{else if or .MapType .ArrayType}} 427 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 428 tt.Parallel() 429 zeroValue := {{.FieldType}}{} 430 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue } 431 {{.ReceiverVar}}.Get{{.FieldName}}() 432 {{.ReceiverVar}} = &{{.ReceiverType}}{} 433 {{.ReceiverVar}}.Get{{.FieldName}}() 434 {{.ReceiverVar}} = nil 435 {{.ReceiverVar}}.Get{{.FieldName}}() 436 } 437 {{else}} 438 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { 439 tt.Parallel() 440 var zeroValue {{.FieldType}} 441 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue } 442 {{.ReceiverVar}}.Get{{.FieldName}}() 443 {{.ReceiverVar}} = &{{.ReceiverType}}{} 444 {{.ReceiverVar}}.Get{{.FieldName}}() 445 {{.ReceiverVar}} = nil 446 {{.ReceiverVar}}.Get{{.FieldName}}() 447 } 448 {{end}} 449 {{end}} 450 `