vitess.io/vitess@v0.16.2/go/tools/sizegen/sizegen.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package main 18 19 import ( 20 "bytes" 21 "fmt" 22 "go/types" 23 "log" 24 "os" 25 "path" 26 "sort" 27 "strings" 28 29 "github.com/dave/jennifer/jen" 30 "github.com/spf13/pflag" 31 "golang.org/x/tools/go/packages" 32 33 "vitess.io/vitess/go/hack" 34 "vitess.io/vitess/go/tools/common" 35 "vitess.io/vitess/go/tools/goimports" 36 ) 37 38 const licenseFileHeader = `Copyright 2021 The Vitess Authors. 39 40 Licensed under the Apache License, Version 2.0 (the "License"); 41 you may not use this file except in compliance with the License. 42 You may obtain a copy of the License at 43 44 http://www.apache.org/licenses/LICENSE-2.0 45 46 Unless required by applicable law or agreed to in writing, software 47 distributed under the License is distributed on an "AS IS" BASIS, 48 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 49 See the License for the specific language governing permissions and 50 limitations under the License.` 51 52 type sizegen struct { 53 DebugTypes bool 54 mod *packages.Module 55 sizes types.Sizes 56 codegen map[string]*codeFile 57 known map[*types.Named]*typeState 58 } 59 60 type codeFlag uint32 61 62 const ( 63 codeWithInterface = 1 << 0 64 codeWithUnsafe = 1 << 1 65 ) 66 67 type codeImpl struct { 68 name string 69 flags codeFlag 70 code jen.Code 71 } 72 73 type codeFile struct { 74 pkg string 75 impls []codeImpl 76 } 77 78 type typeState struct { 79 generated bool 80 local bool 81 pod bool // struct with only primitives 82 } 83 84 func newSizegen(mod *packages.Module, sizes types.Sizes) *sizegen { 85 return &sizegen{ 86 DebugTypes: true, 87 mod: mod, 88 sizes: sizes, 89 known: make(map[*types.Named]*typeState), 90 codegen: make(map[string]*codeFile), 91 } 92 } 93 94 func isPod(tt types.Type) bool { 95 switch tt := tt.(type) { 96 case *types.Struct: 97 for i := 0; i < tt.NumFields(); i++ { 98 if !isPod(tt.Field(i).Type()) { 99 return false 100 } 101 } 102 return true 103 case *types.Named: 104 return isPod(tt.Underlying()) 105 case *types.Basic: 106 switch tt.Kind() { 107 case types.String, types.UnsafePointer: 108 return false 109 } 110 return true 111 default: 112 return false 113 } 114 } 115 116 func (sizegen *sizegen) getKnownType(named *types.Named) *typeState { 117 ts := sizegen.known[named] 118 if ts == nil { 119 local := strings.HasPrefix(named.Obj().Pkg().Path(), sizegen.mod.Path) 120 ts = &typeState{ 121 local: local, 122 pod: isPod(named.Underlying()), 123 } 124 sizegen.known[named] = ts 125 } 126 return ts 127 } 128 129 func (sizegen *sizegen) generateType(pkg *types.Package, file *codeFile, named *types.Named) { 130 ts := sizegen.getKnownType(named) 131 if ts.generated { 132 return 133 } 134 ts.generated = true 135 136 switch tt := named.Underlying().(type) { 137 case *types.Struct: 138 if impl, flag := sizegen.sizeImplForStruct(named.Obj(), tt); impl != nil { 139 file.impls = append(file.impls, codeImpl{ 140 code: impl, 141 name: named.String(), 142 flags: flag, 143 }) 144 } 145 case *types.Interface: 146 findImplementations(pkg.Scope(), tt, func(tt types.Type) { 147 if _, isStruct := tt.Underlying().(*types.Struct); isStruct { 148 sizegen.generateType(pkg, file, tt.(*types.Named)) 149 } 150 }) 151 default: 152 // no-op 153 } 154 } 155 156 func (sizegen *sizegen) generateKnownType(named *types.Named) { 157 pkgInfo := named.Obj().Pkg() 158 file := sizegen.codegen[pkgInfo.Path()] 159 if file == nil { 160 file = &codeFile{pkg: pkgInfo.Name()} 161 sizegen.codegen[pkgInfo.Path()] = file 162 } 163 164 sizegen.generateType(pkgInfo, file, named) 165 } 166 167 func findImplementations(scope *types.Scope, iff *types.Interface, impl func(types.Type)) { 168 for _, name := range scope.Names() { 169 obj := scope.Lookup(name) 170 baseType := obj.Type() 171 if types.Implements(baseType, iff) || types.Implements(types.NewPointer(baseType), iff) { 172 impl(baseType) 173 } 174 } 175 } 176 177 func (sizegen *sizegen) finalize() map[string]*jen.File { 178 var complete bool 179 180 for !complete { 181 complete = true 182 for tt, ts := range sizegen.known { 183 isComplex := !ts.pod 184 notYetGenerated := !ts.generated 185 if ts.local && isComplex && notYetGenerated { 186 sizegen.generateKnownType(tt) 187 complete = false 188 } 189 } 190 } 191 192 outputFiles := make(map[string]*jen.File) 193 194 for pkg, file := range sizegen.codegen { 195 if len(file.impls) == 0 { 196 continue 197 } 198 if !strings.HasPrefix(pkg, sizegen.mod.Path) { 199 log.Printf("failed to generate code for foreign package '%s'", pkg) 200 log.Printf("DEBUG:\n%#v", file) 201 continue 202 } 203 204 sort.Slice(file.impls, func(i, j int) bool { 205 return strings.Compare(file.impls[i].name, file.impls[j].name) < 0 206 }) 207 208 out := jen.NewFile(file.pkg) 209 out.HeaderComment(licenseFileHeader) 210 out.HeaderComment("Code generated by Sizegen. DO NOT EDIT.") 211 212 for _, impl := range file.impls { 213 if impl.flags&codeWithInterface != 0 { 214 out.Add(jen.Type().Id("cachedObject").InterfaceFunc(func(i *jen.Group) { 215 i.Id("CachedSize").Params(jen.Id("alloc").Id("bool")).Int64() 216 })) 217 break 218 } 219 } 220 221 for _, impl := range file.impls { 222 if impl.flags&codeWithUnsafe != 0 { 223 out.Commentf("//go:nocheckptr") 224 } 225 out.Add(impl.code) 226 } 227 228 fullPath := path.Join(sizegen.mod.Dir, strings.TrimPrefix(pkg, sizegen.mod.Path), "cached_size.go") 229 outputFiles[fullPath] = out 230 } 231 232 return outputFiles 233 } 234 235 func (sizegen *sizegen) sizeImplForStruct(name *types.TypeName, st *types.Struct) (jen.Code, codeFlag) { 236 if sizegen.sizes.Sizeof(st) == 0 { 237 return nil, 0 238 } 239 240 var stmt []jen.Code 241 var funcFlags codeFlag 242 for i := 0; i < st.NumFields(); i++ { 243 field := st.Field(i) 244 fieldType := field.Type() 245 fieldName := jen.Id("cached").Dot(field.Name()) 246 247 fieldStmt, flag := sizegen.sizeStmtForType(fieldName, fieldType, false) 248 if fieldStmt != nil { 249 if sizegen.DebugTypes { 250 stmt = append(stmt, jen.Commentf("%s", field.String())) 251 } 252 stmt = append(stmt, fieldStmt) 253 } 254 funcFlags |= flag 255 } 256 257 f := jen.Func() 258 f.Params(jen.Id("cached").Op("*").Id(name.Name())) 259 f.Id("CachedSize").Params(jen.Id("alloc").Id("bool")).Int64() 260 f.BlockFunc(func(b *jen.Group) { 261 b.Add(jen.If(jen.Id("cached").Op("==").Nil()).Block(jen.Return(jen.Lit(int64(0))))) 262 b.Add(jen.Id("size").Op(":=").Lit(int64(0))) 263 b.Add(jen.If(jen.Id("alloc")).Block( 264 jen.Id("size").Op("+=").Lit(hack.RuntimeAllocSize(sizegen.sizes.Sizeof(st))), 265 )) 266 for _, s := range stmt { 267 b.Add(s) 268 } 269 b.Add(jen.Return(jen.Id("size"))) 270 }) 271 return f, funcFlags 272 } 273 274 func (sizegen *sizegen) sizeStmtForMap(fieldName *jen.Statement, m *types.Map) []jen.Code { 275 const bucketCnt = 8 276 const sizeofHmap = int64(6 * 8) 277 278 /* 279 type bmap struct { 280 // tophash generally contains the top byte of the hash value 281 // for each key in this bucket. If tophash[0] < minTopHash, 282 // tophash[0] is a bucket evacuation state instead. 283 tophash [bucketCnt]uint8 284 // Followed by bucketCnt keys and then bucketCnt elems. 285 // NOTE: packing all the keys together and then all the elems together makes the 286 // code a bit more complicated than alternating key/elem/key/elem/... but it allows 287 // us to eliminate padding which would be needed for, e.g., map[int64]int8. 288 // Followed by an overflow pointer. 289 } 290 */ 291 sizeOfBucket := int( 292 bucketCnt + // tophash 293 bucketCnt*sizegen.sizes.Sizeof(m.Key()) + 294 bucketCnt*sizegen.sizes.Sizeof(m.Elem()) + 295 8, // overflow pointer 296 ) 297 298 return []jen.Code{ 299 jen.Id("size").Op("+=").Lit(hack.RuntimeAllocSize(sizeofHmap)), 300 301 jen.Id("hmap").Op(":=").Qual("reflect", "ValueOf").Call(fieldName), 302 303 jen.Id("numBuckets").Op(":=").Id("int").Call( 304 jen.Qual("math", "Pow").Call(jen.Lit(2), jen.Id("float64").Call( 305 jen.Parens(jen.Op("*").Parens(jen.Op("*").Id("uint8")).Call( 306 jen.Qual("unsafe", "Pointer").Call(jen.Id("hmap").Dot("Pointer").Call(). 307 Op("+").Id("uintptr").Call(jen.Lit(9)))))))), 308 309 jen.Id("numOldBuckets").Op(":=").Parens(jen.Op("*").Parens(jen.Op("*").Id("uint16")).Call( 310 jen.Qual("unsafe", "Pointer").Call( 311 jen.Id("hmap").Dot("Pointer").Call().Op("+").Id("uintptr").Call(jen.Lit(10))))), 312 313 jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Id("numOldBuckets").Op("*").Lit(sizeOfBucket)))), 314 315 jen.If(jen.Id("len").Call(fieldName).Op(">").Lit(0).Op("||").Id("numBuckets").Op(">").Lit(1)).Block( 316 jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Id("numBuckets").Op("*").Lit(sizeOfBucket))))), 317 } 318 } 319 320 func mallocsize(sizeStmt *jen.Statement) func(*jen.Statement) { 321 return func(parent *jen.Statement) { 322 parent.Qual("vitess.io/vitess/go/hack", "RuntimeAllocSize").Call(sizeStmt) 323 } 324 } 325 326 func (sizegen *sizegen) sizeStmtForArray(stmt []jen.Code, fieldName *jen.Statement, elemT types.Type) ([]jen.Code, codeFlag) { 327 var flag codeFlag 328 329 switch sizegen.sizes.Sizeof(elemT) { 330 case 0: 331 return nil, 0 332 333 case 1: 334 stmt = append(stmt, jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Cap(fieldName))))) 335 336 default: 337 var nested jen.Code 338 nested, flag = sizegen.sizeStmtForType(jen.Id("elem"), elemT, false) 339 340 stmt = append(stmt, 341 jen.Id("size"). 342 Op("+="). 343 Do(mallocsize(jen.Int64().Call(jen.Cap(fieldName)). 344 Op("*"). 345 Lit(sizegen.sizes.Sizeof(elemT))), 346 )) 347 348 if nested != nil { 349 stmt = append(stmt, jen.For(jen.List(jen.Id("_"), jen.Id("elem")).Op(":=").Range().Add(fieldName)).Block(nested)) 350 } 351 } 352 353 return stmt, flag 354 } 355 356 func (sizegen *sizegen) sizeStmtForType(fieldName *jen.Statement, field types.Type, alloc bool) (jen.Code, codeFlag) { 357 if sizegen.sizes.Sizeof(field) == 0 { 358 return nil, 0 359 } 360 361 switch node := field.(type) { 362 case *types.Slice: 363 var cond *jen.Statement 364 var stmt []jen.Code 365 var flag codeFlag 366 367 if alloc { 368 cond = jen.If(fieldName.Clone().Op("!=").Nil()) 369 fieldName = jen.Op("*").Add(fieldName) 370 stmt = append(stmt, jen.Id("size").Op("+=").Lit(hack.RuntimeAllocSize(8*3))) 371 } 372 373 stmt, flag = sizegen.sizeStmtForArray(stmt, fieldName, node.Elem()) 374 if cond != nil { 375 return cond.Block(stmt...), flag 376 } 377 return jen.Block(stmt...), flag 378 379 case *types.Array: 380 if alloc { 381 cond := jen.If(fieldName.Clone().Op("!=").Nil()) 382 fieldName = jen.Op("*").Add(fieldName) 383 384 stmt, flag := sizegen.sizeStmtForArray(nil, fieldName, node.Elem()) 385 return cond.Block(stmt...), flag 386 } 387 388 elemT := node.Elem() 389 if sizegen.sizes.Sizeof(elemT) > 1 { 390 nested, flag := sizegen.sizeStmtForType(jen.Id("elem"), elemT, false) 391 if nested != nil { 392 return jen.For(jen.List(jen.Id("_"), jen.Id("elem")).Op(":=").Range().Add(fieldName)).Block(nested), flag 393 } 394 } 395 return nil, 0 396 397 case *types.Map: 398 keySize, keyFlag := sizegen.sizeStmtForType(jen.Id("k"), node.Key(), false) 399 valSize, valFlag := sizegen.sizeStmtForType(jen.Id("v"), node.Elem(), false) 400 401 return jen.If(fieldName.Clone().Op("!=").Nil()).BlockFunc(func(block *jen.Group) { 402 for _, stmt := range sizegen.sizeStmtForMap(fieldName, node) { 403 block.Add(stmt) 404 } 405 406 var forLoopVars []jen.Code 407 switch { 408 case keySize != nil && valSize != nil: 409 forLoopVars = []jen.Code{jen.Id("k"), jen.Id("v")} 410 case keySize == nil && valSize != nil: 411 forLoopVars = []jen.Code{jen.Id("_"), jen.Id("v")} 412 case keySize != nil && valSize == nil: 413 forLoopVars = []jen.Code{jen.Id("k")} 414 case keySize == nil && valSize == nil: 415 return 416 } 417 418 block.Add(jen.For(jen.List(forLoopVars...).Op(":=").Range().Add(fieldName))).BlockFunc(func(b *jen.Group) { 419 if keySize != nil { 420 b.Add(keySize) 421 } 422 if valSize != nil { 423 b.Add(valSize) 424 } 425 }) 426 }), codeWithUnsafe | keyFlag | valFlag 427 428 case *types.Pointer: 429 return sizegen.sizeStmtForType(fieldName, node.Elem(), true) 430 431 case *types.Named: 432 ts := sizegen.getKnownType(node) 433 if ts.pod || !ts.local { 434 if alloc { 435 if !ts.local { 436 log.Printf("WARNING: size of external type %s cannot be fully calculated", node) 437 } 438 return jen.If(fieldName.Clone().Op("!=").Nil()).Block( 439 jen.Id("size").Op("+=").Do(mallocsize(jen.Lit(sizegen.sizes.Sizeof(node.Underlying())))), 440 ), 0 441 } 442 return nil, 0 443 } 444 return sizegen.sizeStmtForType(fieldName, node.Underlying(), alloc) 445 446 case *types.Interface: 447 if node.Empty() { 448 return nil, 0 449 } 450 return jen.If( 451 jen.List( 452 jen.Id("cc"), jen.Id("ok")). 453 Op(":="). 454 Add(fieldName.Clone().Assert(jen.Id("cachedObject"))), 455 jen.Id("ok"), 456 ).Block( 457 jen.Id("size"). 458 Op("+="). 459 Id("cc"). 460 Dot("CachedSize"). 461 Call(jen.True()), 462 ), codeWithInterface 463 464 case *types.Struct: 465 return jen.Id("size").Op("+=").Add(fieldName.Clone().Dot("CachedSize").Call(jen.Lit(alloc))), 0 466 467 case *types.Basic: 468 if !alloc { 469 if node.Info()&types.IsString != 0 { 470 return jen.Id("size").Op("+=").Do(mallocsize(jen.Int64().Call(jen.Len(fieldName)))), 0 471 } 472 return nil, 0 473 } 474 return jen.Id("size").Op("+=").Do(mallocsize(jen.Lit(sizegen.sizes.Sizeof(node)))), 0 475 476 case *types.Signature: 477 // assume that function pointers do not allocate (although they might, if they're closures) 478 return nil, 0 479 480 default: 481 log.Printf("unhandled type: %T", node) 482 return nil, 0 483 } 484 } 485 486 type typePaths []string 487 488 func (t *typePaths) String() string { 489 return fmt.Sprintf("%v", *t) 490 } 491 492 func (t *typePaths) Set(path string) error { 493 *t = append(*t, path) 494 return nil 495 } 496 497 func main() { 498 var ( 499 patterns, generate []string 500 verify bool 501 ) 502 503 pflag.StringSliceVar(&patterns, "in", nil, "Go packages to load the generator") 504 pflag.StringSliceVar(&generate, "gen", nil, "Typename of the Go struct to generate size info for") 505 pflag.BoolVar(&verify, "verify", false, "ensure that the generated files are correct") 506 pflag.Parse() 507 508 result, err := GenerateSizeHelpers(patterns, generate) 509 if err != nil { 510 log.Fatal(err) 511 } 512 513 if verify { 514 for _, err := range VerifyFilesOnDisk(result) { 515 log.Fatal(err) 516 } 517 log.Printf("%d files OK", len(result)) 518 } else { 519 for fullPath, file := range result { 520 content, err := goimports.FormatJenFile(file) 521 if err != nil { 522 log.Fatalf("failed to apply goimport to '%s': %v", fullPath, err) 523 } 524 err = os.WriteFile(fullPath, content, 0664) 525 if err != nil { 526 log.Fatalf("failed to save file to '%s': %v", fullPath, err) 527 } 528 } 529 } 530 } 531 532 // VerifyFilesOnDisk compares the generated results from the codegen against the files that 533 // currently exist on disk and returns any mismatches. All the files generated by jennifer 534 // are formatted using the goimports command. Any difference in the imports will also make 535 // this test fail. 536 func VerifyFilesOnDisk(result map[string]*jen.File) (errors []error) { 537 for fullPath, file := range result { 538 existing, err := os.ReadFile(fullPath) 539 if err != nil { 540 errors = append(errors, fmt.Errorf("missing file on disk: %s (%w)", fullPath, err)) 541 continue 542 } 543 544 genFile, err := goimports.FormatJenFile(file) 545 if err != nil { 546 errors = append(errors, fmt.Errorf("goimport error: %w", err)) 547 continue 548 } 549 550 if !bytes.Equal(existing, genFile) { 551 errors = append(errors, fmt.Errorf("'%s' has changed", fullPath)) 552 continue 553 } 554 } 555 return errors 556 } 557 558 // GenerateSizeHelpers generates the auxiliary code that implements CachedSize helper methods 559 // for all the types listed in typePatterns 560 func GenerateSizeHelpers(packagePatterns []string, typePatterns []string) (map[string]*jen.File, error) { 561 loaded, err := packages.Load(&packages.Config{ 562 Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule, 563 }, packagePatterns...) 564 565 if err != nil { 566 return nil, err 567 } 568 569 if common.PkgFailed(loaded) { 570 return nil, fmt.Errorf("failed to load packages") 571 } 572 573 sizegen := newSizegen(loaded[0].Module, loaded[0].TypesSizes) 574 575 scopes := make(map[string]*types.Scope) 576 for _, pkg := range loaded { 577 scopes[pkg.PkgPath] = pkg.Types.Scope() 578 } 579 580 for _, gen := range typePatterns { 581 pos := strings.LastIndexByte(gen, '.') 582 if pos < 0 { 583 return nil, fmt.Errorf("unexpected input type: %s", gen) 584 } 585 586 pkgname := gen[:pos] 587 typename := gen[pos+1:] 588 589 scope := scopes[pkgname] 590 if scope == nil { 591 return nil, fmt.Errorf("no scope found for type '%s'", gen) 592 } 593 594 if typename == "*" { 595 for _, name := range scope.Names() { 596 sizegen.generateKnownType(scope.Lookup(name).Type().(*types.Named)) 597 } 598 } else { 599 tt := scope.Lookup(typename) 600 if tt == nil { 601 return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname) 602 } 603 604 sizegen.generateKnownType(tt.Type().(*types.Named)) 605 } 606 } 607 608 return sizegen.finalize(), nil 609 }