github.com/AESNooper/go/src@v0.0.0-20220218095104-b56a4ab1bbbb/sort/genzfunc.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build ignore
     6  // +build ignore
     7  
     8  // This program is run via "go generate" (via a directive in sort.go)
     9  // to generate zfuncversion.go.
    10  //
    11  // It copies sort.go to zfuncversion.go, only retaining funcs which
    12  // take a "data Interface" parameter, and renaming each to have a
    13  // "_func" suffix and taking a "data lessSwap" instead. It then rewrites
    14  // each internal function call to the appropriate _func variants.
    15  
    16  package main
    17  
    18  import (
    19  	"bytes"
    20  	"go/ast"
    21  	"go/format"
    22  	"go/parser"
    23  	"go/token"
    24  	"log"
    25  	"os"
    26  	"regexp"
    27  )
    28  
    29  var fset = token.NewFileSet()
    30  
    31  func main() {
    32  	af, err := parser.ParseFile(fset, "sort.go", nil, 0)
    33  	if err != nil {
    34  		log.Fatal(err)
    35  	}
    36  	af.Doc = nil
    37  	af.Imports = nil
    38  	af.Comments = nil
    39  
    40  	var newDecl []ast.Decl
    41  	for _, d := range af.Decls {
    42  		fd, ok := d.(*ast.FuncDecl)
    43  		if !ok {
    44  			continue
    45  		}
    46  		if fd.Recv != nil || fd.Name.IsExported() {
    47  			continue
    48  		}
    49  		typ := fd.Type
    50  		if len(typ.Params.List) < 1 {
    51  			continue
    52  		}
    53  		arg0 := typ.Params.List[0]
    54  		arg0Name := arg0.Names[0].Name
    55  		arg0Type := arg0.Type.(*ast.Ident)
    56  		if arg0Name != "data" || arg0Type.Name != "Interface" {
    57  			continue
    58  		}
    59  		arg0Type.Name = "lessSwap"
    60  
    61  		newDecl = append(newDecl, fd)
    62  	}
    63  	af.Decls = newDecl
    64  	ast.Walk(visitFunc(rewriteCalls), af)
    65  
    66  	var out bytes.Buffer
    67  	if err := format.Node(&out, fset, af); err != nil {
    68  		log.Fatalf("format.Node: %v", err)
    69  	}
    70  
    71  	// Get rid of blank lines after removal of comments.
    72  	src := regexp.MustCompile(`\n{2,}`).ReplaceAll(out.Bytes(), []byte("\n"))
    73  
    74  	// Add comments to each func, for the lost reader.
    75  	// This is so much easier than adding comments via the AST
    76  	// and trying to get position info correct.
    77  	src = regexp.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func"))
    78  
    79  	// Final gofmt.
    80  	src, err = format.Source(src)
    81  	if err != nil {
    82  		log.Fatalf("format.Source: %v on\n%s", err, src)
    83  	}
    84  
    85  	out.Reset()
    86  	out.WriteString(`// Code generated from sort.go using genzfunc.go; DO NOT EDIT.
    87  
    88  // Copyright 2016 The Go Authors. All rights reserved.
    89  // Use of this source code is governed by a BSD-style
    90  // license that can be found in the LICENSE file.
    91  
    92  `)
    93  	out.Write(src)
    94  
    95  	const target = "zfuncversion.go"
    96  	if err := os.WriteFile(target, out.Bytes(), 0644); err != nil {
    97  		log.Fatal(err)
    98  	}
    99  }
   100  
   101  type visitFunc func(ast.Node) ast.Visitor
   102  
   103  func (f visitFunc) Visit(n ast.Node) ast.Visitor { return f(n) }
   104  
   105  func rewriteCalls(n ast.Node) ast.Visitor {
   106  	ce, ok := n.(*ast.CallExpr)
   107  	if ok {
   108  		rewriteCall(ce)
   109  	}
   110  	return visitFunc(rewriteCalls)
   111  }
   112  
   113  func rewriteCall(ce *ast.CallExpr) {
   114  	ident, ok := ce.Fun.(*ast.Ident)
   115  	if !ok {
   116  		// e.g. skip SelectorExpr (data.Less(..) calls)
   117  		return
   118  	}
   119  	// skip casts
   120  	if ident.Name == "int" || ident.Name == "uint" {
   121  		return
   122  	}
   123  	if len(ce.Args) < 1 {
   124  		return
   125  	}
   126  	ident.Name += "_func"
   127  }