github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/cmd/gstack/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"flag"
     6  	"fmt"
     7  	"go/ast"
     8  	"go/parser"
     9  	"go/token"
    10  	"io/fs"
    11  	"os"
    12  	"path/filepath"
    13  
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/cmd/gstack/utils"
    15  )
    16  
    17  func usage() {
    18  	fmt.Fprintf(os.Stderr, "usage: gstack [path]\n")
    19  	flag.PrintDefaults()
    20  }
    21  
    22  func getCallExpressionsFromExpr(expr ast.Expr) (listOfCalls []*ast.CallExpr) {
    23  	switch expr := expr.(type) {
    24  	case *ast.SelectorExpr:
    25  		listOfCalls = getCallExpressionsFromExpr(expr.X)
    26  	case *ast.IndexExpr:
    27  		listOfCalls = getCallExpressionsFromExpr(expr.X)
    28  	case *ast.StarExpr:
    29  		listOfCalls = getCallExpressionsFromExpr(expr.X)
    30  	case *ast.BinaryExpr:
    31  		listOfCalls = getCallExpressionsFromExpr(expr.X)
    32  		listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Y)...)
    33  	case *ast.CallExpr:
    34  		listOfCalls = append(listOfCalls, expr)
    35  		listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Fun)...)
    36  		for _, arg := range expr.Args {
    37  			listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(arg)...)
    38  		}
    39  	case *ast.CompositeLit:
    40  		for _, elt := range expr.Elts {
    41  			listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(elt)...)
    42  		}
    43  	case *ast.UnaryExpr:
    44  		listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.X)...)
    45  	case *ast.KeyValueExpr:
    46  		listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Value)...)
    47  	case *ast.FuncLit:
    48  		listOfCalls = append(listOfCalls, getListOfCallExpressionsFromBlockStmt(expr.Body)...)
    49  	}
    50  
    51  	return listOfCalls
    52  }
    53  
    54  func getExprFromDeclStmt(statement *ast.DeclStmt) (listOfExpressions []ast.Expr) {
    55  	decl, ok := statement.Decl.(*ast.GenDecl)
    56  	if !ok {
    57  		return listOfExpressions
    58  	}
    59  	for _, spec := range decl.Specs {
    60  		if spec, ok := spec.(*ast.ValueSpec); ok {
    61  			listOfExpressions = append(listOfExpressions, spec.Values...)
    62  		}
    63  	}
    64  
    65  	return listOfExpressions
    66  }
    67  
    68  func getCallExpressionsFromStmt(statement ast.Stmt) (listOfCallExpressions []*ast.CallExpr) {
    69  	var body *ast.BlockStmt
    70  	var listOfExpressions []ast.Expr
    71  	switch stmt := statement.(type) {
    72  	case *ast.IfStmt:
    73  		body = stmt.Body
    74  	case *ast.SwitchStmt:
    75  		body = stmt.Body
    76  	case *ast.TypeSwitchStmt:
    77  		body = stmt.Body
    78  	case *ast.SelectStmt:
    79  		body = stmt.Body
    80  	case *ast.ForStmt:
    81  		body = stmt.Body
    82  	case *ast.GoStmt:
    83  		if fun, ok := stmt.Call.Fun.(*ast.FuncLit); ok {
    84  			listOfCallExpressions = append(listOfCallExpressions, getListOfCallExpressionsFromBlockStmt(fun.Body)...)
    85  		} else {
    86  			listOfCallExpressions = append(listOfCallExpressions, stmt.Call)
    87  		}
    88  	case *ast.RangeStmt:
    89  		body = stmt.Body
    90  	case *ast.DeclStmt:
    91  		listOfExpressions = append(listOfExpressions, getExprFromDeclStmt(stmt)...)
    92  		for _, expr := range listOfExpressions {
    93  			listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr)...)
    94  		}
    95  	case *ast.CommClause:
    96  		stmts := stmt.Body
    97  		for _, stmt := range stmts {
    98  			listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(stmt)...)
    99  		}
   100  	case *ast.ExprStmt:
   101  		listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(stmt.X)...)
   102  	case *ast.AssignStmt:
   103  		for _, rh := range stmt.Rhs {
   104  			listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(rh)...)
   105  		}
   106  	case *ast.ReturnStmt:
   107  		for _, result := range stmt.Results {
   108  			listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(result)...)
   109  		}
   110  	}
   111  	if body != nil {
   112  		listOfCallExpressions = append(
   113  			listOfCallExpressions,
   114  			getListOfCallExpressionsFromBlockStmt(body)...,
   115  		)
   116  	}
   117  
   118  	return listOfCallExpressions
   119  }
   120  
   121  func getListOfCallExpressionsFromBlockStmt(block *ast.BlockStmt) (listOfCallExpressions []*ast.CallExpr) {
   122  	for _, statement := range block.List {
   123  		listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(statement)...)
   124  	}
   125  
   126  	return listOfCallExpressions
   127  }
   128  
   129  func format(src []byte, path string, fset *token.FileSet, file *ast.File) ([]byte, error) {
   130  	var listOfArgs []utils.FunctionIDArg
   131  	for _, f := range file.Decls {
   132  		var listOfCalls []*ast.CallExpr
   133  		fn, ok := f.(*ast.FuncDecl)
   134  		if !ok {
   135  			continue
   136  		}
   137  		listOfCalls = getListOfCallExpressionsFromBlockStmt(fn.Body)
   138  		for _, call := range listOfCalls {
   139  			if function, ok := call.Fun.(*ast.SelectorExpr); ok && function.Sel.Name == "FunctionID" {
   140  				pack, ok := function.X.(*ast.Ident)
   141  				if !ok {
   142  					continue
   143  				}
   144  				if pack.Name == "stack" && len(call.Args) == 1 {
   145  					listOfArgs = append(listOfArgs, utils.FunctionIDArg{
   146  						FuncDecl: fn,
   147  						ArgPos:   call.Args[0].Pos(),
   148  						ArgEnd:   call.Args[0].End(),
   149  					})
   150  				}
   151  			}
   152  		}
   153  	}
   154  	if len(listOfArgs) != 0 {
   155  		fixed, err := utils.FixSource(fset, path, src, listOfArgs)
   156  		if err != nil {
   157  			return nil, err
   158  		}
   159  
   160  		return fixed, nil
   161  	}
   162  
   163  	return src, nil
   164  }
   165  
   166  func processFile(src []byte, path string, fset *token.FileSet, file *ast.File, info os.FileInfo) error {
   167  	formatted, err := format(src, path, fset, file)
   168  	if err != nil {
   169  		return err
   170  	}
   171  	if !bytes.Equal(src, formatted) {
   172  		err = utils.WriteFile(path, formatted, info.Mode().Perm())
   173  		if err != nil {
   174  			return err
   175  		}
   176  	}
   177  
   178  	return nil
   179  }
   180  
   181  func main() {
   182  	flag.Usage = usage
   183  	flag.Parse()
   184  	args := flag.Args()
   185  
   186  	if len(args) != 1 {
   187  		flag.Usage()
   188  
   189  		return
   190  	}
   191  	_, err := os.Stat(args[0])
   192  	if err != nil {
   193  		panic(err)
   194  	}
   195  
   196  	fileSystem := os.DirFS(args[0])
   197  
   198  	err = fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error {
   199  		fset := token.NewFileSet()
   200  		if err != nil {
   201  			return err
   202  		}
   203  		if d.IsDir() {
   204  			return nil
   205  		}
   206  		if filepath.Ext(path) == ".go" {
   207  			info, err := os.Stat(path)
   208  			if err != nil {
   209  				return err
   210  			}
   211  			src, err := utils.ReadFile(path, info)
   212  			if err != nil {
   213  				return err
   214  			}
   215  			file, err := parser.ParseFile(fset, path, nil, 0)
   216  			if err != nil {
   217  				return err
   218  			}
   219  
   220  			return processFile(src, path, fset, file, info)
   221  		}
   222  
   223  		return nil
   224  	})
   225  	if err != nil {
   226  		panic(err)
   227  	}
   228  }