github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/cmd/gstack/main.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "flag" 6 "fmt" 7 "go/ast" 8 "go/parser" 9 "go/token" 10 "io/fs" 11 "os" 12 "path/filepath" 13 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/cmd/gstack/utils" 15 ) 16 17 func usage() { 18 fmt.Fprintf(os.Stderr, "usage: gstack [path]\n") 19 flag.PrintDefaults() 20 } 21 22 func getCallExpressionsFromExpr(expr ast.Expr) (listOfCalls []*ast.CallExpr) { 23 switch expr := expr.(type) { 24 case *ast.SelectorExpr: 25 listOfCalls = getCallExpressionsFromExpr(expr.X) 26 case *ast.IndexExpr: 27 listOfCalls = getCallExpressionsFromExpr(expr.X) 28 case *ast.StarExpr: 29 listOfCalls = getCallExpressionsFromExpr(expr.X) 30 case *ast.BinaryExpr: 31 listOfCalls = getCallExpressionsFromExpr(expr.X) 32 listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Y)...) 33 case *ast.CallExpr: 34 listOfCalls = append(listOfCalls, expr) 35 listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Fun)...) 36 for _, arg := range expr.Args { 37 listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(arg)...) 38 } 39 case *ast.CompositeLit: 40 for _, elt := range expr.Elts { 41 listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(elt)...) 42 } 43 case *ast.UnaryExpr: 44 listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.X)...) 45 case *ast.KeyValueExpr: 46 listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Value)...) 47 case *ast.FuncLit: 48 listOfCalls = append(listOfCalls, getListOfCallExpressionsFromBlockStmt(expr.Body)...) 49 } 50 51 return listOfCalls 52 } 53 54 func getExprFromDeclStmt(statement *ast.DeclStmt) (listOfExpressions []ast.Expr) { 55 decl, ok := statement.Decl.(*ast.GenDecl) 56 if !ok { 57 return listOfExpressions 58 } 59 for _, spec := range decl.Specs { 60 if spec, ok := spec.(*ast.ValueSpec); ok { 61 listOfExpressions = append(listOfExpressions, spec.Values...) 62 } 63 } 64 65 return listOfExpressions 66 } 67 68 func getCallExpressionsFromStmt(statement ast.Stmt) (listOfCallExpressions []*ast.CallExpr) { 69 var body *ast.BlockStmt 70 var listOfExpressions []ast.Expr 71 switch stmt := statement.(type) { 72 case *ast.IfStmt: 73 body = stmt.Body 74 case *ast.SwitchStmt: 75 body = stmt.Body 76 case *ast.TypeSwitchStmt: 77 body = stmt.Body 78 case *ast.SelectStmt: 79 body = stmt.Body 80 case *ast.ForStmt: 81 body = stmt.Body 82 case *ast.GoStmt: 83 if fun, ok := stmt.Call.Fun.(*ast.FuncLit); ok { 84 listOfCallExpressions = append(listOfCallExpressions, getListOfCallExpressionsFromBlockStmt(fun.Body)...) 85 } else { 86 listOfCallExpressions = append(listOfCallExpressions, stmt.Call) 87 } 88 case *ast.RangeStmt: 89 body = stmt.Body 90 case *ast.DeclStmt: 91 listOfExpressions = append(listOfExpressions, getExprFromDeclStmt(stmt)...) 92 for _, expr := range listOfExpressions { 93 listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr)...) 94 } 95 case *ast.CommClause: 96 stmts := stmt.Body 97 for _, stmt := range stmts { 98 listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(stmt)...) 99 } 100 case *ast.ExprStmt: 101 listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(stmt.X)...) 102 case *ast.AssignStmt: 103 for _, rh := range stmt.Rhs { 104 listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(rh)...) 105 } 106 case *ast.ReturnStmt: 107 for _, result := range stmt.Results { 108 listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(result)...) 109 } 110 } 111 if body != nil { 112 listOfCallExpressions = append( 113 listOfCallExpressions, 114 getListOfCallExpressionsFromBlockStmt(body)..., 115 ) 116 } 117 118 return listOfCallExpressions 119 } 120 121 func getListOfCallExpressionsFromBlockStmt(block *ast.BlockStmt) (listOfCallExpressions []*ast.CallExpr) { 122 for _, statement := range block.List { 123 listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(statement)...) 124 } 125 126 return listOfCallExpressions 127 } 128 129 func format(src []byte, path string, fset *token.FileSet, file *ast.File) ([]byte, error) { 130 var listOfArgs []utils.FunctionIDArg 131 for _, f := range file.Decls { 132 var listOfCalls []*ast.CallExpr 133 fn, ok := f.(*ast.FuncDecl) 134 if !ok { 135 continue 136 } 137 listOfCalls = getListOfCallExpressionsFromBlockStmt(fn.Body) 138 for _, call := range listOfCalls { 139 if function, ok := call.Fun.(*ast.SelectorExpr); ok && function.Sel.Name == "FunctionID" { 140 pack, ok := function.X.(*ast.Ident) 141 if !ok { 142 continue 143 } 144 if pack.Name == "stack" && len(call.Args) == 1 { 145 listOfArgs = append(listOfArgs, utils.FunctionIDArg{ 146 FuncDecl: fn, 147 ArgPos: call.Args[0].Pos(), 148 ArgEnd: call.Args[0].End(), 149 }) 150 } 151 } 152 } 153 } 154 if len(listOfArgs) != 0 { 155 fixed, err := utils.FixSource(fset, path, src, listOfArgs) 156 if err != nil { 157 return nil, err 158 } 159 160 return fixed, nil 161 } 162 163 return src, nil 164 } 165 166 func processFile(src []byte, path string, fset *token.FileSet, file *ast.File, info os.FileInfo) error { 167 formatted, err := format(src, path, fset, file) 168 if err != nil { 169 return err 170 } 171 if !bytes.Equal(src, formatted) { 172 err = utils.WriteFile(path, formatted, info.Mode().Perm()) 173 if err != nil { 174 return err 175 } 176 } 177 178 return nil 179 } 180 181 func main() { 182 flag.Usage = usage 183 flag.Parse() 184 args := flag.Args() 185 186 if len(args) != 1 { 187 flag.Usage() 188 189 return 190 } 191 _, err := os.Stat(args[0]) 192 if err != nil { 193 panic(err) 194 } 195 196 fileSystem := os.DirFS(args[0]) 197 198 err = fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error { 199 fset := token.NewFileSet() 200 if err != nil { 201 return err 202 } 203 if d.IsDir() { 204 return nil 205 } 206 if filepath.Ext(path) == ".go" { 207 info, err := os.Stat(path) 208 if err != nil { 209 return err 210 } 211 src, err := utils.ReadFile(path, info) 212 if err != nil { 213 return err 214 } 215 file, err := parser.ParseFile(fset, path, nil, 0) 216 if err != nil { 217 return err 218 } 219 220 return processFile(src, path, fset, file, info) 221 } 222 223 return nil 224 }) 225 if err != nil { 226 panic(err) 227 } 228 }