golang.org/x/sys@v0.20.1-0.20240517151509-673e0f94c16d/unix/internal/mkmerge/mkmerge.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // The mkmerge command parses generated source files and merges common
     6  // consts, funcs, and types into a common source file, per GOOS.
     7  //
     8  // Usage:
     9  //
    10  //	$ mkmerge -out MERGED FILE [FILE ...]
    11  //
    12  // Example:
    13  //
    14  //	# Remove all common consts, funcs, and types from zerrors_linux_*.go
    15  //	# and write the common code into zerrors_linux.go
    16  //	$ mkmerge -out zerrors_linux.go zerrors_linux_*.go
    17  //
    18  // mkmerge performs the merge in the following steps:
    19  //  1. Construct the set of common code that is identical in all
    20  //     architecture-specific files.
    21  //  2. Write this common code to the merged file.
    22  //  3. Remove the common code from all architecture-specific files.
    23  package main
    24  
    25  import (
    26  	"bufio"
    27  	"bytes"
    28  	"flag"
    29  	"fmt"
    30  	"go/ast"
    31  	"go/format"
    32  	"go/parser"
    33  	"go/token"
    34  	"io"
    35  	"log"
    36  	"os"
    37  	"path"
    38  	"path/filepath"
    39  	"regexp"
    40  	"strconv"
    41  	"strings"
    42  )
    43  
    44  const validGOOS = "aix|darwin|dragonfly|freebsd|linux|netbsd|openbsd|solaris"
    45  
    46  // getValidGOOS returns GOOS, true if filename ends with a valid "_GOOS.go"
    47  func getValidGOOS(filename string) (string, bool) {
    48  	matches := regexp.MustCompile(`_(` + validGOOS + `)\.go$`).FindStringSubmatch(filename)
    49  	if len(matches) != 2 {
    50  		return "", false
    51  	}
    52  	return matches[1], true
    53  }
    54  
    55  // codeElem represents an ast.Decl in a comparable way.
    56  type codeElem struct {
    57  	tok token.Token // e.g. token.CONST, token.TYPE, or token.FUNC
    58  	src string      // the declaration formatted as source code
    59  }
    60  
    61  // newCodeElem returns a codeElem based on tok and node, or an error is returned.
    62  func newCodeElem(tok token.Token, node ast.Node) (codeElem, error) {
    63  	var b strings.Builder
    64  	err := format.Node(&b, token.NewFileSet(), node)
    65  	if err != nil {
    66  		return codeElem{}, err
    67  	}
    68  	return codeElem{tok, b.String()}, nil
    69  }
    70  
    71  // codeSet is a set of codeElems
    72  type codeSet struct {
    73  	set map[codeElem]bool // true for all codeElems in the set
    74  }
    75  
    76  // newCodeSet returns a new codeSet
    77  func newCodeSet() *codeSet { return &codeSet{make(map[codeElem]bool)} }
    78  
    79  // add adds elem to c
    80  func (c *codeSet) add(elem codeElem) { c.set[elem] = true }
    81  
    82  // has returns true if elem is in c
    83  func (c *codeSet) has(elem codeElem) bool { return c.set[elem] }
    84  
    85  // isEmpty returns true if the set is empty
    86  func (c *codeSet) isEmpty() bool { return len(c.set) == 0 }
    87  
    88  // intersection returns a new set which is the intersection of c and a
    89  func (c *codeSet) intersection(a *codeSet) *codeSet {
    90  	res := newCodeSet()
    91  
    92  	for elem := range c.set {
    93  		if a.has(elem) {
    94  			res.add(elem)
    95  		}
    96  	}
    97  	return res
    98  }
    99  
   100  // keepCommon is a filterFn for filtering the merged file with common declarations.
   101  func (c *codeSet) keepCommon(elem codeElem) bool {
   102  	switch elem.tok {
   103  	case token.VAR:
   104  		// Remove all vars from the merged file
   105  		return false
   106  	case token.CONST, token.TYPE, token.FUNC, token.COMMENT:
   107  		// Remove arch-specific consts, types, functions, and file-level comments from the merged file
   108  		return c.has(elem)
   109  	case token.IMPORT:
   110  		// Keep imports, they are handled by filterImports
   111  		return true
   112  	}
   113  
   114  	log.Fatalf("keepCommon: invalid elem %v", elem)
   115  	return true
   116  }
   117  
   118  // keepArchSpecific is a filterFn for filtering the GOARC-specific files.
   119  func (c *codeSet) keepArchSpecific(elem codeElem) bool {
   120  	switch elem.tok {
   121  	case token.CONST, token.TYPE, token.FUNC:
   122  		// Remove common consts, types, or functions from the arch-specific file
   123  		return !c.has(elem)
   124  	}
   125  	return true
   126  }
   127  
   128  // srcFile represents a source file
   129  type srcFile struct {
   130  	name string
   131  	src  []byte
   132  }
   133  
   134  // filterFn is a helper for filter
   135  type filterFn func(codeElem) bool
   136  
   137  // filter parses and filters Go source code from src, removing top
   138  // level declarations using keep as predicate.
   139  // For src parameter, please see docs for parser.ParseFile.
   140  func filter(src interface{}, keep filterFn) ([]byte, error) {
   141  	// Parse the src into an ast
   142  	fset := token.NewFileSet()
   143  	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	cmap := ast.NewCommentMap(fset, f, f.Comments)
   148  
   149  	// Group const/type specs on adjacent lines
   150  	var groups specGroups = make(map[string]int)
   151  	var groupID int
   152  
   153  	decls := f.Decls
   154  	f.Decls = f.Decls[:0]
   155  	for _, decl := range decls {
   156  		switch decl := decl.(type) {
   157  		case *ast.GenDecl:
   158  			// Filter imports, consts, types, vars
   159  			specs := decl.Specs
   160  			decl.Specs = decl.Specs[:0]
   161  			for i, spec := range specs {
   162  				elem, err := newCodeElem(decl.Tok, spec)
   163  				if err != nil {
   164  					return nil, err
   165  				}
   166  
   167  				// Create new group if there are empty lines between this and the previous spec
   168  				if i > 0 && fset.Position(specs[i-1].End()).Line < fset.Position(spec.Pos()).Line-1 {
   169  					groupID++
   170  				}
   171  
   172  				// Check if we should keep this spec
   173  				if keep(elem) {
   174  					decl.Specs = append(decl.Specs, spec)
   175  					groups.add(elem.src, groupID)
   176  				}
   177  			}
   178  			// Check if we should keep this decl
   179  			if len(decl.Specs) > 0 {
   180  				f.Decls = append(f.Decls, decl)
   181  			}
   182  		case *ast.FuncDecl:
   183  			// Filter funcs
   184  			elem, err := newCodeElem(token.FUNC, decl)
   185  			if err != nil {
   186  				return nil, err
   187  			}
   188  			if keep(elem) {
   189  				f.Decls = append(f.Decls, decl)
   190  			}
   191  		}
   192  	}
   193  
   194  	// Filter file level comments
   195  	if cmap[f] != nil {
   196  		commentGroups := cmap[f]
   197  		cmap[f] = cmap[f][:0]
   198  		for _, cGrp := range commentGroups {
   199  			if keep(codeElem{token.COMMENT, cGrp.Text()}) {
   200  				cmap[f] = append(cmap[f], cGrp)
   201  			}
   202  		}
   203  	}
   204  	f.Comments = cmap.Filter(f).Comments()
   205  
   206  	// Generate code for the filtered ast
   207  	var buf bytes.Buffer
   208  	if err = format.Node(&buf, fset, f); err != nil {
   209  		return nil, err
   210  	}
   211  
   212  	groupedSrc, err := groups.filterEmptyLines(&buf)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  
   217  	return filterImports(groupedSrc)
   218  }
   219  
   220  // getCommonSet returns the set of consts, types, and funcs that are present in every file.
   221  func getCommonSet(files []srcFile) (*codeSet, error) {
   222  	if len(files) == 0 {
   223  		return nil, fmt.Errorf("no files provided")
   224  	}
   225  	// Use the first architecture file as the baseline
   226  	baseSet, err := getCodeSet(files[0].src)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  
   231  	// Compare baseline set with other architecture files: discard any element,
   232  	// that doesn't exist in other architecture files.
   233  	for _, f := range files[1:] {
   234  		set, err := getCodeSet(f.src)
   235  		if err != nil {
   236  			return nil, err
   237  		}
   238  
   239  		baseSet = baseSet.intersection(set)
   240  	}
   241  	return baseSet, nil
   242  }
   243  
   244  // getCodeSet returns the set of all top-level consts, types, and funcs from src.
   245  // src must be string, []byte, or io.Reader (see go/parser.ParseFile docs)
   246  func getCodeSet(src interface{}) (*codeSet, error) {
   247  	set := newCodeSet()
   248  
   249  	fset := token.NewFileSet()
   250  	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	for _, decl := range f.Decls {
   256  		switch decl := decl.(type) {
   257  		case *ast.GenDecl:
   258  			// Add const, and type declarations
   259  			if !(decl.Tok == token.CONST || decl.Tok == token.TYPE) {
   260  				break
   261  			}
   262  
   263  			for _, spec := range decl.Specs {
   264  				elem, err := newCodeElem(decl.Tok, spec)
   265  				if err != nil {
   266  					return nil, err
   267  				}
   268  
   269  				set.add(elem)
   270  			}
   271  		case *ast.FuncDecl:
   272  			// Add func declarations
   273  			elem, err := newCodeElem(token.FUNC, decl)
   274  			if err != nil {
   275  				return nil, err
   276  			}
   277  
   278  			set.add(elem)
   279  		}
   280  	}
   281  
   282  	// Add file level comments
   283  	cmap := ast.NewCommentMap(fset, f, f.Comments)
   284  	for _, cGrp := range cmap[f] {
   285  		text := cGrp.Text()
   286  		if text == "" && len(cGrp.List) == 1 && strings.HasPrefix(cGrp.List[0].Text, "//go:build ") {
   287  			// ast.CommentGroup.Text doesn't include comment directives like "//go:build"
   288  			// in the text. So if a comment group has empty text and a single //go:build
   289  			// constraint line, make a custom codeElem. This is enough for mkmerge needs.
   290  			set.add(codeElem{token.COMMENT, cGrp.List[0].Text[len("//"):] + "\n"})
   291  			continue
   292  		}
   293  		set.add(codeElem{token.COMMENT, text})
   294  	}
   295  
   296  	return set, nil
   297  }
   298  
   299  // importName returns the identifier (PackageName) for an imported package
   300  func importName(iSpec *ast.ImportSpec) (string, error) {
   301  	if iSpec.Name == nil {
   302  		name, err := strconv.Unquote(iSpec.Path.Value)
   303  		if err != nil {
   304  			return "", err
   305  		}
   306  		return path.Base(name), nil
   307  	}
   308  	return iSpec.Name.Name, nil
   309  }
   310  
   311  // specGroups tracks grouped const/type specs with a map of line: groupID pairs
   312  type specGroups map[string]int
   313  
   314  // add spec source to group
   315  func (s specGroups) add(src string, groupID int) error {
   316  	srcBytes, err := format.Source(bytes.TrimSpace([]byte(src)))
   317  	if err != nil {
   318  		return err
   319  	}
   320  	s[string(srcBytes)] = groupID
   321  	return nil
   322  }
   323  
   324  // filterEmptyLines removes empty lines within groups of const/type specs.
   325  // Returns the filtered source.
   326  func (s specGroups) filterEmptyLines(src io.Reader) ([]byte, error) {
   327  	scanner := bufio.NewScanner(src)
   328  	var out bytes.Buffer
   329  
   330  	var emptyLines bytes.Buffer
   331  	prevGroupID := -1 // Initialize to invalid group
   332  	for scanner.Scan() {
   333  		line := bytes.TrimSpace(scanner.Bytes())
   334  
   335  		if len(line) == 0 {
   336  			fmt.Fprintf(&emptyLines, "%s\n", scanner.Bytes())
   337  			continue
   338  		}
   339  
   340  		// Discard emptyLines if previous non-empty line belonged to the same
   341  		// group as this line
   342  		if src, err := format.Source(line); err == nil {
   343  			groupID, ok := s[string(src)]
   344  			if ok && groupID == prevGroupID {
   345  				emptyLines.Reset()
   346  			}
   347  			prevGroupID = groupID
   348  		}
   349  
   350  		emptyLines.WriteTo(&out)
   351  		fmt.Fprintf(&out, "%s\n", scanner.Bytes())
   352  	}
   353  	if err := scanner.Err(); err != nil {
   354  		return nil, err
   355  	}
   356  	return out.Bytes(), nil
   357  }
   358  
   359  // filterImports removes unused imports from fileSrc, and returns a formatted src.
   360  func filterImports(fileSrc []byte) ([]byte, error) {
   361  	fset := token.NewFileSet()
   362  	file, err := parser.ParseFile(fset, "", fileSrc, parser.ParseComments)
   363  	if err != nil {
   364  		return nil, err
   365  	}
   366  	cmap := ast.NewCommentMap(fset, file, file.Comments)
   367  
   368  	// create set of references to imported identifiers
   369  	keepImport := make(map[string]bool)
   370  	for _, u := range file.Unresolved {
   371  		keepImport[u.Name] = true
   372  	}
   373  
   374  	// filter import declarations
   375  	decls := file.Decls
   376  	file.Decls = file.Decls[:0]
   377  	for _, decl := range decls {
   378  		importDecl, ok := decl.(*ast.GenDecl)
   379  
   380  		// Keep non-import declarations
   381  		if !ok || importDecl.Tok != token.IMPORT {
   382  			file.Decls = append(file.Decls, decl)
   383  			continue
   384  		}
   385  
   386  		// Filter the import specs
   387  		specs := importDecl.Specs
   388  		importDecl.Specs = importDecl.Specs[:0]
   389  		for _, spec := range specs {
   390  			iSpec := spec.(*ast.ImportSpec)
   391  			name, err := importName(iSpec)
   392  			if err != nil {
   393  				return nil, err
   394  			}
   395  
   396  			if keepImport[name] {
   397  				importDecl.Specs = append(importDecl.Specs, iSpec)
   398  			}
   399  		}
   400  		if len(importDecl.Specs) > 0 {
   401  			file.Decls = append(file.Decls, importDecl)
   402  		}
   403  	}
   404  
   405  	// filter file.Imports
   406  	imports := file.Imports
   407  	file.Imports = file.Imports[:0]
   408  	for _, spec := range imports {
   409  		name, err := importName(spec)
   410  		if err != nil {
   411  			return nil, err
   412  		}
   413  
   414  		if keepImport[name] {
   415  			file.Imports = append(file.Imports, spec)
   416  		}
   417  	}
   418  	file.Comments = cmap.Filter(file).Comments()
   419  
   420  	var buf bytes.Buffer
   421  	err = format.Node(&buf, fset, file)
   422  	if err != nil {
   423  		return nil, err
   424  	}
   425  
   426  	return buf.Bytes(), nil
   427  }
   428  
   429  // merge extracts duplicate code from archFiles and merges it to mergeFile.
   430  // 1. Construct commonSet: the set of code that is idential in all archFiles.
   431  // 2. Write the code in commonSet to mergedFile.
   432  // 3. Remove the commonSet code from all archFiles.
   433  func merge(mergedFile string, archFiles ...string) error {
   434  	// extract and validate the GOOS part of the merged filename
   435  	goos, ok := getValidGOOS(mergedFile)
   436  	if !ok {
   437  		return fmt.Errorf("invalid GOOS in merged file name %s", mergedFile)
   438  	}
   439  
   440  	// Read architecture files
   441  	var inSrc []srcFile
   442  	for _, file := range archFiles {
   443  		src, err := os.ReadFile(file)
   444  		if err != nil {
   445  			return fmt.Errorf("cannot read archfile %s: %w", file, err)
   446  		}
   447  
   448  		inSrc = append(inSrc, srcFile{file, src})
   449  	}
   450  
   451  	// 1. Construct the set of top-level declarations common for all files
   452  	commonSet, err := getCommonSet(inSrc)
   453  	if err != nil {
   454  		return err
   455  	}
   456  	if commonSet.isEmpty() {
   457  		// No common code => do not modify any files
   458  		return nil
   459  	}
   460  
   461  	// 2. Write the merged file
   462  	mergedSrc, err := filter(inSrc[0].src, commonSet.keepCommon)
   463  	if err != nil {
   464  		return err
   465  	}
   466  
   467  	f, err := os.Create(mergedFile)
   468  	if err != nil {
   469  		return err
   470  	}
   471  
   472  	buf := bufio.NewWriter(f)
   473  	fmt.Fprintln(buf, "// Code generated by mkmerge; DO NOT EDIT.")
   474  	fmt.Fprintln(buf)
   475  	fmt.Fprintf(buf, "//go:build %s\n", goos)
   476  	fmt.Fprintln(buf)
   477  	buf.Write(mergedSrc)
   478  
   479  	err = buf.Flush()
   480  	if err != nil {
   481  		return err
   482  	}
   483  	err = f.Close()
   484  	if err != nil {
   485  		return err
   486  	}
   487  
   488  	// 3. Remove duplicate declarations from the architecture files
   489  	for _, inFile := range inSrc {
   490  		src, err := filter(inFile.src, commonSet.keepArchSpecific)
   491  		if err != nil {
   492  			return err
   493  		}
   494  		err = os.WriteFile(inFile.name, src, 0644)
   495  		if err != nil {
   496  			return err
   497  		}
   498  	}
   499  	return nil
   500  }
   501  
   502  func main() {
   503  	var mergedFile string
   504  	flag.StringVar(&mergedFile, "out", "", "Write merged code to `FILE`")
   505  	flag.Parse()
   506  
   507  	// Expand wildcards
   508  	var filenames []string
   509  	for _, arg := range flag.Args() {
   510  		matches, err := filepath.Glob(arg)
   511  		if err != nil {
   512  			fmt.Fprintf(os.Stderr, "Invalid command line argument %q: %v\n", arg, err)
   513  			os.Exit(1)
   514  		}
   515  		filenames = append(filenames, matches...)
   516  	}
   517  
   518  	if len(filenames) < 2 {
   519  		// No need to merge
   520  		return
   521  	}
   522  
   523  	err := merge(mergedFile, filenames...)
   524  	if err != nil {
   525  		fmt.Fprintf(os.Stderr, "Merge failed with error: %v\n", err)
   526  		os.Exit(1)
   527  	}
   528  }