github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/go/ast/astutil/util.go (about) 1 package astutil 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/token" 7 "reflect" 8 "strings" 9 10 "golang.org/x/tools/go/ast/astutil" 11 ) 12 13 func IsIdent(expr ast.Expr, ident string) bool { 14 id, ok := expr.(*ast.Ident) 15 return ok && id.Name == ident 16 } 17 18 // isBlank returns whether id is the blank identifier "_". 19 // If id == nil, the answer is false. 20 func IsBlank(id ast.Expr) bool { 21 ident, _ := id.(*ast.Ident) 22 return ident != nil && ident.Name == "_" 23 } 24 25 // Deprecated: use code.IsIntegerLiteral instead. 26 func IsIntLiteral(expr ast.Expr, literal string) bool { 27 lit, ok := expr.(*ast.BasicLit) 28 return ok && lit.Kind == token.INT && lit.Value == literal 29 } 30 31 // Deprecated: use IsIntLiteral instead 32 func IsZero(expr ast.Expr) bool { 33 return IsIntLiteral(expr, "0") 34 } 35 36 func Preamble(f *ast.File) string { 37 cutoff := f.Package 38 if f.Doc != nil { 39 cutoff = f.Doc.Pos() 40 } 41 var out []string 42 for _, cmt := range f.Comments { 43 if cmt.Pos() >= cutoff { 44 break 45 } 46 out = append(out, cmt.Text()) 47 } 48 return strings.Join(out, "\n") 49 } 50 51 func GroupSpecs(fset *token.FileSet, specs []ast.Spec) [][]ast.Spec { 52 if len(specs) == 0 { 53 return nil 54 } 55 groups := make([][]ast.Spec, 1) 56 groups[0] = append(groups[0], specs[0]) 57 58 for _, spec := range specs[1:] { 59 g := groups[len(groups)-1] 60 if fset.PositionFor(spec.Pos(), false).Line-1 != 61 fset.PositionFor(g[len(g)-1].End(), false).Line { 62 63 groups = append(groups, nil) 64 } 65 66 groups[len(groups)-1] = append(groups[len(groups)-1], spec) 67 } 68 69 return groups 70 } 71 72 // Unparen returns e with any enclosing parentheses stripped. 73 func Unparen(e ast.Expr) ast.Expr { 74 for { 75 p, ok := e.(*ast.ParenExpr) 76 if !ok { 77 return e 78 } 79 e = p.X 80 } 81 } 82 83 // CopyExpr creates a deep copy of an expression. 84 // It doesn't support copying FuncLits and returns ok == false when encountering one. 85 func CopyExpr(node ast.Expr) (ast.Expr, bool) { 86 switch node := node.(type) { 87 case *ast.BasicLit: 88 cp := *node 89 return &cp, true 90 case *ast.BinaryExpr: 91 cp := *node 92 var ok1, ok2 bool 93 cp.X, ok1 = CopyExpr(cp.X) 94 cp.Y, ok2 = CopyExpr(cp.Y) 95 return &cp, ok1 && ok2 96 case *ast.CallExpr: 97 var ok bool 98 cp := *node 99 cp.Fun, ok = CopyExpr(cp.Fun) 100 if !ok { 101 return nil, false 102 } 103 cp.Args = make([]ast.Expr, len(node.Args)) 104 for i, v := range node.Args { 105 cp.Args[i], ok = CopyExpr(v) 106 if !ok { 107 return nil, false 108 } 109 } 110 return &cp, true 111 case *ast.CompositeLit: 112 var ok bool 113 cp := *node 114 cp.Type, ok = CopyExpr(cp.Type) 115 if !ok { 116 return nil, false 117 } 118 cp.Elts = make([]ast.Expr, len(node.Elts)) 119 for i, v := range node.Elts { 120 cp.Elts[i], ok = CopyExpr(v) 121 if !ok { 122 return nil, false 123 } 124 } 125 return &cp, true 126 case *ast.Ident: 127 cp := *node 128 return &cp, true 129 case *ast.IndexExpr: 130 var ok1, ok2 bool 131 cp := *node 132 cp.X, ok1 = CopyExpr(cp.X) 133 cp.Index, ok2 = CopyExpr(cp.Index) 134 return &cp, ok1 && ok2 135 case *ast.IndexListExpr: 136 var ok bool 137 cp := *node 138 cp.X, ok = CopyExpr(cp.X) 139 if !ok { 140 return nil, false 141 } 142 for i, v := range node.Indices { 143 cp.Indices[i], ok = CopyExpr(v) 144 if !ok { 145 return nil, false 146 } 147 } 148 return &cp, true 149 case *ast.KeyValueExpr: 150 var ok1, ok2 bool 151 cp := *node 152 cp.Key, ok1 = CopyExpr(cp.Key) 153 cp.Value, ok2 = CopyExpr(cp.Value) 154 return &cp, ok1 && ok2 155 case *ast.ParenExpr: 156 var ok bool 157 cp := *node 158 cp.X, ok = CopyExpr(cp.X) 159 return &cp, ok 160 case *ast.SelectorExpr: 161 var ok bool 162 cp := *node 163 cp.X, ok = CopyExpr(cp.X) 164 if !ok { 165 return nil, false 166 } 167 sel, ok := CopyExpr(cp.Sel) 168 if !ok { 169 // this is impossible 170 return nil, false 171 } 172 cp.Sel = sel.(*ast.Ident) 173 return &cp, true 174 case *ast.SliceExpr: 175 var ok1, ok2, ok3, ok4 bool 176 cp := *node 177 cp.X, ok1 = CopyExpr(cp.X) 178 cp.Low, ok2 = CopyExpr(cp.Low) 179 cp.High, ok3 = CopyExpr(cp.High) 180 cp.Max, ok4 = CopyExpr(cp.Max) 181 return &cp, ok1 && ok2 && ok3 && ok4 182 case *ast.StarExpr: 183 var ok bool 184 cp := *node 185 cp.X, ok = CopyExpr(cp.X) 186 return &cp, ok 187 case *ast.TypeAssertExpr: 188 var ok1, ok2 bool 189 cp := *node 190 cp.X, ok1 = CopyExpr(cp.X) 191 cp.Type, ok2 = CopyExpr(cp.Type) 192 return &cp, ok1 && ok2 193 case *ast.UnaryExpr: 194 var ok bool 195 cp := *node 196 cp.X, ok = CopyExpr(cp.X) 197 return &cp, ok 198 case *ast.MapType: 199 var ok1, ok2 bool 200 cp := *node 201 cp.Key, ok1 = CopyExpr(cp.Key) 202 cp.Value, ok2 = CopyExpr(cp.Value) 203 return &cp, ok1 && ok2 204 case *ast.ArrayType: 205 var ok1, ok2 bool 206 cp := *node 207 cp.Len, ok1 = CopyExpr(cp.Len) 208 cp.Elt, ok2 = CopyExpr(cp.Elt) 209 return &cp, ok1 && ok2 210 case *ast.Ellipsis: 211 var ok bool 212 cp := *node 213 cp.Elt, ok = CopyExpr(cp.Elt) 214 return &cp, ok 215 case *ast.InterfaceType: 216 cp := *node 217 return &cp, true 218 case *ast.StructType: 219 cp := *node 220 return &cp, true 221 case *ast.FuncLit, *ast.FuncType: 222 // TODO(dh): implement copying of function literals and types. 223 return nil, false 224 case *ast.ChanType: 225 var ok bool 226 cp := *node 227 cp.Value, ok = CopyExpr(cp.Value) 228 return &cp, ok 229 case nil: 230 return nil, true 231 default: 232 panic(fmt.Sprintf("unreachable: %T", node)) 233 } 234 } 235 236 func Equal(a, b ast.Node) bool { 237 if a == b { 238 return true 239 } 240 if a == nil || b == nil { 241 return false 242 } 243 if reflect.TypeOf(a) != reflect.TypeOf(b) { 244 return false 245 } 246 247 switch a := a.(type) { 248 case *ast.BasicLit: 249 b := b.(*ast.BasicLit) 250 return a.Kind == b.Kind && a.Value == b.Value 251 case *ast.BinaryExpr: 252 b := b.(*ast.BinaryExpr) 253 return Equal(a.X, b.X) && a.Op == b.Op && Equal(a.Y, b.Y) 254 case *ast.CallExpr: 255 b := b.(*ast.CallExpr) 256 if len(a.Args) != len(b.Args) { 257 return false 258 } 259 for i, arg := range a.Args { 260 if !Equal(arg, b.Args[i]) { 261 return false 262 } 263 } 264 return Equal(a.Fun, b.Fun) && 265 (a.Ellipsis == token.NoPos && b.Ellipsis == token.NoPos || a.Ellipsis != token.NoPos && b.Ellipsis != token.NoPos) 266 case *ast.CompositeLit: 267 b := b.(*ast.CompositeLit) 268 if len(a.Elts) != len(b.Elts) { 269 return false 270 } 271 for i, elt := range b.Elts { 272 if !Equal(elt, b.Elts[i]) { 273 return false 274 } 275 } 276 return Equal(a.Type, b.Type) && a.Incomplete == b.Incomplete 277 case *ast.Ident: 278 b := b.(*ast.Ident) 279 return a.Name == b.Name 280 case *ast.IndexExpr: 281 b := b.(*ast.IndexExpr) 282 return Equal(a.X, b.X) && Equal(a.Index, b.Index) 283 case *ast.IndexListExpr: 284 b := b.(*ast.IndexListExpr) 285 if len(a.Indices) != len(b.Indices) { 286 return false 287 } 288 for i, v := range a.Indices { 289 if !Equal(v, b.Indices[i]) { 290 return false 291 } 292 } 293 return Equal(a.X, b.X) 294 case *ast.KeyValueExpr: 295 b := b.(*ast.KeyValueExpr) 296 return Equal(a.Key, b.Key) && Equal(a.Value, b.Value) 297 case *ast.ParenExpr: 298 b := b.(*ast.ParenExpr) 299 return Equal(a.X, b.X) 300 case *ast.SelectorExpr: 301 b := b.(*ast.SelectorExpr) 302 return Equal(a.X, b.X) && Equal(a.Sel, b.Sel) 303 case *ast.SliceExpr: 304 b := b.(*ast.SliceExpr) 305 return Equal(a.X, b.X) && Equal(a.Low, b.Low) && Equal(a.High, b.High) && Equal(a.Max, b.Max) && a.Slice3 == b.Slice3 306 case *ast.StarExpr: 307 b := b.(*ast.StarExpr) 308 return Equal(a.X, b.X) 309 case *ast.TypeAssertExpr: 310 b := b.(*ast.TypeAssertExpr) 311 return Equal(a.X, b.X) && Equal(a.Type, b.Type) 312 case *ast.UnaryExpr: 313 b := b.(*ast.UnaryExpr) 314 return a.Op == b.Op && Equal(a.X, b.X) 315 case *ast.MapType: 316 b := b.(*ast.MapType) 317 return Equal(a.Key, b.Key) && Equal(a.Value, b.Value) 318 case *ast.ArrayType: 319 b := b.(*ast.ArrayType) 320 return Equal(a.Len, b.Len) && Equal(a.Elt, b.Elt) 321 case *ast.Ellipsis: 322 b := b.(*ast.Ellipsis) 323 return Equal(a.Elt, b.Elt) 324 case *ast.InterfaceType: 325 b := b.(*ast.InterfaceType) 326 return a.Incomplete == b.Incomplete && Equal(a.Methods, b.Methods) 327 case *ast.StructType: 328 b := b.(*ast.StructType) 329 return a.Incomplete == b.Incomplete && Equal(a.Fields, b.Fields) 330 case *ast.FuncLit: 331 // TODO(dh): support function literals 332 return false 333 case *ast.ChanType: 334 b := b.(*ast.ChanType) 335 return a.Dir == b.Dir && (a.Arrow == token.NoPos && b.Arrow == token.NoPos || a.Arrow != token.NoPos && b.Arrow != token.NoPos) 336 case *ast.FieldList: 337 b := b.(*ast.FieldList) 338 if len(a.List) != len(b.List) { 339 return false 340 } 341 for i, fieldA := range a.List { 342 if !Equal(fieldA, b.List[i]) { 343 return false 344 } 345 } 346 return true 347 case *ast.Field: 348 b := b.(*ast.Field) 349 if len(a.Names) != len(b.Names) { 350 return false 351 } 352 for j, name := range a.Names { 353 if !Equal(name, b.Names[j]) { 354 return false 355 } 356 } 357 if !Equal(a.Type, b.Type) || !Equal(a.Tag, b.Tag) { 358 return false 359 } 360 return true 361 default: 362 panic(fmt.Sprintf("unreachable: %T", a)) 363 } 364 } 365 366 func NegateDeMorgan(expr ast.Expr, recursive bool) ast.Expr { 367 switch expr := expr.(type) { 368 case *ast.BinaryExpr: 369 var out ast.BinaryExpr 370 switch expr.Op { 371 case token.EQL: 372 out.X = expr.X 373 out.Op = token.NEQ 374 out.Y = expr.Y 375 case token.LSS: 376 out.X = expr.X 377 out.Op = token.GEQ 378 out.Y = expr.Y 379 case token.GTR: 380 out.X = expr.X 381 out.Op = token.LEQ 382 out.Y = expr.Y 383 case token.NEQ: 384 out.X = expr.X 385 out.Op = token.EQL 386 out.Y = expr.Y 387 case token.LEQ: 388 out.X = expr.X 389 out.Op = token.GTR 390 out.Y = expr.Y 391 case token.GEQ: 392 out.X = expr.X 393 out.Op = token.LSS 394 out.Y = expr.Y 395 396 case token.LAND: 397 out.X = NegateDeMorgan(expr.X, recursive) 398 out.Op = token.LOR 399 out.Y = NegateDeMorgan(expr.Y, recursive) 400 case token.LOR: 401 out.X = NegateDeMorgan(expr.X, recursive) 402 out.Op = token.LAND 403 out.Y = NegateDeMorgan(expr.Y, recursive) 404 } 405 return &out 406 407 case *ast.ParenExpr: 408 if recursive { 409 return &ast.ParenExpr{ 410 X: NegateDeMorgan(expr.X, recursive), 411 } 412 } else { 413 return &ast.UnaryExpr{ 414 Op: token.NOT, 415 X: expr, 416 } 417 } 418 419 case *ast.UnaryExpr: 420 if expr.Op == token.NOT { 421 return expr.X 422 } else { 423 return &ast.UnaryExpr{ 424 Op: token.NOT, 425 X: expr, 426 } 427 } 428 429 default: 430 return &ast.UnaryExpr{ 431 Op: token.NOT, 432 X: expr, 433 } 434 } 435 } 436 437 func SimplifyParentheses(node ast.Expr) ast.Expr { 438 var changed bool 439 // XXX accept list of ops to operate on 440 // XXX copy AST node, don't modify in place 441 post := func(c *astutil.Cursor) bool { 442 out := c.Node() 443 if paren, ok := c.Node().(*ast.ParenExpr); ok { 444 out = paren.X 445 } 446 447 if binop, ok := out.(*ast.BinaryExpr); ok { 448 if right, ok := binop.Y.(*ast.BinaryExpr); ok && binop.Op == right.Op { 449 // XXX also check that Op is associative 450 451 root := binop 452 pivot := root.Y.(*ast.BinaryExpr) 453 root.Y = pivot.X 454 pivot.X = root 455 root = pivot 456 out = root 457 } 458 } 459 460 if out != c.Node() { 461 changed = true 462 c.Replace(out) 463 } 464 return true 465 } 466 467 for changed = true; changed; { 468 changed = false 469 node = astutil.Apply(node, nil, post).(ast.Expr) 470 } 471 472 return node 473 }