gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_fieldenum/main.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Binary fieldenum emits field bitmasks for all structs in a package marked
    16  // "+fieldenum".
    17  package main
    18  
    19  import (
    20  	"flag"
    21  	"fmt"
    22  	"go/ast"
    23  	"go/parser"
    24  	"go/token"
    25  	"log"
    26  	"os"
    27  	"strings"
    28  )
    29  
    30  var (
    31  	outputPkg      = flag.String("pkg", "", "output package")
    32  	outputFilename = flag.String("out", "-", "output filename")
    33  )
    34  
    35  func main() {
    36  	// Parse command line arguments.
    37  	flag.Parse()
    38  	if len(*outputPkg) == 0 {
    39  		log.Fatalf("-pkg must be provided")
    40  	}
    41  	if len(flag.Args()) == 0 {
    42  		log.Fatalf("Input files must be provided")
    43  	}
    44  
    45  	// Parse input files.
    46  	inputFiles := make([]*ast.File, 0, len(flag.Args()))
    47  	fset := token.NewFileSet()
    48  	for _, filename := range flag.Args() {
    49  		f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
    50  		if err != nil {
    51  			log.Fatalf("Failed to parse input file %q: %v", filename, err)
    52  		}
    53  		inputFiles = append(inputFiles, f)
    54  	}
    55  
    56  	// Determine which types are marked "+fieldenum" and will consequently have
    57  	// code generated.
    58  	var typeNames []string
    59  	fieldEnumTypes := make(map[string]fieldEnumTypeInfo)
    60  	for _, f := range inputFiles {
    61  		for _, decl := range f.Decls {
    62  			d, ok := decl.(*ast.GenDecl)
    63  			if !ok || d.Tok != token.TYPE || d.Doc == nil || len(d.Specs) == 0 {
    64  				continue
    65  			}
    66  			for _, l := range d.Doc.List {
    67  				const fieldenumPrefixWithSpace = "// +fieldenum "
    68  				if l.Text == "// +fieldenum" || strings.HasPrefix(l.Text, fieldenumPrefixWithSpace) {
    69  					spec := d.Specs[0].(*ast.TypeSpec)
    70  					name := spec.Name.Name
    71  					prefix := name
    72  					if len(l.Text) > len(fieldenumPrefixWithSpace) {
    73  						prefix = strings.TrimSpace(l.Text[len(fieldenumPrefixWithSpace):])
    74  					}
    75  					st, ok := spec.Type.(*ast.StructType)
    76  					if !ok {
    77  						log.Fatalf("Type %s is marked +fieldenum, but is not a struct", name)
    78  					}
    79  					typeNames = append(typeNames, name)
    80  					fieldEnumTypes[name] = fieldEnumTypeInfo{
    81  						prefix:     prefix,
    82  						structType: st,
    83  					}
    84  					break
    85  				}
    86  			}
    87  		}
    88  	}
    89  
    90  	// Collect information for each type for which code is being generated.
    91  	structInfos := make([]structInfo, 0, len(typeNames))
    92  	needAtomic := false
    93  	for _, typeName := range typeNames {
    94  		typeInfo := fieldEnumTypes[typeName]
    95  		var si structInfo
    96  		si.name = typeName
    97  		si.prefix = typeInfo.prefix
    98  		for _, field := range typeInfo.structType.Fields.List {
    99  			name := structFieldName(field)
   100  			// If the field's type is a type that is also marked +fieldenum,
   101  			// include a FieldSet for that type in this one's. The field must
   102  			// be a struct by value, since if it's a pointer then that struct
   103  			// might also point to or include this one (which would make
   104  			// FieldSet inclusion circular). It must also be a type defined in
   105  			// this package, since otherwise we don't know whether it's marked
   106  			// +fieldenum. Thus, field.Type must be an identifier (rather than
   107  			// an ast.StarExpr or SelectorExpr).
   108  			if tident, ok := field.Type.(*ast.Ident); ok {
   109  				if fieldTypeInfo, ok := fieldEnumTypes[tident.Name]; ok {
   110  					fsf := fieldSetField{
   111  						fieldName:  name,
   112  						typePrefix: fieldTypeInfo.prefix,
   113  					}
   114  					si.reprByFieldSet = append(si.reprByFieldSet, fsf)
   115  					si.allFields = append(si.allFields, fsf)
   116  					continue
   117  				}
   118  			}
   119  			si.reprByBit = append(si.reprByBit, name)
   120  			si.allFields = append(si.allFields, fieldSetField{
   121  				fieldName: name,
   122  			})
   123  			// atomicbitops import will be needed for FieldSet.Load().
   124  			needAtomic = true
   125  		}
   126  		structInfos = append(structInfos, si)
   127  	}
   128  
   129  	// Build the output file.
   130  	var b strings.Builder
   131  	fmt.Fprintf(&b, "// Generated by go_fieldenum.\n\n")
   132  	fmt.Fprintf(&b, "package %s\n\n", *outputPkg)
   133  	if needAtomic {
   134  		fmt.Fprintf(&b, `import "gvisor.dev/gvisor/pkg/atomicbitops"`)
   135  		fmt.Fprintf(&b, "\n\n")
   136  	}
   137  	for _, si := range structInfos {
   138  		si.writeTo(&b)
   139  	}
   140  
   141  	if *outputFilename == "-" {
   142  		// Write output to stdout.
   143  		fmt.Printf("%s", b.String())
   144  	} else {
   145  		// Write output to file.
   146  		f, err := os.OpenFile(*outputFilename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
   147  		if err != nil {
   148  			log.Fatalf("Failed to open output file %q: %v", *outputFilename, err)
   149  		}
   150  		if _, err := f.WriteString(b.String()); err != nil {
   151  			log.Fatalf("Failed to write output file %q: %v", *outputFilename, err)
   152  		}
   153  		f.Close()
   154  	}
   155  }
   156  
   157  type fieldEnumTypeInfo struct {
   158  	prefix     string
   159  	structType *ast.StructType
   160  }
   161  
   162  // structInfo contains information about the code generated for a given struct.
   163  type structInfo struct {
   164  	// name is the name of the represented struct.
   165  	name string
   166  
   167  	// prefix is the prefix X applied to the name of each generated type and
   168  	// constant, referred to as X in the comments below for convenience.
   169  	prefix string
   170  
   171  	// reprByBit contains the names of fields in X that should be represented
   172  	// by a bit in the bit mask XFieldSet.fields, and by a bool in XFields.
   173  	reprByBit []string
   174  
   175  	// reprByFieldSet contains fields in X whose type is a named struct (e.g.
   176  	// Y) that has a corresponding FieldSet type YFieldSet, and which should
   177  	// therefore be represented by including a value of type YFieldSet in
   178  	// XFieldSet, and a value of type YFields in XFields.
   179  	reprByFieldSet []fieldSetField
   180  
   181  	// allFields contains all fields in X in order of declaration. Fields in
   182  	// reprByBit have fieldSetField.typePrefix == "".
   183  	allFields []fieldSetField
   184  }
   185  
   186  type fieldSetField struct {
   187  	fieldName  string
   188  	typePrefix string
   189  }
   190  
   191  func structFieldName(f *ast.Field) string {
   192  	if len(f.Names) != 0 {
   193  		return f.Names[0].Name
   194  	}
   195  	// For embedded struct fields, the field name is the unqualified type name.
   196  	texpr := f.Type
   197  	for {
   198  		switch t := texpr.(type) {
   199  		case *ast.StarExpr:
   200  			texpr = t.X
   201  		case *ast.SelectorExpr:
   202  			texpr = t.Sel
   203  		case *ast.Ident:
   204  			return t.Name
   205  		default:
   206  			panic(fmt.Sprintf("unexpected %T", texpr))
   207  		}
   208  	}
   209  }
   210  
   211  func (si *structInfo) writeTo(b *strings.Builder) {
   212  	fmt.Fprintf(b, "// A %sField represents a field in %s.\n", si.prefix, si.name)
   213  	fmt.Fprintf(b, "type %sField uint\n\n", si.prefix)
   214  	if len(si.reprByBit) != 0 {
   215  		fmt.Fprintf(b, "// %sFieldX represents %s field X.\n", si.prefix, si.name)
   216  		fmt.Fprintf(b, "const (\n")
   217  		fmt.Fprintf(b, "\t%sField%s %sField = iota\n", si.prefix, si.reprByBit[0], si.prefix)
   218  		for _, fieldName := range si.reprByBit[1:] {
   219  			fmt.Fprintf(b, "\t%sField%s\n", si.prefix, fieldName)
   220  		}
   221  		fmt.Fprintf(b, ")\n\n")
   222  	}
   223  
   224  	fmt.Fprintf(b, "// %sFields represents a set of fields in %s in a literal-friendly form.\n", si.prefix, si.name)
   225  	fmt.Fprintf(b, "// The zero value of %sFields represents an empty set.\n", si.prefix)
   226  	fmt.Fprintf(b, "type %sFields struct {\n", si.prefix)
   227  	for _, fieldSetField := range si.allFields {
   228  		if fieldSetField.typePrefix == "" {
   229  			fmt.Fprintf(b, "\t%s bool\n", fieldSetField.fieldName)
   230  		} else {
   231  			fmt.Fprintf(b, "\t%s %sFields\n", fieldSetField.fieldName, fieldSetField.typePrefix)
   232  		}
   233  	}
   234  	fmt.Fprintf(b, "}\n\n")
   235  
   236  	fmt.Fprintf(b, "// %sFieldSet represents a set of fields in %s in a compact form.\n", si.prefix, si.name)
   237  	fmt.Fprintf(b, "// The zero value of %sFieldSet represents an empty set.\n", si.prefix)
   238  	fmt.Fprintf(b, "type %sFieldSet struct {\n", si.prefix)
   239  	numBitmaskUint32s := (len(si.reprByBit) + 31) / 32
   240  	for _, fieldSetField := range si.reprByFieldSet {
   241  		fmt.Fprintf(b, "\t%s %sFieldSet\n", fieldSetField.fieldName, fieldSetField.typePrefix)
   242  	}
   243  	if len(si.reprByBit) != 0 {
   244  		fmt.Fprintf(b, "\tfields [%d]atomicbitops.Uint32\n", numBitmaskUint32s)
   245  	}
   246  	fmt.Fprintf(b, "}\n\n")
   247  
   248  	if len(si.reprByBit) != 0 {
   249  		fmt.Fprintf(b, "// Contains returns true if f is present in the %sFieldSet.\n", si.prefix)
   250  		fmt.Fprintf(b, "func (fs *%sFieldSet) Contains(f %sField) bool {\n", si.prefix, si.prefix)
   251  		if numBitmaskUint32s == 1 {
   252  			fmt.Fprintf(b, "\treturn fs.fields[0].RacyLoad() & (uint32(1) << uint(f)) != 0\n")
   253  		} else {
   254  			fmt.Fprintf(b, "\treturn fs.fields[f/32].RacyLoad() & (uint32(1) << (f%%32)) != 0\n")
   255  		}
   256  		fmt.Fprintf(b, "}\n\n")
   257  
   258  		fmt.Fprintf(b, "// Add adds f to the %sFieldSet.\n", si.prefix)
   259  		fmt.Fprintf(b, "func (fs *%sFieldSet) Add(f %sField) {\n", si.prefix, si.prefix)
   260  		if numBitmaskUint32s == 1 {
   261  			fmt.Fprintf(b, "\tfs.fields[0] = atomicbitops.FromUint32(fs.fields[0].RacyLoad() | (uint32(1) << uint(f)))\n")
   262  		} else {
   263  			fmt.Fprintf(b, "\tfs.fields[f/32] = atomicbitops.FromUint32(fs.fields[f/32].RacyLoad() | (uint32(1) << (f%%32))\n")
   264  		}
   265  		fmt.Fprintf(b, "}\n\n")
   266  
   267  		fmt.Fprintf(b, "// Remove removes f from the %sFieldSet.\n", si.prefix)
   268  		fmt.Fprintf(b, "func (fs *%sFieldSet) Remove(f %sField) {\n", si.prefix, si.prefix)
   269  		if numBitmaskUint32s == 1 {
   270  			fmt.Fprintf(b, "\tfs.fields[0] = atomicbitops.FromUint32(fs.fields[0].RacyLoad() &^ (uint32(1) << uint(f)))\n")
   271  		} else {
   272  			fmt.Fprintf(b, "\tfs.fields[f/32] = atomicbitops.FromUint32(fs.fields[f/32].RacyLoad() &^ (uint32(1) << uint(f%%32)))\n")
   273  		}
   274  		fmt.Fprintf(b, "}\n\n")
   275  	}
   276  
   277  	fmt.Fprintf(b, "// Load returns a copy of the %sFieldSet.\n", si.prefix)
   278  	fmt.Fprintf(b, "// Load is safe to call concurrently with AddFieldsLoadable, but not Add or Remove.\n")
   279  	fmt.Fprintf(b, "func (fs *%sFieldSet) Load() (copied %sFieldSet) {\n", si.prefix, si.prefix)
   280  	for _, fieldSetField := range si.reprByFieldSet {
   281  		fmt.Fprintf(b, "\tcopied.%s = fs.%s.Load()\n", fieldSetField.fieldName, fieldSetField.fieldName)
   282  	}
   283  	for i := 0; i < numBitmaskUint32s; i++ {
   284  		fmt.Fprintf(b, "\tcopied.fields[%d] = atomicbitops.FromUint32(fs.fields[%d].Load())\n", i, i)
   285  	}
   286  	fmt.Fprintf(b, "\treturn\n")
   287  	fmt.Fprintf(b, "}\n\n")
   288  
   289  	fmt.Fprintf(b, "// AddFieldsLoadable adds the given fields to the %sFieldSet.\n", si.prefix)
   290  	fmt.Fprintf(b, "// AddFieldsLoadable is safe to call concurrently with Load, but not other methods (including other calls to AddFieldsLoadable).\n")
   291  	fmt.Fprintf(b, "func (fs *%sFieldSet) AddFieldsLoadable(fields %sFields) {\n", si.prefix, si.prefix)
   292  	for _, fieldSetField := range si.reprByFieldSet {
   293  		fmt.Fprintf(b, "\tfs.%s.AddFieldsLoadable(fields.%s)\n", fieldSetField.fieldName, fieldSetField.fieldName)
   294  	}
   295  	for _, fieldName := range si.reprByBit {
   296  		fieldConstName := fmt.Sprintf("%sField%s", si.prefix, fieldName)
   297  		fmt.Fprintf(b, "\tif fields.%s {\n", fieldName)
   298  		if numBitmaskUint32s == 1 {
   299  			fmt.Fprintf(b, "\t\tfs.fields[0].Store(fs.fields[0].RacyLoad() | (uint32(1) << uint(%s)))\n", fieldConstName)
   300  		} else {
   301  			fmt.Fprintf(b, "\t\tword, bit := %s/32, %s%%32\n", fieldConstName, fieldConstName)
   302  			fmt.Fprintf(b, "\t\tfs.fields[word].Store(fs.fields[word].RacyLoad() | (uint32(1) << bit))\n")
   303  		}
   304  		fmt.Fprintf(b, "\t}\n")
   305  	}
   306  	fmt.Fprintf(b, "}\n\n")
   307  }