vitess.io/vitess@v0.16.2/go/tools/asthelpergen/asthelpergen.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package asthelpergen
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"go/types"
    23  	"log"
    24  	"os"
    25  	"path"
    26  	"strings"
    27  
    28  	"vitess.io/vitess/go/tools/goimports"
    29  
    30  	"github.com/dave/jennifer/jen"
    31  	"golang.org/x/tools/go/packages"
    32  )
    33  
    34  const licenseFileHeader = `Copyright 2023 The Vitess Authors.
    35  
    36  Licensed under the Apache License, Version 2.0 (the "License");
    37  you may not use this file except in compliance with the License.
    38  You may obtain a copy of the License at
    39  
    40      http://www.apache.org/licenses/LICENSE-2.0
    41  
    42  Unless required by applicable law or agreed to in writing, software
    43  distributed under the License is distributed on an "AS IS" BASIS,
    44  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    45  See the License for the specific language governing permissions and
    46  limitations under the License.`
    47  
    48  type (
    49  	generatorSPI interface {
    50  		addType(t types.Type)
    51  		scope() *types.Scope
    52  		findImplementations(iff *types.Interface, impl func(types.Type) error) error
    53  		iface() *types.Interface
    54  	}
    55  	generator interface {
    56  		genFile() (string, *jen.File)
    57  		interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error
    58  		structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error
    59  		ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error
    60  		ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error
    61  		sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error
    62  		basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error
    63  	}
    64  	// astHelperGen finds implementations of the given interface,
    65  	// and uses the supplied `generator`s to produce the output code
    66  	astHelperGen struct {
    67  		DebugTypes bool
    68  		mod        *packages.Module
    69  		sizes      types.Sizes
    70  		namedIface *types.Named
    71  		_iface     *types.Interface
    72  		gens       []generator
    73  
    74  		_scope *types.Scope
    75  		todo   []types.Type
    76  	}
    77  )
    78  
    79  func (gen *astHelperGen) iface() *types.Interface {
    80  	return gen._iface
    81  }
    82  
    83  func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator) *astHelperGen {
    84  	return &astHelperGen{
    85  		DebugTypes: true,
    86  		mod:        mod,
    87  		sizes:      sizes,
    88  		namedIface: named,
    89  		_iface:     named.Underlying().(*types.Interface),
    90  		gens:       generators,
    91  	}
    92  }
    93  
    94  func findImplementations(scope *types.Scope, iff *types.Interface, impl func(types.Type) error) error {
    95  	for _, name := range scope.Names() {
    96  		obj := scope.Lookup(name)
    97  		if _, ok := obj.(*types.TypeName); !ok {
    98  			continue
    99  		}
   100  		baseType := obj.Type()
   101  		if types.Implements(baseType, iff) {
   102  			err := impl(baseType)
   103  			if err != nil {
   104  				return err
   105  			}
   106  			continue
   107  		}
   108  		pointerT := types.NewPointer(baseType)
   109  		if types.Implements(pointerT, iff) {
   110  			err := impl(pointerT)
   111  			if err != nil {
   112  				return err
   113  			}
   114  			continue
   115  		}
   116  	}
   117  	return nil
   118  }
   119  func (gen *astHelperGen) findImplementations(iff *types.Interface, impl func(types.Type) error) error {
   120  	for _, name := range gen._scope.Names() {
   121  		obj := gen._scope.Lookup(name)
   122  		if _, ok := obj.(*types.TypeName); !ok {
   123  			continue
   124  		}
   125  		baseType := obj.Type()
   126  		if types.Implements(baseType, iff) {
   127  			err := impl(baseType)
   128  			if err != nil {
   129  				return err
   130  			}
   131  			continue
   132  		}
   133  		pointerT := types.NewPointer(baseType)
   134  		if types.Implements(pointerT, iff) {
   135  			err := impl(pointerT)
   136  			if err != nil {
   137  				return err
   138  			}
   139  			continue
   140  		}
   141  	}
   142  	return nil
   143  }
   144  
   145  // GenerateCode is the main loop where we build up the code per file.
   146  func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) {
   147  	pkg := gen.namedIface.Obj().Pkg()
   148  
   149  	gen._scope = pkg.Scope()
   150  	gen.todo = append(gen.todo, gen.namedIface)
   151  	jenFiles := gen.createFiles()
   152  
   153  	result := map[string]*jen.File{}
   154  	for fName, genFile := range jenFiles {
   155  		fullPath := path.Join(gen.mod.Dir, strings.TrimPrefix(pkg.Path(), gen.mod.Path), fName)
   156  		result[fullPath] = genFile
   157  	}
   158  
   159  	return result, nil
   160  }
   161  
   162  // VerifyFilesOnDisk compares the generated results from the codegen against the files that
   163  // currently exist on disk and returns any mismatches
   164  func VerifyFilesOnDisk(result map[string]*jen.File) (errors []error) {
   165  	for fullPath, file := range result {
   166  		existing, err := os.ReadFile(fullPath)
   167  		if err != nil {
   168  			errors = append(errors, fmt.Errorf("missing file on disk: %s (%w)", fullPath, err))
   169  			continue
   170  		}
   171  
   172  		genFile, err := goimports.FormatJenFile(file)
   173  		if err != nil {
   174  			errors = append(errors, fmt.Errorf("goimport error: %w", err))
   175  			continue
   176  		}
   177  
   178  		if !bytes.Equal(existing, genFile) {
   179  			errors = append(errors, fmt.Errorf("'%s' has changed", fullPath))
   180  			continue
   181  		}
   182  	}
   183  	return errors
   184  }
   185  
   186  var acceptableBuildErrorsOn = map[string]any{
   187  	"ast_equals.go":  nil,
   188  	"ast_clone.go":   nil,
   189  	"ast_rewrite.go": nil,
   190  	"ast_visit.go":   nil,
   191  }
   192  
   193  type Options struct {
   194  	Packages      []string
   195  	RootInterface string
   196  
   197  	Clone  CloneOptions
   198  	Equals EqualsOptions
   199  }
   200  
   201  // GenerateASTHelpers loads the input code, constructs the necessary generators,
   202  // and generates the rewriter and clone methods for the AST
   203  func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) {
   204  	loaded, err := packages.Load(&packages.Config{
   205  		Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule,
   206  	}, options.Packages...)
   207  
   208  	if err != nil {
   209  		return nil, fmt.Errorf("failed to load packages: %w", err)
   210  	}
   211  
   212  	checkErrors(loaded, func(fileName string) bool {
   213  		_, ok := acceptableBuildErrorsOn[fileName]
   214  		return ok
   215  	})
   216  
   217  	scopes := make(map[string]*types.Scope)
   218  	for _, pkg := range loaded {
   219  		scopes[pkg.PkgPath] = pkg.Types.Scope()
   220  	}
   221  
   222  	pos := strings.LastIndexByte(options.RootInterface, '.')
   223  	if pos < 0 {
   224  		return nil, fmt.Errorf("unexpected input type: %s", options.RootInterface)
   225  	}
   226  
   227  	pkgname := options.RootInterface[:pos]
   228  	typename := options.RootInterface[pos+1:]
   229  
   230  	scope := scopes[pkgname]
   231  	if scope == nil {
   232  		return nil, fmt.Errorf("no scope found for type '%s'", options.RootInterface)
   233  	}
   234  
   235  	tt := scope.Lookup(typename)
   236  	if tt == nil {
   237  		return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname)
   238  	}
   239  
   240  	nt := tt.Type().(*types.Named)
   241  	pName := nt.Obj().Pkg().Name()
   242  	generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt,
   243  		newEqualsGen(pName, &options.Equals),
   244  		newCloneGen(pName, &options.Clone),
   245  		newVisitGen(pName),
   246  		newRewriterGen(pName, types.TypeString(nt, noQualifier)),
   247  		newCOWGen(pName, nt),
   248  	)
   249  
   250  	it, err := generator.GenerateCode()
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	return it, nil
   256  }
   257  
   258  var _ generatorSPI = (*astHelperGen)(nil)
   259  
   260  func (gen *astHelperGen) scope() *types.Scope {
   261  	return gen._scope
   262  }
   263  
   264  func (gen *astHelperGen) addType(t types.Type) {
   265  	gen.todo = append(gen.todo, t)
   266  }
   267  
   268  func (gen *astHelperGen) createFiles() map[string]*jen.File {
   269  	alreadyDone := map[string]bool{}
   270  	for len(gen.todo) > 0 {
   271  		t := gen.todo[0]
   272  		underlying := t.Underlying()
   273  		typeName := printableTypeName(t)
   274  		gen.todo = gen.todo[1:]
   275  
   276  		if alreadyDone[typeName] {
   277  			continue
   278  		}
   279  		var err error
   280  		for _, g := range gen.gens {
   281  			switch underlying := underlying.(type) {
   282  			case *types.Interface:
   283  				err = g.interfaceMethod(t, underlying, gen)
   284  			case *types.Slice:
   285  				err = g.sliceMethod(t, underlying, gen)
   286  			case *types.Struct:
   287  				err = g.structMethod(t, underlying, gen)
   288  			case *types.Pointer:
   289  				ptrToType := underlying.Elem().Underlying()
   290  				switch ptrToType := ptrToType.(type) {
   291  				case *types.Struct:
   292  					err = g.ptrToStructMethod(t, ptrToType, gen)
   293  				case *types.Basic:
   294  					err = g.ptrToBasicMethod(t, ptrToType, gen)
   295  				default:
   296  					panic(fmt.Sprintf("%T", ptrToType))
   297  				}
   298  			case *types.Basic:
   299  				err = g.basicMethod(t, underlying, gen)
   300  			default:
   301  				log.Fatalf("don't know how to handle %s %T", typeName, underlying)
   302  			}
   303  			if err != nil {
   304  				log.Fatal(err)
   305  			}
   306  		}
   307  		alreadyDone[typeName] = true
   308  	}
   309  
   310  	result := map[string]*jen.File{}
   311  	for _, g := range gen.gens {
   312  		fName, jenFile := g.genFile()
   313  		result[fName] = jenFile
   314  	}
   315  	return result
   316  }
   317  
   318  // printableTypeName returns a string that can be used as a valid golang identifier
   319  func printableTypeName(t types.Type) string {
   320  	switch t := t.(type) {
   321  	case *types.Pointer:
   322  		return "RefOf" + printableTypeName(t.Elem())
   323  	case *types.Slice:
   324  		return "SliceOf" + printableTypeName(t.Elem())
   325  	case *types.Named:
   326  		return t.Obj().Name()
   327  	case *types.Basic:
   328  		return strings.Title(t.Name()) // nolint
   329  	case *types.Interface:
   330  		return t.String()
   331  	default:
   332  		panic(fmt.Sprintf("unknown type %T %v", t, t))
   333  	}
   334  }
   335  
   336  func checkErrors(loaded []*packages.Package, canSkipErrorOn func(fileName string) bool) {
   337  	for _, l := range loaded {
   338  		for _, e := range l.Errors {
   339  			idx := strings.Index(e.Pos, ":")
   340  			filePath := e.Pos[:idx]
   341  			_, fileName := path.Split(filePath)
   342  			if !canSkipErrorOn(fileName) {
   343  				log.Fatalf("error loading package %s", e.Error())
   344  			}
   345  		}
   346  	}
   347  }