github.com/neonyo/sys@v0.0.0-20230720094341-b1ee14be3ce8/unix/internal/mkmerge/mkmerge.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // The mkmerge command parses generated source files and merges common
     6  // consts, funcs, and types into a common source file, per GOOS.
     7  //
     8  // Usage:
     9  //
    10  //	$ mkmerge -out MERGED FILE [FILE ...]
    11  //
    12  // Example:
    13  //
    14  //	# Remove all common consts, funcs, and types from zerrors_linux_*.go
    15  //	# and write the common code into zerrors_linux.go
    16  //	$ mkmerge -out zerrors_linux.go zerrors_linux_*.go
    17  //
    18  // mkmerge performs the merge in the following steps:
    19  //  1. Construct the set of common code that is identical in all
    20  //     architecture-specific files.
    21  //  2. Write this common code to the merged file.
    22  //  3. Remove the common code from all architecture-specific files.
    23  package main
    24  
    25  import (
    26  	"bufio"
    27  	"bytes"
    28  	"flag"
    29  	"fmt"
    30  	"go/ast"
    31  	"go/format"
    32  	"go/parser"
    33  	"go/token"
    34  	"io"
    35  	"io/ioutil"
    36  	"log"
    37  	"os"
    38  	"path"
    39  	"path/filepath"
    40  	"regexp"
    41  	"strconv"
    42  	"strings"
    43  )
    44  
    45  const validGOOS = "aix|darwin|dragonfly|freebsd|linux|netbsd|openbsd|solaris"
    46  
    47  // getValidGOOS returns GOOS, true if filename ends with a valid "_GOOS.go"
    48  func getValidGOOS(filename string) (string, bool) {
    49  	matches := regexp.MustCompile(`_(` + validGOOS + `)\.go$`).FindStringSubmatch(filename)
    50  	if len(matches) != 2 {
    51  		return "", false
    52  	}
    53  	return matches[1], true
    54  }
    55  
    56  // codeElem represents an ast.Decl in a comparable way.
    57  type codeElem struct {
    58  	tok token.Token // e.g. token.CONST, token.TYPE, or token.FUNC
    59  	src string      // the declaration formatted as source code
    60  }
    61  
    62  // newCodeElem returns a codeElem based on tok and node, or an error is returned.
    63  func newCodeElem(tok token.Token, node ast.Node) (codeElem, error) {
    64  	var b strings.Builder
    65  	err := format.Node(&b, token.NewFileSet(), node)
    66  	if err != nil {
    67  		return codeElem{}, err
    68  	}
    69  	return codeElem{tok, b.String()}, nil
    70  }
    71  
    72  // codeSet is a set of codeElems
    73  type codeSet struct {
    74  	set map[codeElem]bool // true for all codeElems in the set
    75  }
    76  
    77  // newCodeSet returns a new codeSet
    78  func newCodeSet() *codeSet { return &codeSet{make(map[codeElem]bool)} }
    79  
    80  // add adds elem to c
    81  func (c *codeSet) add(elem codeElem) { c.set[elem] = true }
    82  
    83  // has returns true if elem is in c
    84  func (c *codeSet) has(elem codeElem) bool { return c.set[elem] }
    85  
    86  // isEmpty returns true if the set is empty
    87  func (c *codeSet) isEmpty() bool { return len(c.set) == 0 }
    88  
    89  // intersection returns a new set which is the intersection of c and a
    90  func (c *codeSet) intersection(a *codeSet) *codeSet {
    91  	res := newCodeSet()
    92  
    93  	for elem := range c.set {
    94  		if a.has(elem) {
    95  			res.add(elem)
    96  		}
    97  	}
    98  	return res
    99  }
   100  
   101  // keepCommon is a filterFn for filtering the merged file with common declarations.
   102  func (c *codeSet) keepCommon(elem codeElem) bool {
   103  	switch elem.tok {
   104  	case token.VAR:
   105  		// Remove all vars from the merged file
   106  		return false
   107  	case token.CONST, token.TYPE, token.FUNC, token.COMMENT:
   108  		// Remove arch-specific consts, types, functions, and file-level comments from the merged file
   109  		return c.has(elem)
   110  	case token.IMPORT:
   111  		// Keep imports, they are handled by filterImports
   112  		return true
   113  	}
   114  
   115  	log.Fatalf("keepCommon: invalid elem %v", elem)
   116  	return true
   117  }
   118  
   119  // keepArchSpecific is a filterFn for filtering the GOARC-specific files.
   120  func (c *codeSet) keepArchSpecific(elem codeElem) bool {
   121  	switch elem.tok {
   122  	case token.CONST, token.TYPE, token.FUNC:
   123  		// Remove common consts, types, or functions from the arch-specific file
   124  		return !c.has(elem)
   125  	}
   126  	return true
   127  }
   128  
   129  // srcFile represents a source file
   130  type srcFile struct {
   131  	name string
   132  	src  []byte
   133  }
   134  
   135  // filterFn is a helper for filter
   136  type filterFn func(codeElem) bool
   137  
   138  // filter parses and filters Go source code from src, removing top
   139  // level declarations using keep as predicate.
   140  // For src parameter, please see docs for parser.ParseFile.
   141  func filter(src interface{}, keep filterFn) ([]byte, error) {
   142  	// Parse the src into an ast
   143  	fset := token.NewFileSet()
   144  	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	cmap := ast.NewCommentMap(fset, f, f.Comments)
   149  
   150  	// Group const/type specs on adjacent lines
   151  	var groups specGroups = make(map[string]int)
   152  	var groupID int
   153  
   154  	decls := f.Decls
   155  	f.Decls = f.Decls[:0]
   156  	for _, decl := range decls {
   157  		switch decl := decl.(type) {
   158  		case *ast.GenDecl:
   159  			// Filter imports, consts, types, vars
   160  			specs := decl.Specs
   161  			decl.Specs = decl.Specs[:0]
   162  			for i, spec := range specs {
   163  				elem, err := newCodeElem(decl.Tok, spec)
   164  				if err != nil {
   165  					return nil, err
   166  				}
   167  
   168  				// Create new group if there are empty lines between this and the previous spec
   169  				if i > 0 && fset.Position(specs[i-1].End()).Line < fset.Position(spec.Pos()).Line-1 {
   170  					groupID++
   171  				}
   172  
   173  				// Check if we should keep this spec
   174  				if keep(elem) {
   175  					decl.Specs = append(decl.Specs, spec)
   176  					groups.add(elem.src, groupID)
   177  				}
   178  			}
   179  			// Check if we should keep this decl
   180  			if len(decl.Specs) > 0 {
   181  				f.Decls = append(f.Decls, decl)
   182  			}
   183  		case *ast.FuncDecl:
   184  			// Filter funcs
   185  			elem, err := newCodeElem(token.FUNC, decl)
   186  			if err != nil {
   187  				return nil, err
   188  			}
   189  			if keep(elem) {
   190  				f.Decls = append(f.Decls, decl)
   191  			}
   192  		}
   193  	}
   194  
   195  	// Filter file level comments
   196  	if cmap[f] != nil {
   197  		commentGroups := cmap[f]
   198  		cmap[f] = cmap[f][:0]
   199  		for _, cGrp := range commentGroups {
   200  			if keep(codeElem{token.COMMENT, cGrp.Text()}) {
   201  				cmap[f] = append(cmap[f], cGrp)
   202  			}
   203  		}
   204  	}
   205  	f.Comments = cmap.Filter(f).Comments()
   206  
   207  	// Generate code for the filtered ast
   208  	var buf bytes.Buffer
   209  	if err = format.Node(&buf, fset, f); err != nil {
   210  		return nil, err
   211  	}
   212  
   213  	groupedSrc, err := groups.filterEmptyLines(&buf)
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  
   218  	return filterImports(groupedSrc)
   219  }
   220  
   221  // getCommonSet returns the set of consts, types, and funcs that are present in every file.
   222  func getCommonSet(files []srcFile) (*codeSet, error) {
   223  	if len(files) == 0 {
   224  		return nil, fmt.Errorf("no files provided")
   225  	}
   226  	// Use the first architecture file as the baseline
   227  	baseSet, err := getCodeSet(files[0].src)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	// Compare baseline set with other architecture files: discard any element,
   233  	// that doesn't exist in other architecture files.
   234  	for _, f := range files[1:] {
   235  		set, err := getCodeSet(f.src)
   236  		if err != nil {
   237  			return nil, err
   238  		}
   239  
   240  		baseSet = baseSet.intersection(set)
   241  	}
   242  	return baseSet, nil
   243  }
   244  
   245  // getCodeSet returns the set of all top-level consts, types, and funcs from src.
   246  // src must be string, []byte, or io.Reader (see go/parser.ParseFile docs)
   247  func getCodeSet(src interface{}) (*codeSet, error) {
   248  	set := newCodeSet()
   249  
   250  	fset := token.NewFileSet()
   251  	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  
   256  	for _, decl := range f.Decls {
   257  		switch decl := decl.(type) {
   258  		case *ast.GenDecl:
   259  			// Add const, and type declarations
   260  			if !(decl.Tok == token.CONST || decl.Tok == token.TYPE) {
   261  				break
   262  			}
   263  
   264  			for _, spec := range decl.Specs {
   265  				elem, err := newCodeElem(decl.Tok, spec)
   266  				if err != nil {
   267  					return nil, err
   268  				}
   269  
   270  				set.add(elem)
   271  			}
   272  		case *ast.FuncDecl:
   273  			// Add func declarations
   274  			elem, err := newCodeElem(token.FUNC, decl)
   275  			if err != nil {
   276  				return nil, err
   277  			}
   278  
   279  			set.add(elem)
   280  		}
   281  	}
   282  
   283  	// Add file level comments
   284  	cmap := ast.NewCommentMap(fset, f, f.Comments)
   285  	for _, cGrp := range cmap[f] {
   286  		set.add(codeElem{token.COMMENT, cGrp.Text()})
   287  	}
   288  
   289  	return set, nil
   290  }
   291  
   292  // importName returns the identifier (PackageName) for an imported package
   293  func importName(iSpec *ast.ImportSpec) (string, error) {
   294  	if iSpec.Name == nil {
   295  		name, err := strconv.Unquote(iSpec.Path.Value)
   296  		if err != nil {
   297  			return "", err
   298  		}
   299  		return path.Base(name), nil
   300  	}
   301  	return iSpec.Name.Name, nil
   302  }
   303  
   304  // specGroups tracks grouped const/type specs with a map of line: groupID pairs
   305  type specGroups map[string]int
   306  
   307  // add spec source to group
   308  func (s specGroups) add(src string, groupID int) error {
   309  	srcBytes, err := format.Source(bytes.TrimSpace([]byte(src)))
   310  	if err != nil {
   311  		return err
   312  	}
   313  	s[string(srcBytes)] = groupID
   314  	return nil
   315  }
   316  
   317  // filterEmptyLines removes empty lines within groups of const/type specs.
   318  // Returns the filtered source.
   319  func (s specGroups) filterEmptyLines(src io.Reader) ([]byte, error) {
   320  	scanner := bufio.NewScanner(src)
   321  	var out bytes.Buffer
   322  
   323  	var emptyLines bytes.Buffer
   324  	prevGroupID := -1 // Initialize to invalid group
   325  	for scanner.Scan() {
   326  		line := bytes.TrimSpace(scanner.Bytes())
   327  
   328  		if len(line) == 0 {
   329  			fmt.Fprintf(&emptyLines, "%s\n", scanner.Bytes())
   330  			continue
   331  		}
   332  
   333  		// Discard emptyLines if previous non-empty line belonged to the same
   334  		// group as this line
   335  		if src, err := format.Source(line); err == nil {
   336  			groupID, ok := s[string(src)]
   337  			if ok && groupID == prevGroupID {
   338  				emptyLines.Reset()
   339  			}
   340  			prevGroupID = groupID
   341  		}
   342  
   343  		emptyLines.WriteTo(&out)
   344  		fmt.Fprintf(&out, "%s\n", scanner.Bytes())
   345  	}
   346  	if err := scanner.Err(); err != nil {
   347  		return nil, err
   348  	}
   349  	return out.Bytes(), nil
   350  }
   351  
   352  // filterImports removes unused imports from fileSrc, and returns a formatted src.
   353  func filterImports(fileSrc []byte) ([]byte, error) {
   354  	fset := token.NewFileSet()
   355  	file, err := parser.ParseFile(fset, "", fileSrc, parser.ParseComments)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  	cmap := ast.NewCommentMap(fset, file, file.Comments)
   360  
   361  	// create set of references to imported identifiers
   362  	keepImport := make(map[string]bool)
   363  	for _, u := range file.Unresolved {
   364  		keepImport[u.Name] = true
   365  	}
   366  
   367  	// filter import declarations
   368  	decls := file.Decls
   369  	file.Decls = file.Decls[:0]
   370  	for _, decl := range decls {
   371  		importDecl, ok := decl.(*ast.GenDecl)
   372  
   373  		// Keep non-import declarations
   374  		if !ok || importDecl.Tok != token.IMPORT {
   375  			file.Decls = append(file.Decls, decl)
   376  			continue
   377  		}
   378  
   379  		// Filter the import specs
   380  		specs := importDecl.Specs
   381  		importDecl.Specs = importDecl.Specs[:0]
   382  		for _, spec := range specs {
   383  			iSpec := spec.(*ast.ImportSpec)
   384  			name, err := importName(iSpec)
   385  			if err != nil {
   386  				return nil, err
   387  			}
   388  
   389  			if keepImport[name] {
   390  				importDecl.Specs = append(importDecl.Specs, iSpec)
   391  			}
   392  		}
   393  		if len(importDecl.Specs) > 0 {
   394  			file.Decls = append(file.Decls, importDecl)
   395  		}
   396  	}
   397  
   398  	// filter file.Imports
   399  	imports := file.Imports
   400  	file.Imports = file.Imports[:0]
   401  	for _, spec := range imports {
   402  		name, err := importName(spec)
   403  		if err != nil {
   404  			return nil, err
   405  		}
   406  
   407  		if keepImport[name] {
   408  			file.Imports = append(file.Imports, spec)
   409  		}
   410  	}
   411  	file.Comments = cmap.Filter(file).Comments()
   412  
   413  	var buf bytes.Buffer
   414  	err = format.Node(&buf, fset, file)
   415  	if err != nil {
   416  		return nil, err
   417  	}
   418  
   419  	return buf.Bytes(), nil
   420  }
   421  
   422  // merge extracts duplicate code from archFiles and merges it to mergeFile.
   423  // 1. Construct commonSet: the set of code that is idential in all archFiles.
   424  // 2. Write the code in commonSet to mergedFile.
   425  // 3. Remove the commonSet code from all archFiles.
   426  func merge(mergedFile string, archFiles ...string) error {
   427  	// extract and validate the GOOS part of the merged filename
   428  	goos, ok := getValidGOOS(mergedFile)
   429  	if !ok {
   430  		return fmt.Errorf("invalid GOOS in merged file name %s", mergedFile)
   431  	}
   432  
   433  	// Read architecture files
   434  	var inSrc []srcFile
   435  	for _, file := range archFiles {
   436  		src, err := ioutil.ReadFile(file)
   437  		if err != nil {
   438  			return fmt.Errorf("cannot read archfile %s: %w", file, err)
   439  		}
   440  
   441  		inSrc = append(inSrc, srcFile{file, src})
   442  	}
   443  
   444  	// 1. Construct the set of top-level declarations common for all files
   445  	commonSet, err := getCommonSet(inSrc)
   446  	if err != nil {
   447  		return err
   448  	}
   449  	if commonSet.isEmpty() {
   450  		// No common code => do not modify any files
   451  		return nil
   452  	}
   453  
   454  	// 2. Write the merged file
   455  	mergedSrc, err := filter(inSrc[0].src, commonSet.keepCommon)
   456  	if err != nil {
   457  		return err
   458  	}
   459  
   460  	f, err := os.Create(mergedFile)
   461  	if err != nil {
   462  		return err
   463  	}
   464  
   465  	buf := bufio.NewWriter(f)
   466  	fmt.Fprintln(buf, "// Code generated by mkmerge; DO NOT EDIT.")
   467  	fmt.Fprintln(buf)
   468  	fmt.Fprintf(buf, "//go:build %s\n", goos)
   469  	fmt.Fprintf(buf, "// +build %s\n", goos)
   470  	fmt.Fprintln(buf)
   471  	buf.Write(mergedSrc)
   472  
   473  	err = buf.Flush()
   474  	if err != nil {
   475  		return err
   476  	}
   477  	err = f.Close()
   478  	if err != nil {
   479  		return err
   480  	}
   481  
   482  	// 3. Remove duplicate declarations from the architecture files
   483  	for _, inFile := range inSrc {
   484  		src, err := filter(inFile.src, commonSet.keepArchSpecific)
   485  		if err != nil {
   486  			return err
   487  		}
   488  		err = ioutil.WriteFile(inFile.name, src, 0644)
   489  		if err != nil {
   490  			return err
   491  		}
   492  	}
   493  	return nil
   494  }
   495  
   496  func main() {
   497  	var mergedFile string
   498  	flag.StringVar(&mergedFile, "out", "", "Write merged code to `FILE`")
   499  	flag.Parse()
   500  
   501  	// Expand wildcards
   502  	var filenames []string
   503  	for _, arg := range flag.Args() {
   504  		matches, err := filepath.Glob(arg)
   505  		if err != nil {
   506  			fmt.Fprintf(os.Stderr, "Invalid command line argument %q: %v\n", arg, err)
   507  			os.Exit(1)
   508  		}
   509  		filenames = append(filenames, matches...)
   510  	}
   511  
   512  	if len(filenames) < 2 {
   513  		// No need to merge
   514  		return
   515  	}
   516  
   517  	err := merge(mergedFile, filenames...)
   518  	if err != nil {
   519  		fmt.Fprintf(os.Stderr, "Merge failed with error: %v\n", err)
   520  		os.Exit(1)
   521  	}
   522  }