github.com/aclements/go-misc@v0.0.0-20240129233631-2f6ede80790c/rtcheck/rewrite.go (about) 1 // Copyright 2016 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package main 6 7 import ( 8 "fmt" 9 "go/ast" 10 ) 11 12 func rewriteIdentList(v func(ast.Node) ast.Node, list []*ast.Ident) { 13 for i, x := range list { 14 list[i] = Rewrite(v, x).(*ast.Ident) 15 } 16 } 17 18 func rewriteExprList(v func(ast.Node) ast.Node, list []ast.Expr) { 19 for i, x := range list { 20 list[i] = Rewrite(v, x).(ast.Expr) 21 } 22 } 23 24 func rewriteStmtList(v func(ast.Node) ast.Node, list []ast.Stmt) { 25 for i, x := range list { 26 list[i] = Rewrite(v, x).(ast.Stmt) 27 } 28 } 29 30 func rewriteDeclList(v func(ast.Node) ast.Node, list []ast.Decl) { 31 for i, x := range list { 32 list[i] = Rewrite(v, x).(ast.Decl) 33 } 34 } 35 36 func Rewrite(v func(ast.Node) ast.Node, node ast.Node) ast.Node { 37 node = v(node) 38 39 // rewrite children 40 // (the order of the cases matches the order 41 // of the corresponding node types in ast.go) 42 switch n := node.(type) { 43 // Comments and fields 44 case *ast.Comment: 45 // nothing to do 46 47 case *ast.CommentGroup: 48 for i, c := range n.List { 49 n.List[i] = Rewrite(v, c).(*ast.Comment) 50 } 51 52 case *ast.Field: 53 if n.Doc != nil { 54 n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup) 55 } 56 rewriteIdentList(v, n.Names) 57 n.Type = Rewrite(v, n.Type).(ast.Expr) 58 if n.Tag != nil { 59 n.Tag = Rewrite(v, n.Tag).(*ast.BasicLit) 60 } 61 if n.Comment != nil { 62 n.Comment = Rewrite(v, n.Comment).(*ast.CommentGroup) 63 } 64 65 case *ast.FieldList: 66 for i, f := range n.List { 67 n.List[i] = Rewrite(v, f).(*ast.Field) 68 } 69 70 // Expressions 71 case *ast.BadExpr, *ast.Ident, *ast.BasicLit: 72 // nothing to do 73 74 case *ast.Ellipsis: 75 if n.Elt != nil { 76 n.Elt = Rewrite(v, n.Elt).(ast.Expr) 77 } 78 79 case *ast.FuncLit: 80 n.Type = Rewrite(v, n.Type).(*ast.FuncType) 81 n.Body = Rewrite(v, n.Body).(*ast.BlockStmt) 82 83 case *ast.CompositeLit: 84 if n.Type != nil { 85 n.Type = Rewrite(v, n.Type).(ast.Expr) 86 } 87 rewriteExprList(v, n.Elts) 88 89 case *ast.ParenExpr: 90 n.X = Rewrite(v, n.X).(ast.Expr) 91 92 case *ast.SelectorExpr: 93 n.X = Rewrite(v, n.X).(ast.Expr) 94 n.Sel = Rewrite(v, n.Sel).(*ast.Ident) 95 96 case *ast.IndexExpr: 97 n.X = Rewrite(v, n.X).(ast.Expr) 98 n.Index = Rewrite(v, n.Index).(ast.Expr) 99 100 case *ast.SliceExpr: 101 n.X = Rewrite(v, n.X).(ast.Expr) 102 if n.Low != nil { 103 n.Low = Rewrite(v, n.Low).(ast.Expr) 104 } 105 if n.High != nil { 106 n.High = Rewrite(v, n.High).(ast.Expr) 107 } 108 if n.Max != nil { 109 n.Max = Rewrite(v, n.Max).(ast.Expr) 110 } 111 112 case *ast.TypeAssertExpr: 113 n.X = Rewrite(v, n.X).(ast.Expr) 114 if n.Type != nil { 115 n.Type = Rewrite(v, n.Type).(ast.Expr) 116 } 117 118 case *ast.CallExpr: 119 n.Fun = Rewrite(v, n.Fun).(ast.Expr) 120 rewriteExprList(v, n.Args) 121 122 case *ast.StarExpr: 123 n.X = Rewrite(v, n.X).(ast.Expr) 124 125 case *ast.UnaryExpr: 126 n.X = Rewrite(v, n.X).(ast.Expr) 127 128 case *ast.BinaryExpr: 129 n.X = Rewrite(v, n.X).(ast.Expr) 130 n.Y = Rewrite(v, n.Y).(ast.Expr) 131 132 case *ast.KeyValueExpr: 133 n.Key = Rewrite(v, n.Key).(ast.Expr) 134 n.Value = Rewrite(v, n.Value).(ast.Expr) 135 136 // Types 137 case *ast.ArrayType: 138 if n.Len != nil { 139 n.Len = Rewrite(v, n.Len).(ast.Expr) 140 } 141 n.Elt = Rewrite(v, n.Elt).(ast.Expr) 142 143 case *ast.StructType: 144 n.Fields = Rewrite(v, n.Fields).(*ast.FieldList) 145 146 case *ast.FuncType: 147 if n.Params != nil { 148 n.Params = Rewrite(v, n.Params).(*ast.FieldList) 149 } 150 if n.Results != nil { 151 n.Results = Rewrite(v, n.Results).(*ast.FieldList) 152 } 153 154 case *ast.InterfaceType: 155 n.Methods = Rewrite(v, n.Methods).(*ast.FieldList) 156 157 case *ast.MapType: 158 n.Key = Rewrite(v, n.Key).(ast.Expr) 159 n.Value = Rewrite(v, n.Value).(ast.Expr) 160 161 case *ast.ChanType: 162 n.Value = Rewrite(v, n.Value).(ast.Expr) 163 164 // Statements 165 case *ast.BadStmt: 166 // nothing to do 167 168 case *ast.DeclStmt: 169 n.Decl = Rewrite(v, n.Decl).(ast.Decl) 170 171 case *ast.EmptyStmt: 172 // nothing to do 173 174 case *ast.LabeledStmt: 175 n.Label = Rewrite(v, n.Label).(*ast.Ident) 176 n.Stmt = Rewrite(v, n.Stmt).(ast.Stmt) 177 178 case *ast.ExprStmt: 179 n.X = Rewrite(v, n.X).(ast.Expr) 180 181 case *ast.SendStmt: 182 n.Chan = Rewrite(v, n.Chan).(ast.Expr) 183 n.Value = Rewrite(v, n.Value).(ast.Expr) 184 185 case *ast.IncDecStmt: 186 n.X = Rewrite(v, n.X).(ast.Expr) 187 188 case *ast.AssignStmt: 189 rewriteExprList(v, n.Lhs) 190 rewriteExprList(v, n.Rhs) 191 192 case *ast.GoStmt: 193 n.Call = Rewrite(v, n.Call).(*ast.CallExpr) 194 195 case *ast.DeferStmt: 196 n.Call = Rewrite(v, n.Call).(*ast.CallExpr) 197 198 case *ast.ReturnStmt: 199 rewriteExprList(v, n.Results) 200 201 case *ast.BranchStmt: 202 if n.Label != nil { 203 n.Label = Rewrite(v, n.Label).(*ast.Ident) 204 } 205 206 case *ast.BlockStmt: 207 rewriteStmtList(v, n.List) 208 209 case *ast.IfStmt: 210 if n.Init != nil { 211 n.Init = Rewrite(v, n.Init).(ast.Stmt) 212 } 213 n.Cond = Rewrite(v, n.Cond).(ast.Expr) 214 n.Body = Rewrite(v, n.Body).(*ast.BlockStmt) 215 if n.Else != nil { 216 n.Else = Rewrite(v, n.Else).(ast.Stmt) 217 } 218 219 case *ast.CaseClause: 220 rewriteExprList(v, n.List) 221 rewriteStmtList(v, n.Body) 222 223 case *ast.SwitchStmt: 224 if n.Init != nil { 225 n.Init = Rewrite(v, n.Init).(ast.Stmt) 226 } 227 if n.Tag != nil { 228 n.Tag = Rewrite(v, n.Tag).(ast.Expr) 229 } 230 n.Body = Rewrite(v, n.Body).(*ast.BlockStmt) 231 232 case *ast.TypeSwitchStmt: 233 if n.Init != nil { 234 n.Init = Rewrite(v, n.Init).(ast.Stmt) 235 } 236 n.Assign = Rewrite(v, n.Assign).(ast.Stmt) 237 n.Body = Rewrite(v, n.Body).(*ast.BlockStmt) 238 239 case *ast.CommClause: 240 if n.Comm != nil { 241 n.Comm = Rewrite(v, n.Comm).(ast.Stmt) 242 } 243 rewriteStmtList(v, n.Body) 244 245 case *ast.SelectStmt: 246 n.Body = Rewrite(v, n.Body).(*ast.BlockStmt) 247 248 case *ast.ForStmt: 249 if n.Init != nil { 250 n.Init = Rewrite(v, n.Init).(ast.Stmt) 251 } 252 if n.Cond != nil { 253 n.Cond = Rewrite(v, n.Cond).(ast.Expr) 254 } 255 if n.Post != nil { 256 n.Post = Rewrite(v, n.Post).(ast.Stmt) 257 } 258 n.Body = Rewrite(v, n.Body).(*ast.BlockStmt) 259 260 case *ast.RangeStmt: 261 if n.Key != nil { 262 n.Key = Rewrite(v, n.Key).(ast.Expr) 263 } 264 if n.Value != nil { 265 n.Value = Rewrite(v, n.Value).(ast.Expr) 266 } 267 n.X = Rewrite(v, n.X).(ast.Expr) 268 n.Body = Rewrite(v, n.Body).(*ast.BlockStmt) 269 270 // Declarations 271 case *ast.ImportSpec: 272 if n.Doc != nil { 273 n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup) 274 } 275 if n.Name != nil { 276 n.Name = Rewrite(v, n.Name).(*ast.Ident) 277 } 278 n.Path = Rewrite(v, n.Path).(*ast.BasicLit) 279 if n.Comment != nil { 280 n.Comment = Rewrite(v, n.Comment).(*ast.CommentGroup) 281 } 282 283 case *ast.ValueSpec: 284 if n.Doc != nil { 285 n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup) 286 } 287 rewriteIdentList(v, n.Names) 288 if n.Type != nil { 289 n.Type = Rewrite(v, n.Type).(ast.Expr) 290 } 291 rewriteExprList(v, n.Values) 292 if n.Comment != nil { 293 n.Comment = Rewrite(v, n.Comment).(*ast.CommentGroup) 294 } 295 296 case *ast.TypeSpec: 297 if n.Doc != nil { 298 n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup) 299 } 300 n.Name = Rewrite(v, n.Name).(*ast.Ident) 301 n.Type = Rewrite(v, n.Type).(ast.Expr) 302 if n.Comment != nil { 303 n.Comment = Rewrite(v, n.Comment).(*ast.CommentGroup) 304 } 305 306 case *ast.BadDecl: 307 // nothing to do 308 309 case *ast.GenDecl: 310 if n.Doc != nil { 311 n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup) 312 } 313 for i, s := range n.Specs { 314 n.Specs[i] = Rewrite(v, s).(ast.Spec) 315 } 316 317 case *ast.FuncDecl: 318 if n.Doc != nil { 319 n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup) 320 } 321 if n.Recv != nil { 322 n.Recv = Rewrite(v, n.Recv).(*ast.FieldList) 323 } 324 n.Name = Rewrite(v, n.Name).(*ast.Ident) 325 n.Type = Rewrite(v, n.Type).(*ast.FuncType) 326 if n.Body != nil { 327 n.Body = Rewrite(v, n.Body).(*ast.BlockStmt) 328 } 329 330 // Files and packages 331 case *ast.File: 332 if n.Doc != nil { 333 n.Doc = Rewrite(v, n.Doc).(*ast.CommentGroup) 334 } 335 n.Name = Rewrite(v, n.Name).(*ast.Ident) 336 rewriteDeclList(v, n.Decls) 337 // don't rewrite n.Comments - they have been 338 // visited already through the individual 339 // nodes 340 341 case *ast.Package: 342 for i, f := range n.Files { 343 n.Files[i] = Rewrite(v, f).(*ast.File) 344 } 345 346 default: 347 panic(fmt.Sprintf("rewrite: unexpected node type %T", n)) 348 } 349 350 return node 351 }