github.com/aclements/go-misc@v0.0.0-20240129233631-2f6ede80790c/rtcheck/rewrite.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  )
    11  
    12  func rewriteIdentList(v func(ast.Node) ast.Node, list []*ast.Ident) {
    13  	for i, x := range list {
    14  		list[i] = Rewrite(v, x).(*ast.Ident)
    15  	}
    16  }
    17  
    18  func rewriteExprList(v func(ast.Node) ast.Node, list []ast.Expr) {
    19  	for i, x := range list {
    20  		list[i] = Rewrite(v, x).(ast.Expr)
    21  	}
    22  }
    23  
    24  func rewriteStmtList(v func(ast.Node) ast.Node, list []ast.Stmt) {
    25  	for i, x := range list {
    26  		list[i] = Rewrite(v, x).(ast.Stmt)
    27  	}
    28  }
    29  
    30  func rewriteDeclList(v func(ast.Node) ast.Node, list []ast.Decl) {
    31  	for i, x := range list {
    32  		list[i] = Rewrite(v, x).(ast.Decl)
    33  	}
    34  }
    35  
    36  func Rewrite(v func(ast.Node) ast.Node, node ast.Node) ast.Node {
    37  	node = v(node)
    38  
    39  	// rewrite children
    40  	// (the order of the cases matches the order
    41  	// of the corresponding node types in ast.go)
    42  	switch n := node.(type) {
    43  	// Comments and fields
    44  	case *ast.Comment:
    45  		// nothing to do
    46  
    47  	case *ast.CommentGroup:
    48  		for i, c := range n.List {
    49  			n.List[i] = Rewrite(v, c).(*ast.Comment)
    50  		}
    51  
    52  	case *ast.Field:
    53  		if n.Doc != nil {
    54  			n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup)
    55  		}
    56  		rewriteIdentList(v, n.Names)
    57  		n.Type = Rewrite(v, n.Type).(ast.Expr)
    58  		if n.Tag != nil {
    59  			n.Tag = Rewrite(v, n.Tag).(*ast.BasicLit)
    60  		}
    61  		if n.Comment != nil {
    62  			n.Comment = Rewrite(v, n.Comment).(*ast.CommentGroup)
    63  		}
    64  
    65  	case *ast.FieldList:
    66  		for i, f := range n.List {
    67  			n.List[i] = Rewrite(v, f).(*ast.Field)
    68  		}
    69  
    70  	// Expressions
    71  	case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
    72  		// nothing to do
    73  
    74  	case *ast.Ellipsis:
    75  		if n.Elt != nil {
    76  			n.Elt = Rewrite(v, n.Elt).(ast.Expr)
    77  		}
    78  
    79  	case *ast.FuncLit:
    80  		n.Type = Rewrite(v, n.Type).(*ast.FuncType)
    81  		n.Body = Rewrite(v, n.Body).(*ast.BlockStmt)
    82  
    83  	case *ast.CompositeLit:
    84  		if n.Type != nil {
    85  			n.Type = Rewrite(v, n.Type).(ast.Expr)
    86  		}
    87  		rewriteExprList(v, n.Elts)
    88  
    89  	case *ast.ParenExpr:
    90  		n.X = Rewrite(v, n.X).(ast.Expr)
    91  
    92  	case *ast.SelectorExpr:
    93  		n.X = Rewrite(v, n.X).(ast.Expr)
    94  		n.Sel = Rewrite(v, n.Sel).(*ast.Ident)
    95  
    96  	case *ast.IndexExpr:
    97  		n.X = Rewrite(v, n.X).(ast.Expr)
    98  		n.Index = Rewrite(v, n.Index).(ast.Expr)
    99  
   100  	case *ast.SliceExpr:
   101  		n.X = Rewrite(v, n.X).(ast.Expr)
   102  		if n.Low != nil {
   103  			n.Low = Rewrite(v, n.Low).(ast.Expr)
   104  		}
   105  		if n.High != nil {
   106  			n.High = Rewrite(v, n.High).(ast.Expr)
   107  		}
   108  		if n.Max != nil {
   109  			n.Max = Rewrite(v, n.Max).(ast.Expr)
   110  		}
   111  
   112  	case *ast.TypeAssertExpr:
   113  		n.X = Rewrite(v, n.X).(ast.Expr)
   114  		if n.Type != nil {
   115  			n.Type = Rewrite(v, n.Type).(ast.Expr)
   116  		}
   117  
   118  	case *ast.CallExpr:
   119  		n.Fun = Rewrite(v, n.Fun).(ast.Expr)
   120  		rewriteExprList(v, n.Args)
   121  
   122  	case *ast.StarExpr:
   123  		n.X = Rewrite(v, n.X).(ast.Expr)
   124  
   125  	case *ast.UnaryExpr:
   126  		n.X = Rewrite(v, n.X).(ast.Expr)
   127  
   128  	case *ast.BinaryExpr:
   129  		n.X = Rewrite(v, n.X).(ast.Expr)
   130  		n.Y = Rewrite(v, n.Y).(ast.Expr)
   131  
   132  	case *ast.KeyValueExpr:
   133  		n.Key = Rewrite(v, n.Key).(ast.Expr)
   134  		n.Value = Rewrite(v, n.Value).(ast.Expr)
   135  
   136  	// Types
   137  	case *ast.ArrayType:
   138  		if n.Len != nil {
   139  			n.Len = Rewrite(v, n.Len).(ast.Expr)
   140  		}
   141  		n.Elt = Rewrite(v, n.Elt).(ast.Expr)
   142  
   143  	case *ast.StructType:
   144  		n.Fields = Rewrite(v, n.Fields).(*ast.FieldList)
   145  
   146  	case *ast.FuncType:
   147  		if n.Params != nil {
   148  			n.Params = Rewrite(v, n.Params).(*ast.FieldList)
   149  		}
   150  		if n.Results != nil {
   151  			n.Results = Rewrite(v, n.Results).(*ast.FieldList)
   152  		}
   153  
   154  	case *ast.InterfaceType:
   155  		n.Methods = Rewrite(v, n.Methods).(*ast.FieldList)
   156  
   157  	case *ast.MapType:
   158  		n.Key = Rewrite(v, n.Key).(ast.Expr)
   159  		n.Value = Rewrite(v, n.Value).(ast.Expr)
   160  
   161  	case *ast.ChanType:
   162  		n.Value = Rewrite(v, n.Value).(ast.Expr)
   163  
   164  	// Statements
   165  	case *ast.BadStmt:
   166  		// nothing to do
   167  
   168  	case *ast.DeclStmt:
   169  		n.Decl = Rewrite(v, n.Decl).(ast.Decl)
   170  
   171  	case *ast.EmptyStmt:
   172  		// nothing to do
   173  
   174  	case *ast.LabeledStmt:
   175  		n.Label = Rewrite(v, n.Label).(*ast.Ident)
   176  		n.Stmt = Rewrite(v, n.Stmt).(ast.Stmt)
   177  
   178  	case *ast.ExprStmt:
   179  		n.X = Rewrite(v, n.X).(ast.Expr)
   180  
   181  	case *ast.SendStmt:
   182  		n.Chan = Rewrite(v, n.Chan).(ast.Expr)
   183  		n.Value = Rewrite(v, n.Value).(ast.Expr)
   184  
   185  	case *ast.IncDecStmt:
   186  		n.X = Rewrite(v, n.X).(ast.Expr)
   187  
   188  	case *ast.AssignStmt:
   189  		rewriteExprList(v, n.Lhs)
   190  		rewriteExprList(v, n.Rhs)
   191  
   192  	case *ast.GoStmt:
   193  		n.Call = Rewrite(v, n.Call).(*ast.CallExpr)
   194  
   195  	case *ast.DeferStmt:
   196  		n.Call = Rewrite(v, n.Call).(*ast.CallExpr)
   197  
   198  	case *ast.ReturnStmt:
   199  		rewriteExprList(v, n.Results)
   200  
   201  	case *ast.BranchStmt:
   202  		if n.Label != nil {
   203  			n.Label = Rewrite(v, n.Label).(*ast.Ident)
   204  		}
   205  
   206  	case *ast.BlockStmt:
   207  		rewriteStmtList(v, n.List)
   208  
   209  	case *ast.IfStmt:
   210  		if n.Init != nil {
   211  			n.Init = Rewrite(v, n.Init).(ast.Stmt)
   212  		}
   213  		n.Cond = Rewrite(v, n.Cond).(ast.Expr)
   214  		n.Body = Rewrite(v, n.Body).(*ast.BlockStmt)
   215  		if n.Else != nil {
   216  			n.Else = Rewrite(v, n.Else).(ast.Stmt)
   217  		}
   218  
   219  	case *ast.CaseClause:
   220  		rewriteExprList(v, n.List)
   221  		rewriteStmtList(v, n.Body)
   222  
   223  	case *ast.SwitchStmt:
   224  		if n.Init != nil {
   225  			n.Init = Rewrite(v, n.Init).(ast.Stmt)
   226  		}
   227  		if n.Tag != nil {
   228  			n.Tag = Rewrite(v, n.Tag).(ast.Expr)
   229  		}
   230  		n.Body = Rewrite(v, n.Body).(*ast.BlockStmt)
   231  
   232  	case *ast.TypeSwitchStmt:
   233  		if n.Init != nil {
   234  			n.Init = Rewrite(v, n.Init).(ast.Stmt)
   235  		}
   236  		n.Assign = Rewrite(v, n.Assign).(ast.Stmt)
   237  		n.Body = Rewrite(v, n.Body).(*ast.BlockStmt)
   238  
   239  	case *ast.CommClause:
   240  		if n.Comm != nil {
   241  			n.Comm = Rewrite(v, n.Comm).(ast.Stmt)
   242  		}
   243  		rewriteStmtList(v, n.Body)
   244  
   245  	case *ast.SelectStmt:
   246  		n.Body = Rewrite(v, n.Body).(*ast.BlockStmt)
   247  
   248  	case *ast.ForStmt:
   249  		if n.Init != nil {
   250  			n.Init = Rewrite(v, n.Init).(ast.Stmt)
   251  		}
   252  		if n.Cond != nil {
   253  			n.Cond = Rewrite(v, n.Cond).(ast.Expr)
   254  		}
   255  		if n.Post != nil {
   256  			n.Post = Rewrite(v, n.Post).(ast.Stmt)
   257  		}
   258  		n.Body = Rewrite(v, n.Body).(*ast.BlockStmt)
   259  
   260  	case *ast.RangeStmt:
   261  		if n.Key != nil {
   262  			n.Key = Rewrite(v, n.Key).(ast.Expr)
   263  		}
   264  		if n.Value != nil {
   265  			n.Value = Rewrite(v, n.Value).(ast.Expr)
   266  		}
   267  		n.X = Rewrite(v, n.X).(ast.Expr)
   268  		n.Body = Rewrite(v, n.Body).(*ast.BlockStmt)
   269  
   270  	// Declarations
   271  	case *ast.ImportSpec:
   272  		if n.Doc != nil {
   273  			n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup)
   274  		}
   275  		if n.Name != nil {
   276  			n.Name = Rewrite(v, n.Name).(*ast.Ident)
   277  		}
   278  		n.Path = Rewrite(v, n.Path).(*ast.BasicLit)
   279  		if n.Comment != nil {
   280  			n.Comment = Rewrite(v, n.Comment).(*ast.CommentGroup)
   281  		}
   282  
   283  	case *ast.ValueSpec:
   284  		if n.Doc != nil {
   285  			n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup)
   286  		}
   287  		rewriteIdentList(v, n.Names)
   288  		if n.Type != nil {
   289  			n.Type = Rewrite(v, n.Type).(ast.Expr)
   290  		}
   291  		rewriteExprList(v, n.Values)
   292  		if n.Comment != nil {
   293  			n.Comment = Rewrite(v, n.Comment).(*ast.CommentGroup)
   294  		}
   295  
   296  	case *ast.TypeSpec:
   297  		if n.Doc != nil {
   298  			n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup)
   299  		}
   300  		n.Name = Rewrite(v, n.Name).(*ast.Ident)
   301  		n.Type = Rewrite(v, n.Type).(ast.Expr)
   302  		if n.Comment != nil {
   303  			n.Comment = Rewrite(v, n.Comment).(*ast.CommentGroup)
   304  		}
   305  
   306  	case *ast.BadDecl:
   307  		// nothing to do
   308  
   309  	case *ast.GenDecl:
   310  		if n.Doc != nil {
   311  			n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup)
   312  		}
   313  		for i, s := range n.Specs {
   314  			n.Specs[i] = Rewrite(v, s).(ast.Spec)
   315  		}
   316  
   317  	case *ast.FuncDecl:
   318  		if n.Doc != nil {
   319  			n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup)
   320  		}
   321  		if n.Recv != nil {
   322  			n.Recv = Rewrite(v, n.Recv).(*ast.FieldList)
   323  		}
   324  		n.Name = Rewrite(v, n.Name).(*ast.Ident)
   325  		n.Type = Rewrite(v, n.Type).(*ast.FuncType)
   326  		if n.Body != nil {
   327  			n.Body = Rewrite(v, n.Body).(*ast.BlockStmt)
   328  		}
   329  
   330  	// Files and packages
   331  	case *ast.File:
   332  		if n.Doc != nil {
   333  			n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup)
   334  		}
   335  		n.Name = Rewrite(v, n.Name).(*ast.Ident)
   336  		rewriteDeclList(v, n.Decls)
   337  		// don't rewrite n.Comments - they have been
   338  		// visited already through the individual
   339  		// nodes
   340  
   341  	case *ast.Package:
   342  		for i, f := range n.Files {
   343  			n.Files[i] = Rewrite(v, f).(*ast.File)
   344  		}
   345  
   346  	default:
   347  		panic(fmt.Sprintf("rewrite: unexpected node type %T", n))
   348  	}
   349  
   350  	return node
   351  }