vitess.io/vitess@v0.16.2/go/tools/astfmtgen/main.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package main 18 19 import ( 20 "fmt" 21 "go/ast" 22 "go/format" 23 gotoken "go/token" 24 "go/types" 25 "log" 26 "os" 27 "path" 28 "strconv" 29 "strings" 30 31 "vitess.io/vitess/go/tools/common" 32 33 "golang.org/x/tools/go/ast/astutil" 34 "golang.org/x/tools/go/packages" 35 ) 36 37 func main() { 38 packageName := os.Args[1] 39 40 config := &packages.Config{ 41 Mode: packages.NeedName | 42 packages.NeedFiles | 43 packages.NeedCompiledGoFiles | 44 packages.NeedImports | 45 packages.NeedTypes | 46 packages.NeedSyntax | 47 packages.NeedTypesInfo, 48 } 49 pkgs, err := packages.Load(config, packageName) 50 if err != nil || common.PkgFailed(pkgs) { 51 log.Fatal("error loading packaged") 52 } 53 for _, pkg := range pkgs { 54 if pkg.Name == "sqlparser" { 55 rewriter := &Rewriter{pkg: pkg} 56 err := rewriter.Rewrite() 57 if err != nil { 58 log.Fatal(err.Error()) 59 } 60 } 61 } 62 } 63 64 type Rewriter struct { 65 pkg *packages.Package 66 astExpr *types.Interface 67 } 68 69 func (r *Rewriter) Rewrite() error { 70 scope := r.pkg.Types.Scope() 71 exprT := scope.Lookup("Expr").(*types.TypeName) 72 exprN := exprT.Type().(*types.Named).Underlying() 73 r.astExpr = exprN.(*types.Interface) 74 75 for i, file := range r.pkg.GoFiles { 76 dirname, filename := path.Split(file) 77 if filename == "ast_format.go" { 78 syntax := r.pkg.Syntax[i] 79 // Add fmt import since %d is handled by calling fmt.Sprintf("%d",...) 80 astutil.AddImport(r.pkg.Fset, syntax, "fmt") 81 astutil.Apply(syntax, r.replaceAstfmtCalls, nil) 82 83 f, err := os.Create(path.Join(dirname, "ast_format_fast.go")) 84 if err != nil { 85 return err 86 } 87 fmt.Fprintf(f, "// Code generated by ASTFmtGen. DO NOT EDIT.\n") 88 // format.Node is like printer.Fprintf but its output is formatted in 89 // the style of gofmt. 90 _ = format.Node(f, r.pkg.Fset, syntax) 91 f.Close() 92 } 93 } 94 return nil 95 } 96 97 func (r *Rewriter) replaceAstfmtCalls(cursor *astutil.Cursor) bool { 98 switch v := cursor.Node().(type) { 99 case *ast.Comment: 100 v.Text = strings.ReplaceAll(v.Text, " Format ", " formatFast ") 101 case *ast.FuncDecl: 102 if v.Name.Name == "Format" { 103 v.Name.Name = "formatFast" 104 } 105 case *ast.ExprStmt: 106 if call, ok := v.X.(*ast.CallExpr); ok { 107 switch r.methodName(call) { 108 case "astPrintf": 109 return r.rewriteAstPrintf(cursor, call) 110 case "literal": 111 callexpr := call.Fun.(*ast.SelectorExpr) 112 callexpr.Sel.Name = "WriteString" 113 return true 114 } 115 } 116 } 117 return true 118 } 119 120 func (r *Rewriter) methodName(n *ast.CallExpr) string { 121 if call, ok := n.Fun.(*ast.SelectorExpr); ok { 122 id := call.Sel 123 if id != nil && !r.pkg.TypesInfo.Types[id].IsType() { 124 return id.Name 125 } 126 } 127 return "" 128 } 129 130 func (r *Rewriter) rewriteLiteral(rcv ast.Expr, method string, arg ast.Expr) ast.Stmt { 131 expr := &ast.CallExpr{ 132 Fun: &ast.SelectorExpr{ 133 X: rcv, 134 Sel: &ast.Ident{Name: method}, 135 }, 136 Args: []ast.Expr{arg}, 137 } 138 return &ast.ExprStmt{X: expr} 139 } 140 141 func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr) bool { 142 callexpr := expr.Fun.(*ast.SelectorExpr) 143 lit := expr.Args[1].(*ast.BasicLit) 144 format, err := strconv.Unquote(lit.Value) 145 if err != nil { 146 panic("bad literal argument") 147 } 148 149 end := len(format) 150 fieldnum := 0 151 for i := 0; i < end; { 152 lasti := i 153 for i < end && format[i] != '%' { 154 i++ 155 } 156 if i > lasti { 157 var arg ast.Expr 158 var method string 159 var lit = format[lasti:i] 160 161 if len(lit) == 1 { 162 method = "WriteByte" 163 arg = &ast.BasicLit{ 164 Kind: gotoken.CHAR, 165 Value: strconv.QuoteRune(rune(lit[0])), 166 } 167 } else { 168 method = "WriteString" 169 arg = &ast.BasicLit{ 170 Kind: gotoken.STRING, 171 Value: strconv.Quote(lit), 172 } 173 } 174 175 cursor.InsertBefore(r.rewriteLiteral(callexpr.X, method, arg)) 176 } 177 if i >= end { 178 break 179 } 180 i++ // '%' 181 if format[i] == '#' { 182 i++ 183 } 184 185 token := format[i] 186 switch token { 187 case 'c': 188 cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteByte", expr.Args[2+fieldnum])) 189 case 's': 190 cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteString", expr.Args[2+fieldnum])) 191 case 'l', 'r', 'v': 192 leftExpr := expr.Args[0] 193 leftExprT := r.pkg.TypesInfo.Types[leftExpr].Type 194 195 rightExpr := expr.Args[2+fieldnum] 196 rightExprT := r.pkg.TypesInfo.Types[rightExpr].Type 197 198 var call ast.Expr 199 if types.Implements(leftExprT, r.astExpr) && types.Implements(rightExprT, r.astExpr) { 200 call = &ast.CallExpr{ 201 Fun: &ast.SelectorExpr{ 202 X: callexpr.X, 203 Sel: &ast.Ident{Name: "printExpr"}, 204 }, 205 Args: []ast.Expr{ 206 leftExpr, 207 rightExpr, 208 &ast.Ident{ 209 Name: strconv.FormatBool(token != 'r'), 210 }, 211 }, 212 } 213 } else { 214 call = &ast.CallExpr{ 215 Fun: &ast.SelectorExpr{ 216 X: rightExpr, 217 Sel: &ast.Ident{Name: "formatFast"}, 218 }, 219 Args: []ast.Expr{callexpr.X}, 220 } 221 } 222 cursor.InsertBefore(&ast.ExprStmt{X: call}) 223 case 'd': 224 call := &ast.CallExpr{ 225 Fun: &ast.Ident{Name: "fmt.Sprintf"}, 226 Args: []ast.Expr{&ast.BasicLit{Value: `"%d"`, Kind: gotoken.STRING}, expr.Args[2+fieldnum]}, 227 } 228 cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteString", call)) 229 default: 230 panic(fmt.Sprintf("unsupported escape %q", token)) 231 } 232 fieldnum++ 233 i++ 234 } 235 236 cursor.Delete() 237 return true 238 }