vitess.io/vitess@v0.16.2/go/tools/astfmtgen/main.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"fmt"
    21  	"go/ast"
    22  	"go/format"
    23  	gotoken "go/token"
    24  	"go/types"
    25  	"log"
    26  	"os"
    27  	"path"
    28  	"strconv"
    29  	"strings"
    30  
    31  	"vitess.io/vitess/go/tools/common"
    32  
    33  	"golang.org/x/tools/go/ast/astutil"
    34  	"golang.org/x/tools/go/packages"
    35  )
    36  
    37  func main() {
    38  	packageName := os.Args[1]
    39  
    40  	config := &packages.Config{
    41  		Mode: packages.NeedName |
    42  			packages.NeedFiles |
    43  			packages.NeedCompiledGoFiles |
    44  			packages.NeedImports |
    45  			packages.NeedTypes |
    46  			packages.NeedSyntax |
    47  			packages.NeedTypesInfo,
    48  	}
    49  	pkgs, err := packages.Load(config, packageName)
    50  	if err != nil || common.PkgFailed(pkgs) {
    51  		log.Fatal("error loading packaged")
    52  	}
    53  	for _, pkg := range pkgs {
    54  		if pkg.Name == "sqlparser" {
    55  			rewriter := &Rewriter{pkg: pkg}
    56  			err := rewriter.Rewrite()
    57  			if err != nil {
    58  				log.Fatal(err.Error())
    59  			}
    60  		}
    61  	}
    62  }
    63  
    64  type Rewriter struct {
    65  	pkg     *packages.Package
    66  	astExpr *types.Interface
    67  }
    68  
    69  func (r *Rewriter) Rewrite() error {
    70  	scope := r.pkg.Types.Scope()
    71  	exprT := scope.Lookup("Expr").(*types.TypeName)
    72  	exprN := exprT.Type().(*types.Named).Underlying()
    73  	r.astExpr = exprN.(*types.Interface)
    74  
    75  	for i, file := range r.pkg.GoFiles {
    76  		dirname, filename := path.Split(file)
    77  		if filename == "ast_format.go" {
    78  			syntax := r.pkg.Syntax[i]
    79  			// Add fmt import since %d is handled by calling fmt.Sprintf("%d",...)
    80  			astutil.AddImport(r.pkg.Fset, syntax, "fmt")
    81  			astutil.Apply(syntax, r.replaceAstfmtCalls, nil)
    82  
    83  			f, err := os.Create(path.Join(dirname, "ast_format_fast.go"))
    84  			if err != nil {
    85  				return err
    86  			}
    87  			fmt.Fprintf(f, "// Code generated by ASTFmtGen. DO NOT EDIT.\n")
    88  			// format.Node is like printer.Fprintf but its output is formatted in
    89  			// the style of gofmt.
    90  			_ = format.Node(f, r.pkg.Fset, syntax)
    91  			f.Close()
    92  		}
    93  	}
    94  	return nil
    95  }
    96  
    97  func (r *Rewriter) replaceAstfmtCalls(cursor *astutil.Cursor) bool {
    98  	switch v := cursor.Node().(type) {
    99  	case *ast.Comment:
   100  		v.Text = strings.ReplaceAll(v.Text, " Format ", " formatFast ")
   101  	case *ast.FuncDecl:
   102  		if v.Name.Name == "Format" {
   103  			v.Name.Name = "formatFast"
   104  		}
   105  	case *ast.ExprStmt:
   106  		if call, ok := v.X.(*ast.CallExpr); ok {
   107  			switch r.methodName(call) {
   108  			case "astPrintf":
   109  				return r.rewriteAstPrintf(cursor, call)
   110  			case "literal":
   111  				callexpr := call.Fun.(*ast.SelectorExpr)
   112  				callexpr.Sel.Name = "WriteString"
   113  				return true
   114  			}
   115  		}
   116  	}
   117  	return true
   118  }
   119  
   120  func (r *Rewriter) methodName(n *ast.CallExpr) string {
   121  	if call, ok := n.Fun.(*ast.SelectorExpr); ok {
   122  		id := call.Sel
   123  		if id != nil && !r.pkg.TypesInfo.Types[id].IsType() {
   124  			return id.Name
   125  		}
   126  	}
   127  	return ""
   128  }
   129  
   130  func (r *Rewriter) rewriteLiteral(rcv ast.Expr, method string, arg ast.Expr) ast.Stmt {
   131  	expr := &ast.CallExpr{
   132  		Fun: &ast.SelectorExpr{
   133  			X:   rcv,
   134  			Sel: &ast.Ident{Name: method},
   135  		},
   136  		Args: []ast.Expr{arg},
   137  	}
   138  	return &ast.ExprStmt{X: expr}
   139  }
   140  
   141  func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr) bool {
   142  	callexpr := expr.Fun.(*ast.SelectorExpr)
   143  	lit := expr.Args[1].(*ast.BasicLit)
   144  	format, err := strconv.Unquote(lit.Value)
   145  	if err != nil {
   146  		panic("bad literal argument")
   147  	}
   148  
   149  	end := len(format)
   150  	fieldnum := 0
   151  	for i := 0; i < end; {
   152  		lasti := i
   153  		for i < end && format[i] != '%' {
   154  			i++
   155  		}
   156  		if i > lasti {
   157  			var arg ast.Expr
   158  			var method string
   159  			var lit = format[lasti:i]
   160  
   161  			if len(lit) == 1 {
   162  				method = "WriteByte"
   163  				arg = &ast.BasicLit{
   164  					Kind:  gotoken.CHAR,
   165  					Value: strconv.QuoteRune(rune(lit[0])),
   166  				}
   167  			} else {
   168  				method = "WriteString"
   169  				arg = &ast.BasicLit{
   170  					Kind:  gotoken.STRING,
   171  					Value: strconv.Quote(lit),
   172  				}
   173  			}
   174  
   175  			cursor.InsertBefore(r.rewriteLiteral(callexpr.X, method, arg))
   176  		}
   177  		if i >= end {
   178  			break
   179  		}
   180  		i++ // '%'
   181  		if format[i] == '#' {
   182  			i++
   183  		}
   184  
   185  		token := format[i]
   186  		switch token {
   187  		case 'c':
   188  			cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteByte", expr.Args[2+fieldnum]))
   189  		case 's':
   190  			cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteString", expr.Args[2+fieldnum]))
   191  		case 'l', 'r', 'v':
   192  			leftExpr := expr.Args[0]
   193  			leftExprT := r.pkg.TypesInfo.Types[leftExpr].Type
   194  
   195  			rightExpr := expr.Args[2+fieldnum]
   196  			rightExprT := r.pkg.TypesInfo.Types[rightExpr].Type
   197  
   198  			var call ast.Expr
   199  			if types.Implements(leftExprT, r.astExpr) && types.Implements(rightExprT, r.astExpr) {
   200  				call = &ast.CallExpr{
   201  					Fun: &ast.SelectorExpr{
   202  						X:   callexpr.X,
   203  						Sel: &ast.Ident{Name: "printExpr"},
   204  					},
   205  					Args: []ast.Expr{
   206  						leftExpr,
   207  						rightExpr,
   208  						&ast.Ident{
   209  							Name: strconv.FormatBool(token != 'r'),
   210  						},
   211  					},
   212  				}
   213  			} else {
   214  				call = &ast.CallExpr{
   215  					Fun: &ast.SelectorExpr{
   216  						X:   rightExpr,
   217  						Sel: &ast.Ident{Name: "formatFast"},
   218  					},
   219  					Args: []ast.Expr{callexpr.X},
   220  				}
   221  			}
   222  			cursor.InsertBefore(&ast.ExprStmt{X: call})
   223  		case 'd':
   224  			call := &ast.CallExpr{
   225  				Fun:  &ast.Ident{Name: "fmt.Sprintf"},
   226  				Args: []ast.Expr{&ast.BasicLit{Value: `"%d"`, Kind: gotoken.STRING}, expr.Args[2+fieldnum]},
   227  			}
   228  			cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteString", call))
   229  		default:
   230  			panic(fmt.Sprintf("unsupported escape %q", token))
   231  		}
   232  		fieldnum++
   233  		i++
   234  	}
   235  
   236  	cursor.Delete()
   237  	return true
   238  }