vitess.io/vitess@v0.16.2/go/tools/asthelpergen/asthelpergen.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package asthelpergen 18 19 import ( 20 "bytes" 21 "fmt" 22 "go/types" 23 "log" 24 "os" 25 "path" 26 "strings" 27 28 "vitess.io/vitess/go/tools/goimports" 29 30 "github.com/dave/jennifer/jen" 31 "golang.org/x/tools/go/packages" 32 ) 33 34 const licenseFileHeader = `Copyright 2023 The Vitess Authors. 35 36 Licensed under the Apache License, Version 2.0 (the "License"); 37 you may not use this file except in compliance with the License. 38 You may obtain a copy of the License at 39 40 http://www.apache.org/licenses/LICENSE-2.0 41 42 Unless required by applicable law or agreed to in writing, software 43 distributed under the License is distributed on an "AS IS" BASIS, 44 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 45 See the License for the specific language governing permissions and 46 limitations under the License.` 47 48 type ( 49 generatorSPI interface { 50 addType(t types.Type) 51 scope() *types.Scope 52 findImplementations(iff *types.Interface, impl func(types.Type) error) error 53 iface() *types.Interface 54 } 55 generator interface { 56 genFile() (string, *jen.File) 57 interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error 58 structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error 59 ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error 60 ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error 61 sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error 62 basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error 63 } 64 // astHelperGen finds implementations of the given interface, 65 // and uses the supplied `generator`s to produce the output code 66 astHelperGen struct { 67 DebugTypes bool 68 mod *packages.Module 69 sizes types.Sizes 70 namedIface *types.Named 71 _iface *types.Interface 72 gens []generator 73 74 _scope *types.Scope 75 todo []types.Type 76 } 77 ) 78 79 func (gen *astHelperGen) iface() *types.Interface { 80 return gen._iface 81 } 82 83 func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator) *astHelperGen { 84 return &astHelperGen{ 85 DebugTypes: true, 86 mod: mod, 87 sizes: sizes, 88 namedIface: named, 89 _iface: named.Underlying().(*types.Interface), 90 gens: generators, 91 } 92 } 93 94 func findImplementations(scope *types.Scope, iff *types.Interface, impl func(types.Type) error) error { 95 for _, name := range scope.Names() { 96 obj := scope.Lookup(name) 97 if _, ok := obj.(*types.TypeName); !ok { 98 continue 99 } 100 baseType := obj.Type() 101 if types.Implements(baseType, iff) { 102 err := impl(baseType) 103 if err != nil { 104 return err 105 } 106 continue 107 } 108 pointerT := types.NewPointer(baseType) 109 if types.Implements(pointerT, iff) { 110 err := impl(pointerT) 111 if err != nil { 112 return err 113 } 114 continue 115 } 116 } 117 return nil 118 } 119 func (gen *astHelperGen) findImplementations(iff *types.Interface, impl func(types.Type) error) error { 120 for _, name := range gen._scope.Names() { 121 obj := gen._scope.Lookup(name) 122 if _, ok := obj.(*types.TypeName); !ok { 123 continue 124 } 125 baseType := obj.Type() 126 if types.Implements(baseType, iff) { 127 err := impl(baseType) 128 if err != nil { 129 return err 130 } 131 continue 132 } 133 pointerT := types.NewPointer(baseType) 134 if types.Implements(pointerT, iff) { 135 err := impl(pointerT) 136 if err != nil { 137 return err 138 } 139 continue 140 } 141 } 142 return nil 143 } 144 145 // GenerateCode is the main loop where we build up the code per file. 146 func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) { 147 pkg := gen.namedIface.Obj().Pkg() 148 149 gen._scope = pkg.Scope() 150 gen.todo = append(gen.todo, gen.namedIface) 151 jenFiles := gen.createFiles() 152 153 result := map[string]*jen.File{} 154 for fName, genFile := range jenFiles { 155 fullPath := path.Join(gen.mod.Dir, strings.TrimPrefix(pkg.Path(), gen.mod.Path), fName) 156 result[fullPath] = genFile 157 } 158 159 return result, nil 160 } 161 162 // VerifyFilesOnDisk compares the generated results from the codegen against the files that 163 // currently exist on disk and returns any mismatches 164 func VerifyFilesOnDisk(result map[string]*jen.File) (errors []error) { 165 for fullPath, file := range result { 166 existing, err := os.ReadFile(fullPath) 167 if err != nil { 168 errors = append(errors, fmt.Errorf("missing file on disk: %s (%w)", fullPath, err)) 169 continue 170 } 171 172 genFile, err := goimports.FormatJenFile(file) 173 if err != nil { 174 errors = append(errors, fmt.Errorf("goimport error: %w", err)) 175 continue 176 } 177 178 if !bytes.Equal(existing, genFile) { 179 errors = append(errors, fmt.Errorf("'%s' has changed", fullPath)) 180 continue 181 } 182 } 183 return errors 184 } 185 186 var acceptableBuildErrorsOn = map[string]any{ 187 "ast_equals.go": nil, 188 "ast_clone.go": nil, 189 "ast_rewrite.go": nil, 190 "ast_visit.go": nil, 191 } 192 193 type Options struct { 194 Packages []string 195 RootInterface string 196 197 Clone CloneOptions 198 Equals EqualsOptions 199 } 200 201 // GenerateASTHelpers loads the input code, constructs the necessary generators, 202 // and generates the rewriter and clone methods for the AST 203 func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) { 204 loaded, err := packages.Load(&packages.Config{ 205 Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule, 206 }, options.Packages...) 207 208 if err != nil { 209 return nil, fmt.Errorf("failed to load packages: %w", err) 210 } 211 212 checkErrors(loaded, func(fileName string) bool { 213 _, ok := acceptableBuildErrorsOn[fileName] 214 return ok 215 }) 216 217 scopes := make(map[string]*types.Scope) 218 for _, pkg := range loaded { 219 scopes[pkg.PkgPath] = pkg.Types.Scope() 220 } 221 222 pos := strings.LastIndexByte(options.RootInterface, '.') 223 if pos < 0 { 224 return nil, fmt.Errorf("unexpected input type: %s", options.RootInterface) 225 } 226 227 pkgname := options.RootInterface[:pos] 228 typename := options.RootInterface[pos+1:] 229 230 scope := scopes[pkgname] 231 if scope == nil { 232 return nil, fmt.Errorf("no scope found for type '%s'", options.RootInterface) 233 } 234 235 tt := scope.Lookup(typename) 236 if tt == nil { 237 return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname) 238 } 239 240 nt := tt.Type().(*types.Named) 241 pName := nt.Obj().Pkg().Name() 242 generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt, 243 newEqualsGen(pName, &options.Equals), 244 newCloneGen(pName, &options.Clone), 245 newVisitGen(pName), 246 newRewriterGen(pName, types.TypeString(nt, noQualifier)), 247 newCOWGen(pName, nt), 248 ) 249 250 it, err := generator.GenerateCode() 251 if err != nil { 252 return nil, err 253 } 254 255 return it, nil 256 } 257 258 var _ generatorSPI = (*astHelperGen)(nil) 259 260 func (gen *astHelperGen) scope() *types.Scope { 261 return gen._scope 262 } 263 264 func (gen *astHelperGen) addType(t types.Type) { 265 gen.todo = append(gen.todo, t) 266 } 267 268 func (gen *astHelperGen) createFiles() map[string]*jen.File { 269 alreadyDone := map[string]bool{} 270 for len(gen.todo) > 0 { 271 t := gen.todo[0] 272 underlying := t.Underlying() 273 typeName := printableTypeName(t) 274 gen.todo = gen.todo[1:] 275 276 if alreadyDone[typeName] { 277 continue 278 } 279 var err error 280 for _, g := range gen.gens { 281 switch underlying := underlying.(type) { 282 case *types.Interface: 283 err = g.interfaceMethod(t, underlying, gen) 284 case *types.Slice: 285 err = g.sliceMethod(t, underlying, gen) 286 case *types.Struct: 287 err = g.structMethod(t, underlying, gen) 288 case *types.Pointer: 289 ptrToType := underlying.Elem().Underlying() 290 switch ptrToType := ptrToType.(type) { 291 case *types.Struct: 292 err = g.ptrToStructMethod(t, ptrToType, gen) 293 case *types.Basic: 294 err = g.ptrToBasicMethod(t, ptrToType, gen) 295 default: 296 panic(fmt.Sprintf("%T", ptrToType)) 297 } 298 case *types.Basic: 299 err = g.basicMethod(t, underlying, gen) 300 default: 301 log.Fatalf("don't know how to handle %s %T", typeName, underlying) 302 } 303 if err != nil { 304 log.Fatal(err) 305 } 306 } 307 alreadyDone[typeName] = true 308 } 309 310 result := map[string]*jen.File{} 311 for _, g := range gen.gens { 312 fName, jenFile := g.genFile() 313 result[fName] = jenFile 314 } 315 return result 316 } 317 318 // printableTypeName returns a string that can be used as a valid golang identifier 319 func printableTypeName(t types.Type) string { 320 switch t := t.(type) { 321 case *types.Pointer: 322 return "RefOf" + printableTypeName(t.Elem()) 323 case *types.Slice: 324 return "SliceOf" + printableTypeName(t.Elem()) 325 case *types.Named: 326 return t.Obj().Name() 327 case *types.Basic: 328 return strings.Title(t.Name()) // nolint 329 case *types.Interface: 330 return t.String() 331 default: 332 panic(fmt.Sprintf("unknown type %T %v", t, t)) 333 } 334 } 335 336 func checkErrors(loaded []*packages.Package, canSkipErrorOn func(fileName string) bool) { 337 for _, l := range loaded { 338 for _, e := range l.Errors { 339 idx := strings.Index(e.Pos, ":") 340 filePath := e.Pos[:idx] 341 _, fileName := path.Split(filePath) 342 if !canSkipErrorOn(fileName) { 343 log.Fatalf("error loading package %s", e.Error()) 344 } 345 } 346 } 347 }