vitess.io/vitess@v0.16.2/go/tools/sizegen/sizegen.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"go/types"
    23  	"log"
    24  	"os"
    25  	"path"
    26  	"sort"
    27  	"strings"
    28  
    29  	"github.com/dave/jennifer/jen"
    30  	"github.com/spf13/pflag"
    31  	"golang.org/x/tools/go/packages"
    32  
    33  	"vitess.io/vitess/go/hack"
    34  	"vitess.io/vitess/go/tools/common"
    35  	"vitess.io/vitess/go/tools/goimports"
    36  )
    37  
    38  const licenseFileHeader = `Copyright 2021 The Vitess Authors.
    39  
    40  Licensed under the Apache License, Version 2.0 (the "License");
    41  you may not use this file except in compliance with the License.
    42  You may obtain a copy of the License at
    43  
    44      http://www.apache.org/licenses/LICENSE-2.0
    45  
    46  Unless required by applicable law or agreed to in writing, software
    47  distributed under the License is distributed on an "AS IS" BASIS,
    48  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    49  See the License for the specific language governing permissions and
    50  limitations under the License.`
    51  
    52  type sizegen struct {
    53  	DebugTypes bool
    54  	mod        *packages.Module
    55  	sizes      types.Sizes
    56  	codegen    map[string]*codeFile
    57  	known      map[*types.Named]*typeState
    58  }
    59  
    60  type codeFlag uint32
    61  
    62  const (
    63  	codeWithInterface = 1 << 0
    64  	codeWithUnsafe    = 1 << 1
    65  )
    66  
    67  type codeImpl struct {
    68  	name  string
    69  	flags codeFlag
    70  	code  jen.Code
    71  }
    72  
    73  type codeFile struct {
    74  	pkg   string
    75  	impls []codeImpl
    76  }
    77  
    78  type typeState struct {
    79  	generated bool
    80  	local     bool
    81  	pod       bool // struct with only primitives
    82  }
    83  
    84  func newSizegen(mod *packages.Module, sizes types.Sizes) *sizegen {
    85  	return &sizegen{
    86  		DebugTypes: true,
    87  		mod:        mod,
    88  		sizes:      sizes,
    89  		known:      make(map[*types.Named]*typeState),
    90  		codegen:    make(map[string]*codeFile),
    91  	}
    92  }
    93  
    94  func isPod(tt types.Type) bool {
    95  	switch tt := tt.(type) {
    96  	case *types.Struct:
    97  		for i := 0; i < tt.NumFields(); i++ {
    98  			if !isPod(tt.Field(i).Type()) {
    99  				return false
   100  			}
   101  		}
   102  		return true
   103  	case *types.Named:
   104  		return isPod(tt.Underlying())
   105  	case *types.Basic:
   106  		switch tt.Kind() {
   107  		case types.String, types.UnsafePointer:
   108  			return false
   109  		}
   110  		return true
   111  	default:
   112  		return false
   113  	}
   114  }
   115  
   116  func (sizegen *sizegen) getKnownType(named *types.Named) *typeState {
   117  	ts := sizegen.known[named]
   118  	if ts == nil {
   119  		local := strings.HasPrefix(named.Obj().Pkg().Path(), sizegen.mod.Path)
   120  		ts = &typeState{
   121  			local: local,
   122  			pod:   isPod(named.Underlying()),
   123  		}
   124  		sizegen.known[named] = ts
   125  	}
   126  	return ts
   127  }
   128  
   129  func (sizegen *sizegen) generateType(pkg *types.Package, file *codeFile, named *types.Named) {
   130  	ts := sizegen.getKnownType(named)
   131  	if ts.generated {
   132  		return
   133  	}
   134  	ts.generated = true
   135  
   136  	switch tt := named.Underlying().(type) {
   137  	case *types.Struct:
   138  		if impl, flag := sizegen.sizeImplForStruct(named.Obj(), tt); impl != nil {
   139  			file.impls = append(file.impls, codeImpl{
   140  				code:  impl,
   141  				name:  named.String(),
   142  				flags: flag,
   143  			})
   144  		}
   145  	case *types.Interface:
   146  		findImplementations(pkg.Scope(), tt, func(tt types.Type) {
   147  			if _, isStruct := tt.Underlying().(*types.Struct); isStruct {
   148  				sizegen.generateType(pkg, file, tt.(*types.Named))
   149  			}
   150  		})
   151  	default:
   152  		// no-op
   153  	}
   154  }
   155  
   156  func (sizegen *sizegen) generateKnownType(named *types.Named) {
   157  	pkgInfo := named.Obj().Pkg()
   158  	file := sizegen.codegen[pkgInfo.Path()]
   159  	if file == nil {
   160  		file = &codeFile{pkg: pkgInfo.Name()}
   161  		sizegen.codegen[pkgInfo.Path()] = file
   162  	}
   163  
   164  	sizegen.generateType(pkgInfo, file, named)
   165  }
   166  
   167  func findImplementations(scope *types.Scope, iff *types.Interface, impl func(types.Type)) {
   168  	for _, name := range scope.Names() {
   169  		obj := scope.Lookup(name)
   170  		baseType := obj.Type()
   171  		if types.Implements(baseType, iff) || types.Implements(types.NewPointer(baseType), iff) {
   172  			impl(baseType)
   173  		}
   174  	}
   175  }
   176  
   177  func (sizegen *sizegen) finalize() map[string]*jen.File {
   178  	var complete bool
   179  
   180  	for !complete {
   181  		complete = true
   182  		for tt, ts := range sizegen.known {
   183  			isComplex := !ts.pod
   184  			notYetGenerated := !ts.generated
   185  			if ts.local && isComplex && notYetGenerated {
   186  				sizegen.generateKnownType(tt)
   187  				complete = false
   188  			}
   189  		}
   190  	}
   191  
   192  	outputFiles := make(map[string]*jen.File)
   193  
   194  	for pkg, file := range sizegen.codegen {
   195  		if len(file.impls) == 0 {
   196  			continue
   197  		}
   198  		if !strings.HasPrefix(pkg, sizegen.mod.Path) {
   199  			log.Printf("failed to generate code for foreign package '%s'", pkg)
   200  			log.Printf("DEBUG:\n%#v", file)
   201  			continue
   202  		}
   203  
   204  		sort.Slice(file.impls, func(i, j int) bool {
   205  			return strings.Compare(file.impls[i].name, file.impls[j].name) < 0
   206  		})
   207  
   208  		out := jen.NewFile(file.pkg)
   209  		out.HeaderComment(licenseFileHeader)
   210  		out.HeaderComment("Code generated by Sizegen. DO NOT EDIT.")
   211  
   212  		for _, impl := range file.impls {
   213  			if impl.flags&codeWithInterface != 0 {
   214  				out.Add(jen.Type().Id("cachedObject").InterfaceFunc(func(i *jen.Group) {
   215  					i.Id("CachedSize").Params(jen.Id("alloc").Id("bool")).Int64()
   216  				}))
   217  				break
   218  			}
   219  		}
   220  
   221  		for _, impl := range file.impls {
   222  			if impl.flags&codeWithUnsafe != 0 {
   223  				out.Commentf("//go:nocheckptr")
   224  			}
   225  			out.Add(impl.code)
   226  		}
   227  
   228  		fullPath := path.Join(sizegen.mod.Dir, strings.TrimPrefix(pkg, sizegen.mod.Path), "cached_size.go")
   229  		outputFiles[fullPath] = out
   230  	}
   231  
   232  	return outputFiles
   233  }
   234  
   235  func (sizegen *sizegen) sizeImplForStruct(name *types.TypeName, st *types.Struct) (jen.Code, codeFlag) {
   236  	if sizegen.sizes.Sizeof(st) == 0 {
   237  		return nil, 0
   238  	}
   239  
   240  	var stmt []jen.Code
   241  	var funcFlags codeFlag
   242  	for i := 0; i < st.NumFields(); i++ {
   243  		field := st.Field(i)
   244  		fieldType := field.Type()
   245  		fieldName := jen.Id("cached").Dot(field.Name())
   246  
   247  		fieldStmt, flag := sizegen.sizeStmtForType(fieldName, fieldType, false)
   248  		if fieldStmt != nil {
   249  			if sizegen.DebugTypes {
   250  				stmt = append(stmt, jen.Commentf("%s", field.String()))
   251  			}
   252  			stmt = append(stmt, fieldStmt)
   253  		}
   254  		funcFlags |= flag
   255  	}
   256  
   257  	f := jen.Func()
   258  	f.Params(jen.Id("cached").Op("*").Id(name.Name()))
   259  	f.Id("CachedSize").Params(jen.Id("alloc").Id("bool")).Int64()
   260  	f.BlockFunc(func(b *jen.Group) {
   261  		b.Add(jen.If(jen.Id("cached").Op("==").Nil()).Block(jen.Return(jen.Lit(int64(0)))))
   262  		b.Add(jen.Id("size").Op(":=").Lit(int64(0)))
   263  		b.Add(jen.If(jen.Id("alloc")).Block(
   264  			jen.Id("size").Op("+=").Lit(hack.RuntimeAllocSize(sizegen.sizes.Sizeof(st))),
   265  		))
   266  		for _, s := range stmt {
   267  			b.Add(s)
   268  		}
   269  		b.Add(jen.Return(jen.Id("size")))
   270  	})
   271  	return f, funcFlags
   272  }
   273  
   274  func (sizegen *sizegen) sizeStmtForMap(fieldName *jen.Statement, m *types.Map) []jen.Code {
   275  	const bucketCnt = 8
   276  	const sizeofHmap = int64(6 * 8)
   277  
   278  	/*
   279  		type bmap struct {
   280  			// tophash generally contains the top byte of the hash value
   281  			// for each key in this bucket. If tophash[0] < minTopHash,
   282  			// tophash[0] is a bucket evacuation state instead.
   283  			tophash [bucketCnt]uint8
   284  			// Followed by bucketCnt keys and then bucketCnt elems.
   285  			// NOTE: packing all the keys together and then all the elems together makes the
   286  			// code a bit more complicated than alternating key/elem/key/elem/... but it allows
   287  			// us to eliminate padding which would be needed for, e.g., map[int64]int8.
   288  			// Followed by an overflow pointer.
   289  		}
   290  	*/
   291  	sizeOfBucket := int(
   292  		bucketCnt + // tophash
   293  			bucketCnt*sizegen.sizes.Sizeof(m.Key()) +
   294  			bucketCnt*sizegen.sizes.Sizeof(m.Elem()) +
   295  			8, // overflow pointer
   296  	)
   297  
   298  	return []jen.Code{
   299  		jen.Id("size").Op("+=").Lit(hack.RuntimeAllocSize(sizeofHmap)),
   300  
   301  		jen.Id("hmap").Op(":=").Qual("reflect", "ValueOf").Call(fieldName),
   302  
   303  		jen.Id("numBuckets").Op(":=").Id("int").Call(
   304  			jen.Qual("math", "Pow").Call(jen.Lit(2), jen.Id("float64").Call(
   305  				jen.Parens(jen.Op("*").Parens(jen.Op("*").Id("uint8")).Call(
   306  					jen.Qual("unsafe", "Pointer").Call(jen.Id("hmap").Dot("Pointer").Call().
   307  						Op("+").Id("uintptr").Call(jen.Lit(9)))))))),
   308  
   309  		jen.Id("numOldBuckets").Op(":=").Parens(jen.Op("*").Parens(jen.Op("*").Id("uint16")).Call(
   310  			jen.Qual("unsafe", "Pointer").Call(
   311  				jen.Id("hmap").Dot("Pointer").Call().Op("+").Id("uintptr").Call(jen.Lit(10))))),
   312  
   313  		jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Id("numOldBuckets").Op("*").Lit(sizeOfBucket)))),
   314  
   315  		jen.If(jen.Id("len").Call(fieldName).Op(">").Lit(0).Op("||").Id("numBuckets").Op(">").Lit(1)).Block(
   316  			jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Id("numBuckets").Op("*").Lit(sizeOfBucket))))),
   317  	}
   318  }
   319  
   320  func mallocsize(sizeStmt *jen.Statement) func(*jen.Statement) {
   321  	return func(parent *jen.Statement) {
   322  		parent.Qual("vitess.io/vitess/go/hack", "RuntimeAllocSize").Call(sizeStmt)
   323  	}
   324  }
   325  
   326  func (sizegen *sizegen) sizeStmtForArray(stmt []jen.Code, fieldName *jen.Statement, elemT types.Type) ([]jen.Code, codeFlag) {
   327  	var flag codeFlag
   328  
   329  	switch sizegen.sizes.Sizeof(elemT) {
   330  	case 0:
   331  		return nil, 0
   332  
   333  	case 1:
   334  		stmt = append(stmt, jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Cap(fieldName)))))
   335  
   336  	default:
   337  		var nested jen.Code
   338  		nested, flag = sizegen.sizeStmtForType(jen.Id("elem"), elemT, false)
   339  
   340  		stmt = append(stmt,
   341  			jen.Id("size").
   342  				Op("+=").
   343  				Do(mallocsize(jen.Int64().Call(jen.Cap(fieldName)).
   344  					Op("*").
   345  					Lit(sizegen.sizes.Sizeof(elemT))),
   346  				))
   347  
   348  		if nested != nil {
   349  			stmt = append(stmt, jen.For(jen.List(jen.Id("_"), jen.Id("elem")).Op(":=").Range().Add(fieldName)).Block(nested))
   350  		}
   351  	}
   352  
   353  	return stmt, flag
   354  }
   355  
   356  func (sizegen *sizegen) sizeStmtForType(fieldName *jen.Statement, field types.Type, alloc bool) (jen.Code, codeFlag) {
   357  	if sizegen.sizes.Sizeof(field) == 0 {
   358  		return nil, 0
   359  	}
   360  
   361  	switch node := field.(type) {
   362  	case *types.Slice:
   363  		var cond *jen.Statement
   364  		var stmt []jen.Code
   365  		var flag codeFlag
   366  
   367  		if alloc {
   368  			cond = jen.If(fieldName.Clone().Op("!=").Nil())
   369  			fieldName = jen.Op("*").Add(fieldName)
   370  			stmt = append(stmt, jen.Id("size").Op("+=").Lit(hack.RuntimeAllocSize(8*3)))
   371  		}
   372  
   373  		stmt, flag = sizegen.sizeStmtForArray(stmt, fieldName, node.Elem())
   374  		if cond != nil {
   375  			return cond.Block(stmt...), flag
   376  		}
   377  		return jen.Block(stmt...), flag
   378  
   379  	case *types.Array:
   380  		if alloc {
   381  			cond := jen.If(fieldName.Clone().Op("!=").Nil())
   382  			fieldName = jen.Op("*").Add(fieldName)
   383  
   384  			stmt, flag := sizegen.sizeStmtForArray(nil, fieldName, node.Elem())
   385  			return cond.Block(stmt...), flag
   386  		}
   387  
   388  		elemT := node.Elem()
   389  		if sizegen.sizes.Sizeof(elemT) > 1 {
   390  			nested, flag := sizegen.sizeStmtForType(jen.Id("elem"), elemT, false)
   391  			if nested != nil {
   392  				return jen.For(jen.List(jen.Id("_"), jen.Id("elem")).Op(":=").Range().Add(fieldName)).Block(nested), flag
   393  			}
   394  		}
   395  		return nil, 0
   396  
   397  	case *types.Map:
   398  		keySize, keyFlag := sizegen.sizeStmtForType(jen.Id("k"), node.Key(), false)
   399  		valSize, valFlag := sizegen.sizeStmtForType(jen.Id("v"), node.Elem(), false)
   400  
   401  		return jen.If(fieldName.Clone().Op("!=").Nil()).BlockFunc(func(block *jen.Group) {
   402  			for _, stmt := range sizegen.sizeStmtForMap(fieldName, node) {
   403  				block.Add(stmt)
   404  			}
   405  
   406  			var forLoopVars []jen.Code
   407  			switch {
   408  			case keySize != nil && valSize != nil:
   409  				forLoopVars = []jen.Code{jen.Id("k"), jen.Id("v")}
   410  			case keySize == nil && valSize != nil:
   411  				forLoopVars = []jen.Code{jen.Id("_"), jen.Id("v")}
   412  			case keySize != nil && valSize == nil:
   413  				forLoopVars = []jen.Code{jen.Id("k")}
   414  			case keySize == nil && valSize == nil:
   415  				return
   416  			}
   417  
   418  			block.Add(jen.For(jen.List(forLoopVars...).Op(":=").Range().Add(fieldName))).BlockFunc(func(b *jen.Group) {
   419  				if keySize != nil {
   420  					b.Add(keySize)
   421  				}
   422  				if valSize != nil {
   423  					b.Add(valSize)
   424  				}
   425  			})
   426  		}), codeWithUnsafe | keyFlag | valFlag
   427  
   428  	case *types.Pointer:
   429  		return sizegen.sizeStmtForType(fieldName, node.Elem(), true)
   430  
   431  	case *types.Named:
   432  		ts := sizegen.getKnownType(node)
   433  		if ts.pod || !ts.local {
   434  			if alloc {
   435  				if !ts.local {
   436  					log.Printf("WARNING: size of external type %s cannot be fully calculated", node)
   437  				}
   438  				return jen.If(fieldName.Clone().Op("!=").Nil()).Block(
   439  					jen.Id("size").Op("+=").Do(mallocsize(jen.Lit(sizegen.sizes.Sizeof(node.Underlying())))),
   440  				), 0
   441  			}
   442  			return nil, 0
   443  		}
   444  		return sizegen.sizeStmtForType(fieldName, node.Underlying(), alloc)
   445  
   446  	case *types.Interface:
   447  		if node.Empty() {
   448  			return nil, 0
   449  		}
   450  		return jen.If(
   451  			jen.List(
   452  				jen.Id("cc"), jen.Id("ok")).
   453  				Op(":=").
   454  				Add(fieldName.Clone().Assert(jen.Id("cachedObject"))),
   455  			jen.Id("ok"),
   456  		).Block(
   457  			jen.Id("size").
   458  				Op("+=").
   459  				Id("cc").
   460  				Dot("CachedSize").
   461  				Call(jen.True()),
   462  		), codeWithInterface
   463  
   464  	case *types.Struct:
   465  		return jen.Id("size").Op("+=").Add(fieldName.Clone().Dot("CachedSize").Call(jen.Lit(alloc))), 0
   466  
   467  	case *types.Basic:
   468  		if !alloc {
   469  			if node.Info()&types.IsString != 0 {
   470  				return jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Len(fieldName)))), 0
   471  			}
   472  			return nil, 0
   473  		}
   474  		return jen.Id("size").Op("+=").Do(mallocsize(jen.Lit(sizegen.sizes.Sizeof(node)))), 0
   475  
   476  	case *types.Signature:
   477  		// assume that function pointers do not allocate (although they might, if they're closures)
   478  		return nil, 0
   479  
   480  	default:
   481  		log.Printf("unhandled type: %T", node)
   482  		return nil, 0
   483  	}
   484  }
   485  
   486  type typePaths []string
   487  
   488  func (t *typePaths) String() string {
   489  	return fmt.Sprintf("%v", *t)
   490  }
   491  
   492  func (t *typePaths) Set(path string) error {
   493  	*t = append(*t, path)
   494  	return nil
   495  }
   496  
   497  func main() {
   498  	var (
   499  		patterns, generate []string
   500  		verify             bool
   501  	)
   502  
   503  	pflag.StringSliceVar(&patterns, "in", nil, "Go packages to load the generator")
   504  	pflag.StringSliceVar(&generate, "gen", nil, "Typename of the Go struct to generate size info for")
   505  	pflag.BoolVar(&verify, "verify", false, "ensure that the generated files are correct")
   506  	pflag.Parse()
   507  
   508  	result, err := GenerateSizeHelpers(patterns, generate)
   509  	if err != nil {
   510  		log.Fatal(err)
   511  	}
   512  
   513  	if verify {
   514  		for _, err := range VerifyFilesOnDisk(result) {
   515  			log.Fatal(err)
   516  		}
   517  		log.Printf("%d files OK", len(result))
   518  	} else {
   519  		for fullPath, file := range result {
   520  			content, err := goimports.FormatJenFile(file)
   521  			if err != nil {
   522  				log.Fatalf("failed to apply goimport to '%s': %v", fullPath, err)
   523  			}
   524  			err = os.WriteFile(fullPath, content, 0664)
   525  			if err != nil {
   526  				log.Fatalf("failed to save file to '%s': %v", fullPath, err)
   527  			}
   528  		}
   529  	}
   530  }
   531  
   532  // VerifyFilesOnDisk compares the generated results from the codegen against the files that
   533  // currently exist on disk and returns any mismatches. All the files generated by jennifer
   534  // are formatted using the goimports command. Any difference in the imports will also make
   535  // this test fail.
   536  func VerifyFilesOnDisk(result map[string]*jen.File) (errors []error) {
   537  	for fullPath, file := range result {
   538  		existing, err := os.ReadFile(fullPath)
   539  		if err != nil {
   540  			errors = append(errors, fmt.Errorf("missing file on disk: %s (%w)", fullPath, err))
   541  			continue
   542  		}
   543  
   544  		genFile, err := goimports.FormatJenFile(file)
   545  		if err != nil {
   546  			errors = append(errors, fmt.Errorf("goimport error: %w", err))
   547  			continue
   548  		}
   549  
   550  		if !bytes.Equal(existing, genFile) {
   551  			errors = append(errors, fmt.Errorf("'%s' has changed", fullPath))
   552  			continue
   553  		}
   554  	}
   555  	return errors
   556  }
   557  
   558  // GenerateSizeHelpers generates the auxiliary code that implements CachedSize helper methods
   559  // for all the types listed in typePatterns
   560  func GenerateSizeHelpers(packagePatterns []string, typePatterns []string) (map[string]*jen.File, error) {
   561  	loaded, err := packages.Load(&packages.Config{
   562  		Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule,
   563  	}, packagePatterns...)
   564  
   565  	if err != nil {
   566  		return nil, err
   567  	}
   568  
   569  	if common.PkgFailed(loaded) {
   570  		return nil, fmt.Errorf("failed to load packages")
   571  	}
   572  
   573  	sizegen := newSizegen(loaded[0].Module, loaded[0].TypesSizes)
   574  
   575  	scopes := make(map[string]*types.Scope)
   576  	for _, pkg := range loaded {
   577  		scopes[pkg.PkgPath] = pkg.Types.Scope()
   578  	}
   579  
   580  	for _, gen := range typePatterns {
   581  		pos := strings.LastIndexByte(gen, '.')
   582  		if pos < 0 {
   583  			return nil, fmt.Errorf("unexpected input type: %s", gen)
   584  		}
   585  
   586  		pkgname := gen[:pos]
   587  		typename := gen[pos+1:]
   588  
   589  		scope := scopes[pkgname]
   590  		if scope == nil {
   591  			return nil, fmt.Errorf("no scope found for type '%s'", gen)
   592  		}
   593  
   594  		if typename == "*" {
   595  			for _, name := range scope.Names() {
   596  				sizegen.generateKnownType(scope.Lookup(name).Type().(*types.Named))
   597  			}
   598  		} else {
   599  			tt := scope.Lookup(typename)
   600  			if tt == nil {
   601  				return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname)
   602  			}
   603  
   604  			sizegen.generateKnownType(tt.Type().(*types.Named))
   605  		}
   606  	}
   607  
   608  	return sizegen.finalize(), nil
   609  }