github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/internal/lsp/cache/parse.go (about) 1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package cache 6 7 import ( 8 "bytes" 9 "context" 10 "fmt" 11 "go/ast" 12 "go/parser" 13 "go/scanner" 14 "go/token" 15 "go/types" 16 "path/filepath" 17 "reflect" 18 "strconv" 19 "strings" 20 21 "github.com/jhump/golang-x-tools/internal/event" 22 "github.com/jhump/golang-x-tools/internal/lsp/debug/tag" 23 "github.com/jhump/golang-x-tools/internal/lsp/diff" 24 "github.com/jhump/golang-x-tools/internal/lsp/diff/myers" 25 "github.com/jhump/golang-x-tools/internal/lsp/protocol" 26 "github.com/jhump/golang-x-tools/internal/lsp/source" 27 "github.com/jhump/golang-x-tools/internal/memoize" 28 "github.com/jhump/golang-x-tools/internal/span" 29 errors "golang.org/x/xerrors" 30 ) 31 32 // parseKey uniquely identifies a parsed Go file. 33 type parseKey struct { 34 file source.FileIdentity 35 mode source.ParseMode 36 } 37 38 type parseGoHandle struct { 39 handle *memoize.Handle 40 file source.FileHandle 41 mode source.ParseMode 42 } 43 44 type parseGoData struct { 45 parsed *source.ParsedGoFile 46 47 // If true, we adjusted the AST to make it type check better, and 48 // it may not match the source code. 49 fixed bool 50 err error // any other errors 51 } 52 53 func (s *snapshot) parseGoHandle(ctx context.Context, fh source.FileHandle, mode source.ParseMode) *parseGoHandle { 54 key := parseKey{ 55 file: fh.FileIdentity(), 56 mode: mode, 57 } 58 if pgh := s.getGoFile(key); pgh != nil { 59 return pgh 60 } 61 parseHandle := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { 62 snapshot := arg.(*snapshot) 63 return parseGo(ctx, snapshot.FileSet(), fh, mode) 64 }, nil) 65 66 pgh := &parseGoHandle{ 67 handle: parseHandle, 68 file: fh, 69 mode: mode, 70 } 71 return s.addGoFile(key, pgh) 72 } 73 74 func (pgh *parseGoHandle) String() string { 75 return pgh.File().URI().Filename() 76 } 77 78 func (pgh *parseGoHandle) File() source.FileHandle { 79 return pgh.file 80 } 81 82 func (pgh *parseGoHandle) Mode() source.ParseMode { 83 return pgh.mode 84 } 85 86 func (s *snapshot) ParseGo(ctx context.Context, fh source.FileHandle, mode source.ParseMode) (*source.ParsedGoFile, error) { 87 pgh := s.parseGoHandle(ctx, fh, mode) 88 pgf, _, err := s.parseGo(ctx, pgh) 89 return pgf, err 90 } 91 92 func (s *snapshot) parseGo(ctx context.Context, pgh *parseGoHandle) (*source.ParsedGoFile, bool, error) { 93 if pgh.mode == source.ParseExported { 94 panic("only type checking should use Exported") 95 } 96 d, err := pgh.handle.Get(ctx, s.generation, s) 97 if err != nil { 98 return nil, false, err 99 } 100 data := d.(*parseGoData) 101 return data.parsed, data.fixed, data.err 102 } 103 104 type astCacheKey struct { 105 pkg packageHandleKey 106 uri span.URI 107 } 108 109 func (s *snapshot) astCacheData(ctx context.Context, spkg source.Package, pos token.Pos) (*astCacheData, error) { 110 pkg := spkg.(*pkg) 111 pkgHandle := s.getPackage(pkg.m.ID, pkg.mode) 112 if pkgHandle == nil { 113 return nil, fmt.Errorf("could not reconstruct package handle for %v", pkg.m.ID) 114 } 115 tok := s.FileSet().File(pos) 116 if tok == nil { 117 return nil, fmt.Errorf("no file for pos %v", pos) 118 } 119 pgf, err := pkg.File(span.URIFromPath(tok.Name())) 120 if err != nil { 121 return nil, err 122 } 123 astHandle := s.generation.Bind(astCacheKey{pkgHandle.key, pgf.URI}, func(ctx context.Context, arg memoize.Arg) interface{} { 124 return buildASTCache(pgf) 125 }, nil) 126 127 d, err := astHandle.Get(ctx, s.generation, s) 128 if err != nil { 129 return nil, err 130 } 131 data := d.(*astCacheData) 132 if data.err != nil { 133 return nil, data.err 134 } 135 return data, nil 136 } 137 138 func (s *snapshot) PosToDecl(ctx context.Context, spkg source.Package, pos token.Pos) (ast.Decl, error) { 139 data, err := s.astCacheData(ctx, spkg, pos) 140 if err != nil { 141 return nil, err 142 } 143 return data.posToDecl[pos], nil 144 } 145 146 func (s *snapshot) PosToField(ctx context.Context, spkg source.Package, pos token.Pos) (*ast.Field, error) { 147 data, err := s.astCacheData(ctx, spkg, pos) 148 if err != nil { 149 return nil, err 150 } 151 return data.posToField[pos], nil 152 } 153 154 type astCacheData struct { 155 err error 156 157 posToDecl map[token.Pos]ast.Decl 158 posToField map[token.Pos]*ast.Field 159 } 160 161 // buildASTCache builds caches to aid in quickly going from the typed 162 // world to the syntactic world. 163 func buildASTCache(pgf *source.ParsedGoFile) *astCacheData { 164 var ( 165 // path contains all ancestors, including n. 166 path []ast.Node 167 // decls contains all ancestors that are decls. 168 decls []ast.Decl 169 ) 170 171 data := &astCacheData{ 172 posToDecl: make(map[token.Pos]ast.Decl), 173 posToField: make(map[token.Pos]*ast.Field), 174 } 175 176 ast.Inspect(pgf.File, func(n ast.Node) bool { 177 if n == nil { 178 lastP := path[len(path)-1] 179 path = path[:len(path)-1] 180 if len(decls) > 0 && decls[len(decls)-1] == lastP { 181 decls = decls[:len(decls)-1] 182 } 183 return false 184 } 185 186 path = append(path, n) 187 188 switch n := n.(type) { 189 case *ast.Field: 190 addField := func(f ast.Node) { 191 if f.Pos().IsValid() { 192 data.posToField[f.Pos()] = n 193 if len(decls) > 0 { 194 data.posToDecl[f.Pos()] = decls[len(decls)-1] 195 } 196 } 197 } 198 199 // Add mapping for *ast.Field itself. This handles embedded 200 // fields which have no associated *ast.Ident name. 201 addField(n) 202 203 // Add mapping for each field name since you can have 204 // multiple names for the same type expression. 205 for _, name := range n.Names { 206 addField(name) 207 } 208 209 // Also map "X" in "...X" to the containing *ast.Field. This 210 // makes it easy to format variadic signature params 211 // properly. 212 if elips, ok := n.Type.(*ast.Ellipsis); ok && elips.Elt != nil { 213 addField(elips.Elt) 214 } 215 case *ast.FuncDecl: 216 decls = append(decls, n) 217 218 if n.Name != nil && n.Name.Pos().IsValid() { 219 data.posToDecl[n.Name.Pos()] = n 220 } 221 case *ast.GenDecl: 222 decls = append(decls, n) 223 224 for _, spec := range n.Specs { 225 switch spec := spec.(type) { 226 case *ast.TypeSpec: 227 if spec.Name != nil && spec.Name.Pos().IsValid() { 228 data.posToDecl[spec.Name.Pos()] = n 229 } 230 case *ast.ValueSpec: 231 for _, id := range spec.Names { 232 if id != nil && id.Pos().IsValid() { 233 data.posToDecl[id.Pos()] = n 234 } 235 } 236 } 237 } 238 } 239 240 return true 241 }) 242 243 return data 244 } 245 246 func parseGo(ctx context.Context, fset *token.FileSet, fh source.FileHandle, mode source.ParseMode) *parseGoData { 247 ctx, done := event.Start(ctx, "cache.parseGo", tag.File.Of(fh.URI().Filename())) 248 defer done() 249 250 ext := filepath.Ext(fh.URI().Filename()) 251 if ext != ".go" && ext != "" { // files generated by cgo have no extension 252 return &parseGoData{err: errors.Errorf("cannot parse non-Go file %s", fh.URI())} 253 } 254 src, err := fh.Read() 255 if err != nil { 256 return &parseGoData{err: err} 257 } 258 259 parserMode := parser.AllErrors | parser.ParseComments 260 if mode == source.ParseHeader { 261 parserMode = parser.ImportsOnly | parser.ParseComments 262 } 263 264 file, err := parser.ParseFile(fset, fh.URI().Filename(), src, parserMode) 265 var parseErr scanner.ErrorList 266 if err != nil { 267 // We passed a byte slice, so the only possible error is a parse error. 268 parseErr = err.(scanner.ErrorList) 269 } 270 271 tok := fset.File(file.Pos()) 272 if tok == nil { 273 // file.Pos is the location of the package declaration. If there was 274 // none, we can't find the token.File that ParseFile created, and we 275 // have no choice but to recreate it. 276 tok = fset.AddFile(fh.URI().Filename(), -1, len(src)) 277 tok.SetLinesForContent(src) 278 } 279 280 fixed := false 281 // If there were parse errors, attempt to fix them up. 282 if parseErr != nil { 283 // Fix any badly parsed parts of the AST. 284 fixed = fixAST(ctx, file, tok, src) 285 286 for i := 0; i < 10; i++ { 287 // Fix certain syntax errors that render the file unparseable. 288 newSrc := fixSrc(file, tok, src) 289 if newSrc == nil { 290 break 291 } 292 293 // If we thought there was something to fix 10 times in a row, 294 // it is likely we got stuck in a loop somehow. Log out a diff 295 // of the last changes we made to aid in debugging. 296 if i == 9 { 297 edits, err := myers.ComputeEdits(fh.URI(), string(src), string(newSrc)) 298 if err != nil { 299 event.Error(ctx, "error generating fixSrc diff", err, tag.File.Of(tok.Name())) 300 } else { 301 unified := diff.ToUnified("before", "after", string(src), edits) 302 event.Log(ctx, fmt.Sprintf("fixSrc loop - last diff:\n%v", unified), tag.File.Of(tok.Name())) 303 } 304 } 305 306 newFile, _ := parser.ParseFile(fset, fh.URI().Filename(), newSrc, parserMode) 307 if newFile != nil { 308 // Maintain the original parseError so we don't try formatting the doctored file. 309 file = newFile 310 src = newSrc 311 tok = fset.File(file.Pos()) 312 313 fixed = fixAST(ctx, file, tok, src) 314 } 315 } 316 } 317 318 return &parseGoData{ 319 parsed: &source.ParsedGoFile{ 320 URI: fh.URI(), 321 Mode: mode, 322 Src: src, 323 File: file, 324 Tok: tok, 325 Mapper: &protocol.ColumnMapper{ 326 URI: fh.URI(), 327 Converter: span.NewTokenConverter(fset, tok), 328 Content: src, 329 }, 330 ParseErr: parseErr, 331 }, 332 fixed: fixed, 333 } 334 } 335 336 // An unexportedFilter removes as much unexported AST from a set of Files as possible. 337 type unexportedFilter struct { 338 uses map[string]bool 339 } 340 341 // Filter records uses of unexported identifiers and filters out all other 342 // unexported declarations. 343 func (f *unexportedFilter) Filter(files []*ast.File) { 344 // Iterate to fixed point -- unexported types can include other unexported types. 345 oldLen := len(f.uses) 346 for { 347 for _, file := range files { 348 f.recordUses(file) 349 } 350 if len(f.uses) == oldLen { 351 break 352 } 353 oldLen = len(f.uses) 354 } 355 356 for _, file := range files { 357 var newDecls []ast.Decl 358 for _, decl := range file.Decls { 359 if f.filterDecl(decl) { 360 newDecls = append(newDecls, decl) 361 } 362 } 363 file.Decls = newDecls 364 file.Scope = nil 365 file.Unresolved = nil 366 file.Comments = nil 367 trimAST(file) 368 } 369 } 370 371 func (f *unexportedFilter) keep(ident *ast.Ident) bool { 372 return ast.IsExported(ident.Name) || f.uses[ident.Name] 373 } 374 375 func (f *unexportedFilter) filterDecl(decl ast.Decl) bool { 376 switch decl := decl.(type) { 377 case *ast.FuncDecl: 378 if ident := recvIdent(decl); ident != nil && !f.keep(ident) { 379 return false 380 } 381 return f.keep(decl.Name) 382 case *ast.GenDecl: 383 if decl.Tok == token.CONST { 384 // Constants can involve iota, and iota is hard to deal with. 385 return true 386 } 387 var newSpecs []ast.Spec 388 for _, spec := range decl.Specs { 389 if f.filterSpec(spec) { 390 newSpecs = append(newSpecs, spec) 391 } 392 } 393 decl.Specs = newSpecs 394 return len(newSpecs) != 0 395 case *ast.BadDecl: 396 return false 397 } 398 panic(fmt.Sprintf("unknown ast.Decl %T", decl)) 399 } 400 401 func (f *unexportedFilter) filterSpec(spec ast.Spec) bool { 402 switch spec := spec.(type) { 403 case *ast.ImportSpec: 404 return true 405 case *ast.ValueSpec: 406 var newNames []*ast.Ident 407 for _, name := range spec.Names { 408 if f.keep(name) { 409 newNames = append(newNames, name) 410 } 411 } 412 spec.Names = newNames 413 return len(spec.Names) != 0 414 case *ast.TypeSpec: 415 if !f.keep(spec.Name) { 416 return false 417 } 418 switch typ := spec.Type.(type) { 419 case *ast.StructType: 420 f.filterFieldList(typ.Fields) 421 case *ast.InterfaceType: 422 f.filterFieldList(typ.Methods) 423 } 424 return true 425 } 426 panic(fmt.Sprintf("unknown ast.Spec %T", spec)) 427 } 428 429 func (f *unexportedFilter) filterFieldList(fields *ast.FieldList) { 430 var newFields []*ast.Field 431 for _, field := range fields.List { 432 if len(field.Names) == 0 { 433 // Keep embedded fields: they can export methods and fields. 434 newFields = append(newFields, field) 435 } 436 for _, name := range field.Names { 437 if f.keep(name) { 438 newFields = append(newFields, field) 439 break 440 } 441 } 442 } 443 fields.List = newFields 444 } 445 446 func (f *unexportedFilter) recordUses(file *ast.File) { 447 for _, decl := range file.Decls { 448 switch decl := decl.(type) { 449 case *ast.FuncDecl: 450 // Ignore methods on dropped types. 451 if ident := recvIdent(decl); ident != nil && !f.keep(ident) { 452 break 453 } 454 // Ignore functions with dropped names. 455 if !f.keep(decl.Name) { 456 break 457 } 458 f.recordFuncType(decl.Type) 459 case *ast.GenDecl: 460 for _, spec := range decl.Specs { 461 switch spec := spec.(type) { 462 case *ast.ValueSpec: 463 for i, name := range spec.Names { 464 // Don't mess with constants -- iota is hard. 465 if f.keep(name) || decl.Tok == token.CONST { 466 f.recordIdents(spec.Type) 467 if len(spec.Values) > i { 468 f.recordIdents(spec.Values[i]) 469 } 470 } 471 } 472 case *ast.TypeSpec: 473 switch typ := spec.Type.(type) { 474 case *ast.StructType: 475 f.recordFieldUses(false, typ.Fields) 476 case *ast.InterfaceType: 477 f.recordFieldUses(false, typ.Methods) 478 } 479 } 480 } 481 } 482 } 483 } 484 485 // recvIdent returns the identifier of a method receiver, e.g. *int. 486 func recvIdent(decl *ast.FuncDecl) *ast.Ident { 487 if decl.Recv == nil || len(decl.Recv.List) == 0 { 488 return nil 489 } 490 x := decl.Recv.List[0].Type 491 if star, ok := x.(*ast.StarExpr); ok { 492 x = star.X 493 } 494 if ident, ok := x.(*ast.Ident); ok { 495 return ident 496 } 497 return nil 498 } 499 500 // recordIdents records unexported identifiers in an Expr in uses. 501 // These may be types, e.g. in map[key]value, function names, e.g. in foo(), 502 // or simple variable references. References that will be discarded, such 503 // as those in function literal bodies, are ignored. 504 func (f *unexportedFilter) recordIdents(x ast.Expr) { 505 ast.Inspect(x, func(n ast.Node) bool { 506 if n == nil { 507 return false 508 } 509 if complit, ok := n.(*ast.CompositeLit); ok { 510 // We clear out composite literal contents; just record their type. 511 f.recordIdents(complit.Type) 512 return false 513 } 514 if flit, ok := n.(*ast.FuncLit); ok { 515 f.recordFuncType(flit.Type) 516 return false 517 } 518 if ident, ok := n.(*ast.Ident); ok && !ast.IsExported(ident.Name) { 519 f.uses[ident.Name] = true 520 } 521 return true 522 }) 523 } 524 525 // recordFuncType records the types mentioned by a function type. 526 func (f *unexportedFilter) recordFuncType(x *ast.FuncType) { 527 f.recordFieldUses(true, x.Params) 528 f.recordFieldUses(true, x.Results) 529 } 530 531 // recordFieldUses records unexported identifiers used in fields, which may be 532 // struct members, interface members, or function parameter/results. 533 func (f *unexportedFilter) recordFieldUses(isParams bool, fields *ast.FieldList) { 534 if fields == nil { 535 return 536 } 537 for _, field := range fields.List { 538 if isParams { 539 // Parameter types of retained functions need to be retained. 540 f.recordIdents(field.Type) 541 continue 542 } 543 if ft, ok := field.Type.(*ast.FuncType); ok { 544 // Function declarations in interfaces need all their types retained. 545 f.recordFuncType(ft) 546 continue 547 } 548 if len(field.Names) == 0 { 549 // Embedded fields might contribute exported names. 550 f.recordIdents(field.Type) 551 } 552 for _, name := range field.Names { 553 // We only need normal fields if they're exported. 554 if ast.IsExported(name.Name) { 555 f.recordIdents(field.Type) 556 break 557 } 558 } 559 } 560 } 561 562 // ProcessErrors records additional uses from errors, returning the new uses 563 // and any unexpected errors. 564 func (f *unexportedFilter) ProcessErrors(errors []types.Error) (map[string]bool, []types.Error) { 565 var unexpected []types.Error 566 missing := map[string]bool{} 567 for _, err := range errors { 568 if strings.Contains(err.Msg, "missing return") { 569 continue 570 } 571 const undeclared = "undeclared name: " 572 if strings.HasPrefix(err.Msg, undeclared) { 573 missing[strings.TrimPrefix(err.Msg, undeclared)] = true 574 f.uses[strings.TrimPrefix(err.Msg, undeclared)] = true 575 continue 576 } 577 unexpected = append(unexpected, err) 578 } 579 return missing, unexpected 580 } 581 582 // trimAST clears any part of the AST not relevant to type checking 583 // expressions at pos. 584 func trimAST(file *ast.File) { 585 ast.Inspect(file, func(n ast.Node) bool { 586 if n == nil { 587 return false 588 } 589 switch n := n.(type) { 590 case *ast.FuncDecl: 591 n.Body = nil 592 case *ast.BlockStmt: 593 n.List = nil 594 case *ast.CaseClause: 595 n.Body = nil 596 case *ast.CommClause: 597 n.Body = nil 598 case *ast.CompositeLit: 599 // types.Info.Types for long slice/array literals are particularly 600 // expensive. Try to clear them out. 601 at, ok := n.Type.(*ast.ArrayType) 602 if !ok { 603 // Composite literal. No harm removing all its fields. 604 n.Elts = nil 605 break 606 } 607 // Removing the elements from an ellipsis array changes its type. 608 // Try to set the length explicitly so we can continue. 609 if _, ok := at.Len.(*ast.Ellipsis); ok { 610 length, ok := arrayLength(n) 611 if !ok { 612 break 613 } 614 at.Len = &ast.BasicLit{ 615 Kind: token.INT, 616 Value: fmt.Sprint(length), 617 ValuePos: at.Len.Pos(), 618 } 619 } 620 n.Elts = nil 621 } 622 return true 623 }) 624 } 625 626 // arrayLength returns the length of some simple forms of ellipsis array literal. 627 // Notably, it handles the tables in golang.org/x/text. 628 func arrayLength(array *ast.CompositeLit) (int, bool) { 629 litVal := func(expr ast.Expr) (int, bool) { 630 lit, ok := expr.(*ast.BasicLit) 631 if !ok { 632 return 0, false 633 } 634 val, err := strconv.ParseInt(lit.Value, 10, 64) 635 if err != nil { 636 return 0, false 637 } 638 return int(val), true 639 } 640 largestKey := -1 641 for _, elt := range array.Elts { 642 kve, ok := elt.(*ast.KeyValueExpr) 643 if !ok { 644 continue 645 } 646 switch key := kve.Key.(type) { 647 case *ast.BasicLit: 648 if val, ok := litVal(key); ok && largestKey < val { 649 largestKey = val 650 } 651 case *ast.BinaryExpr: 652 // golang.org/x/text uses subtraction (and only subtraction) in its indices. 653 if key.Op != token.SUB { 654 break 655 } 656 x, ok := litVal(key.X) 657 if !ok { 658 break 659 } 660 y, ok := litVal(key.Y) 661 if !ok { 662 break 663 } 664 if val := x - y; largestKey < val { 665 largestKey = val 666 } 667 } 668 } 669 if largestKey != -1 { 670 return largestKey + 1, true 671 } 672 return len(array.Elts), true 673 } 674 675 // fixAST inspects the AST and potentially modifies any *ast.BadStmts so that it can be 676 // type-checked more effectively. 677 // 678 // If fixAST returns true, the resulting AST is considered "fixed", meaning 679 // positions have been mangled, and type checker errors may not make sense. 680 func fixAST(ctx context.Context, n ast.Node, tok *token.File, src []byte) (fixed bool) { 681 var err error 682 walkASTWithParent(n, func(n, parent ast.Node) bool { 683 switch n := n.(type) { 684 case *ast.BadStmt: 685 if fixed = fixDeferOrGoStmt(n, parent, tok, src); fixed { 686 // Recursively fix in our fixed node. 687 _ = fixAST(ctx, parent, tok, src) 688 } else { 689 err = errors.Errorf("unable to parse defer or go from *ast.BadStmt: %v", err) 690 } 691 return false 692 case *ast.BadExpr: 693 if fixed = fixArrayType(n, parent, tok, src); fixed { 694 // Recursively fix in our fixed node. 695 _ = fixAST(ctx, parent, tok, src) 696 return false 697 } 698 699 // Fix cases where parser interprets if/for/switch "init" 700 // statement as "cond" expression, e.g.: 701 // 702 // // "i := foo" is init statement, not condition. 703 // for i := foo 704 // 705 fixInitStmt(n, parent, tok, src) 706 707 return false 708 case *ast.SelectorExpr: 709 // Fix cases where a keyword prefix results in a phantom "_" selector, e.g.: 710 // 711 // foo.var<> // want to complete to "foo.variance" 712 // 713 fixPhantomSelector(n, tok, src) 714 return true 715 716 case *ast.BlockStmt: 717 switch parent.(type) { 718 case *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt: 719 // Adjust closing curly brace of empty switch/select 720 // statements so we can complete inside them. 721 fixEmptySwitch(n, tok, src) 722 } 723 724 return true 725 default: 726 return true 727 } 728 }) 729 return fixed 730 } 731 732 // walkASTWithParent walks the AST rooted at n. The semantics are 733 // similar to ast.Inspect except it does not call f(nil). 734 func walkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) { 735 var ancestors []ast.Node 736 ast.Inspect(n, func(n ast.Node) (recurse bool) { 737 defer func() { 738 if recurse { 739 ancestors = append(ancestors, n) 740 } 741 }() 742 743 if n == nil { 744 ancestors = ancestors[:len(ancestors)-1] 745 return false 746 } 747 748 var parent ast.Node 749 if len(ancestors) > 0 { 750 parent = ancestors[len(ancestors)-1] 751 } 752 753 return f(n, parent) 754 }) 755 } 756 757 // fixSrc attempts to modify the file's source code to fix certain 758 // syntax errors that leave the rest of the file unparsed. 759 func fixSrc(f *ast.File, tok *token.File, src []byte) (newSrc []byte) { 760 walkASTWithParent(f, func(n, parent ast.Node) bool { 761 if newSrc != nil { 762 return false 763 } 764 765 switch n := n.(type) { 766 case *ast.BlockStmt: 767 newSrc = fixMissingCurlies(f, n, parent, tok, src) 768 case *ast.SelectorExpr: 769 newSrc = fixDanglingSelector(n, tok, src) 770 } 771 772 return newSrc == nil 773 }) 774 775 return newSrc 776 } 777 778 // fixMissingCurlies adds in curly braces for block statements that 779 // are missing curly braces. For example: 780 // 781 // if foo 782 // 783 // becomes 784 // 785 // if foo {} 786 func fixMissingCurlies(f *ast.File, b *ast.BlockStmt, parent ast.Node, tok *token.File, src []byte) []byte { 787 // If the "{" is already in the source code, there isn't anything to 788 // fix since we aren't missing curlies. 789 if b.Lbrace.IsValid() { 790 braceOffset, err := source.Offset(tok, b.Lbrace) 791 if err != nil { 792 return nil 793 } 794 if braceOffset < len(src) && src[braceOffset] == '{' { 795 return nil 796 } 797 } 798 799 parentLine := tok.Line(parent.Pos()) 800 801 if parentLine >= tok.LineCount() { 802 // If we are the last line in the file, no need to fix anything. 803 return nil 804 } 805 806 // Insert curlies at the end of parent's starting line. The parent 807 // is the statement that contains the block, e.g. *ast.IfStmt. The 808 // block's Pos()/End() can't be relied upon because they are based 809 // on the (missing) curly braces. We assume the statement is a 810 // single line for now and try sticking the curly braces at the end. 811 insertPos := tok.LineStart(parentLine+1) - 1 812 813 // Scootch position backwards until it's not in a comment. For example: 814 // 815 // if foo<> // some amazing comment | 816 // someOtherCode() 817 // 818 // insertPos will be located at "|", so we back it out of the comment. 819 didSomething := true 820 for didSomething { 821 didSomething = false 822 for _, c := range f.Comments { 823 if c.Pos() < insertPos && insertPos <= c.End() { 824 insertPos = c.Pos() 825 didSomething = true 826 } 827 } 828 } 829 830 // Bail out if line doesn't end in an ident or ".". This is to avoid 831 // cases like below where we end up making things worse by adding 832 // curlies: 833 // 834 // if foo && 835 // bar<> 836 switch precedingToken(insertPos, tok, src) { 837 case token.IDENT, token.PERIOD: 838 // ok 839 default: 840 return nil 841 } 842 843 var buf bytes.Buffer 844 buf.Grow(len(src) + 3) 845 offset, err := source.Offset(tok, insertPos) 846 if err != nil { 847 return nil 848 } 849 buf.Write(src[:offset]) 850 851 // Detect if we need to insert a semicolon to fix "for" loop situations like: 852 // 853 // for i := foo(); foo<> 854 // 855 // Just adding curlies is not sufficient to make things parse well. 856 if fs, ok := parent.(*ast.ForStmt); ok { 857 if _, ok := fs.Cond.(*ast.BadExpr); !ok { 858 if xs, ok := fs.Post.(*ast.ExprStmt); ok { 859 if _, ok := xs.X.(*ast.BadExpr); ok { 860 buf.WriteByte(';') 861 } 862 } 863 } 864 } 865 866 // Insert "{}" at insertPos. 867 buf.WriteByte('{') 868 buf.WriteByte('}') 869 buf.Write(src[offset:]) 870 return buf.Bytes() 871 } 872 873 // fixEmptySwitch moves empty switch/select statements' closing curly 874 // brace down one line. This allows us to properly detect incomplete 875 // "case" and "default" keywords as inside the switch statement. For 876 // example: 877 // 878 // switch { 879 // def<> 880 // } 881 // 882 // gets parsed like: 883 // 884 // switch { 885 // } 886 // 887 // Later we manually pull out the "def" token, but we need to detect 888 // that our "<>" position is inside the switch block. To do that we 889 // move the curly brace so it looks like: 890 // 891 // switch { 892 // 893 // } 894 // 895 func fixEmptySwitch(body *ast.BlockStmt, tok *token.File, src []byte) { 896 // We only care about empty switch statements. 897 if len(body.List) > 0 || !body.Rbrace.IsValid() { 898 return 899 } 900 901 // If the right brace is actually in the source code at the 902 // specified position, don't mess with it. 903 braceOffset, err := source.Offset(tok, body.Rbrace) 904 if err != nil { 905 return 906 } 907 if braceOffset < len(src) && src[braceOffset] == '}' { 908 return 909 } 910 911 braceLine := tok.Line(body.Rbrace) 912 if braceLine >= tok.LineCount() { 913 // If we are the last line in the file, no need to fix anything. 914 return 915 } 916 917 // Move the right brace down one line. 918 body.Rbrace = tok.LineStart(braceLine + 1) 919 } 920 921 // fixDanglingSelector inserts real "_" selector expressions in place 922 // of phantom "_" selectors. For example: 923 // 924 // func _() { 925 // x.<> 926 // } 927 // var x struct { i int } 928 // 929 // To fix completion at "<>", we insert a real "_" after the "." so the 930 // following declaration of "x" can be parsed and type checked 931 // normally. 932 func fixDanglingSelector(s *ast.SelectorExpr, tok *token.File, src []byte) []byte { 933 if !isPhantomUnderscore(s.Sel, tok, src) { 934 return nil 935 } 936 937 if !s.X.End().IsValid() { 938 return nil 939 } 940 941 insertOffset, err := source.Offset(tok, s.X.End()) 942 if err != nil { 943 return nil 944 } 945 // Insert directly after the selector's ".". 946 insertOffset++ 947 if src[insertOffset-1] != '.' { 948 return nil 949 } 950 951 var buf bytes.Buffer 952 buf.Grow(len(src) + 1) 953 buf.Write(src[:insertOffset]) 954 buf.WriteByte('_') 955 buf.Write(src[insertOffset:]) 956 return buf.Bytes() 957 } 958 959 // fixPhantomSelector tries to fix selector expressions with phantom 960 // "_" selectors. In particular, we check if the selector is a 961 // keyword, and if so we swap in an *ast.Ident with the keyword text. For example: 962 // 963 // foo.var 964 // 965 // yields a "_" selector instead of "var" since "var" is a keyword. 966 // 967 // TODO(rfindley): should this constitute an ast 'fix'? 968 func fixPhantomSelector(sel *ast.SelectorExpr, tok *token.File, src []byte) { 969 if !isPhantomUnderscore(sel.Sel, tok, src) { 970 return 971 } 972 973 // Only consider selectors directly abutting the selector ".". This 974 // avoids false positives in cases like: 975 // 976 // foo. // don't think "var" is our selector 977 // var bar = 123 978 // 979 if sel.Sel.Pos() != sel.X.End()+1 { 980 return 981 } 982 983 maybeKeyword := readKeyword(sel.Sel.Pos(), tok, src) 984 if maybeKeyword == "" { 985 return 986 } 987 988 replaceNode(sel, sel.Sel, &ast.Ident{ 989 Name: maybeKeyword, 990 NamePos: sel.Sel.Pos(), 991 }) 992 } 993 994 // isPhantomUnderscore reports whether the given ident is a phantom 995 // underscore. The parser sometimes inserts phantom underscores when 996 // it encounters otherwise unparseable situations. 997 func isPhantomUnderscore(id *ast.Ident, tok *token.File, src []byte) bool { 998 if id == nil || id.Name != "_" { 999 return false 1000 } 1001 1002 // Phantom underscore means the underscore is not actually in the 1003 // program text. 1004 offset, err := source.Offset(tok, id.Pos()) 1005 if err != nil { 1006 return false 1007 } 1008 return len(src) <= offset || src[offset] != '_' 1009 } 1010 1011 // fixInitStmt fixes cases where the parser misinterprets an 1012 // if/for/switch "init" statement as the "cond" conditional. In cases 1013 // like "if i := 0" the user hasn't typed the semicolon yet so the 1014 // parser is looking for the conditional expression. However, "i := 0" 1015 // are not valid expressions, so we get a BadExpr. 1016 // 1017 // fixInitStmt returns valid AST for the original source. 1018 func fixInitStmt(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte) { 1019 if !bad.Pos().IsValid() || !bad.End().IsValid() { 1020 return 1021 } 1022 1023 // Try to extract a statement from the BadExpr. 1024 start, err := source.Offset(tok, bad.Pos()) 1025 if err != nil { 1026 return 1027 } 1028 end, err := source.Offset(tok, bad.End()-1) 1029 if err != nil { 1030 return 1031 } 1032 stmtBytes := src[start : end+1] 1033 stmt, err := parseStmt(bad.Pos(), stmtBytes) 1034 if err != nil { 1035 return 1036 } 1037 1038 // If the parent statement doesn't already have an "init" statement, 1039 // move the extracted statement into the "init" field and insert a 1040 // dummy expression into the required "cond" field. 1041 switch p := parent.(type) { 1042 case *ast.IfStmt: 1043 if p.Init != nil { 1044 return 1045 } 1046 p.Init = stmt 1047 p.Cond = &ast.Ident{ 1048 Name: "_", 1049 NamePos: stmt.End(), 1050 } 1051 case *ast.ForStmt: 1052 if p.Init != nil { 1053 return 1054 } 1055 p.Init = stmt 1056 p.Cond = &ast.Ident{ 1057 Name: "_", 1058 NamePos: stmt.End(), 1059 } 1060 case *ast.SwitchStmt: 1061 if p.Init != nil { 1062 return 1063 } 1064 p.Init = stmt 1065 p.Tag = nil 1066 } 1067 } 1068 1069 // readKeyword reads the keyword starting at pos, if any. 1070 func readKeyword(pos token.Pos, tok *token.File, src []byte) string { 1071 var kwBytes []byte 1072 offset, err := source.Offset(tok, pos) 1073 if err != nil { 1074 return "" 1075 } 1076 for i := offset; i < len(src); i++ { 1077 // Use a simplified identifier check since keywords are always lowercase ASCII. 1078 if src[i] < 'a' || src[i] > 'z' { 1079 break 1080 } 1081 kwBytes = append(kwBytes, src[i]) 1082 1083 // Stop search at arbitrarily chosen too-long-for-a-keyword length. 1084 if len(kwBytes) > 15 { 1085 return "" 1086 } 1087 } 1088 1089 if kw := string(kwBytes); token.Lookup(kw).IsKeyword() { 1090 return kw 1091 } 1092 1093 return "" 1094 } 1095 1096 // fixArrayType tries to parse an *ast.BadExpr into an *ast.ArrayType. 1097 // go/parser often turns lone array types like "[]int" into BadExprs 1098 // if it isn't expecting a type. 1099 func fixArrayType(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte) bool { 1100 // Our expected input is a bad expression that looks like "[]someExpr". 1101 1102 from := bad.Pos() 1103 to := bad.End() 1104 1105 if !from.IsValid() || !to.IsValid() { 1106 return false 1107 } 1108 1109 exprBytes := make([]byte, 0, int(to-from)+3) 1110 // Avoid doing tok.Offset(to) since that panics if badExpr ends at EOF. 1111 // It also panics if the position is not in the range of the file, and 1112 // badExprs may not necessarily have good positions, so check first. 1113 fromOffset, err := source.Offset(tok, from) 1114 if err != nil { 1115 return false 1116 } 1117 toOffset, err := source.Offset(tok, to-1) 1118 if err != nil { 1119 return false 1120 } 1121 exprBytes = append(exprBytes, src[fromOffset:toOffset+1]...) 1122 exprBytes = bytes.TrimSpace(exprBytes) 1123 1124 // If our expression ends in "]" (e.g. "[]"), add a phantom selector 1125 // so we can complete directly after the "[]". 1126 if len(exprBytes) > 0 && exprBytes[len(exprBytes)-1] == ']' { 1127 exprBytes = append(exprBytes, '_') 1128 } 1129 1130 // Add "{}" to turn our ArrayType into a CompositeLit. This is to 1131 // handle the case of "[...]int" where we must make it a composite 1132 // literal to be parseable. 1133 exprBytes = append(exprBytes, '{', '}') 1134 1135 expr, err := parseExpr(from, exprBytes) 1136 if err != nil { 1137 return false 1138 } 1139 1140 cl, _ := expr.(*ast.CompositeLit) 1141 if cl == nil { 1142 return false 1143 } 1144 1145 at, _ := cl.Type.(*ast.ArrayType) 1146 if at == nil { 1147 return false 1148 } 1149 1150 return replaceNode(parent, bad, at) 1151 } 1152 1153 // precedingToken scans src to find the token preceding pos. 1154 func precedingToken(pos token.Pos, tok *token.File, src []byte) token.Token { 1155 s := &scanner.Scanner{} 1156 s.Init(tok, src, nil, 0) 1157 1158 var lastTok token.Token 1159 for { 1160 p, t, _ := s.Scan() 1161 if t == token.EOF || p >= pos { 1162 break 1163 } 1164 1165 lastTok = t 1166 } 1167 return lastTok 1168 } 1169 1170 // fixDeferOrGoStmt tries to parse an *ast.BadStmt into a defer or a go statement. 1171 // 1172 // go/parser packages a statement of the form "defer x." as an *ast.BadStmt because 1173 // it does not include a call expression. This means that go/types skips type-checking 1174 // this statement entirely, and we can't use the type information when completing. 1175 // Here, we try to generate a fake *ast.DeferStmt or *ast.GoStmt to put into the AST, 1176 // instead of the *ast.BadStmt. 1177 func fixDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src []byte) bool { 1178 // Check if we have a bad statement containing either a "go" or "defer". 1179 s := &scanner.Scanner{} 1180 s.Init(tok, src, nil, 0) 1181 1182 var ( 1183 pos token.Pos 1184 tkn token.Token 1185 ) 1186 for { 1187 if tkn == token.EOF { 1188 return false 1189 } 1190 if pos >= bad.From { 1191 break 1192 } 1193 pos, tkn, _ = s.Scan() 1194 } 1195 1196 var stmt ast.Stmt 1197 switch tkn { 1198 case token.DEFER: 1199 stmt = &ast.DeferStmt{ 1200 Defer: pos, 1201 } 1202 case token.GO: 1203 stmt = &ast.GoStmt{ 1204 Go: pos, 1205 } 1206 default: 1207 return false 1208 } 1209 1210 var ( 1211 from, to, last token.Pos 1212 lastToken token.Token 1213 braceDepth int 1214 phantomSelectors []token.Pos 1215 ) 1216 FindTo: 1217 for { 1218 to, tkn, _ = s.Scan() 1219 1220 if from == token.NoPos { 1221 from = to 1222 } 1223 1224 switch tkn { 1225 case token.EOF: 1226 break FindTo 1227 case token.SEMICOLON: 1228 // If we aren't in nested braces, end of statement means 1229 // end of expression. 1230 if braceDepth == 0 { 1231 break FindTo 1232 } 1233 case token.LBRACE: 1234 braceDepth++ 1235 } 1236 1237 // This handles the common dangling selector case. For example in 1238 // 1239 // defer fmt. 1240 // y := 1 1241 // 1242 // we notice the dangling period and end our expression. 1243 // 1244 // If the previous token was a "." and we are looking at a "}", 1245 // the period is likely a dangling selector and needs a phantom 1246 // "_". Likewise if the current token is on a different line than 1247 // the period, the period is likely a dangling selector. 1248 if lastToken == token.PERIOD && (tkn == token.RBRACE || tok.Line(to) > tok.Line(last)) { 1249 // Insert phantom "_" selector after the dangling ".". 1250 phantomSelectors = append(phantomSelectors, last+1) 1251 // If we aren't in a block then end the expression after the ".". 1252 if braceDepth == 0 { 1253 to = last + 1 1254 break 1255 } 1256 } 1257 1258 lastToken = tkn 1259 last = to 1260 1261 switch tkn { 1262 case token.RBRACE: 1263 braceDepth-- 1264 if braceDepth <= 0 { 1265 if braceDepth == 0 { 1266 // +1 to include the "}" itself. 1267 to += 1 1268 } 1269 break FindTo 1270 } 1271 } 1272 } 1273 1274 fromOffset, err := source.Offset(tok, from) 1275 if err != nil { 1276 return false 1277 } 1278 if !from.IsValid() || fromOffset >= len(src) { 1279 return false 1280 } 1281 1282 toOffset, err := source.Offset(tok, to) 1283 if err != nil { 1284 return false 1285 } 1286 if !to.IsValid() || toOffset >= len(src) { 1287 return false 1288 } 1289 1290 // Insert any phantom selectors needed to prevent dangling "." from messing 1291 // up the AST. 1292 exprBytes := make([]byte, 0, int(to-from)+len(phantomSelectors)) 1293 for i, b := range src[fromOffset:toOffset] { 1294 if len(phantomSelectors) > 0 && from+token.Pos(i) == phantomSelectors[0] { 1295 exprBytes = append(exprBytes, '_') 1296 phantomSelectors = phantomSelectors[1:] 1297 } 1298 exprBytes = append(exprBytes, b) 1299 } 1300 1301 if len(phantomSelectors) > 0 { 1302 exprBytes = append(exprBytes, '_') 1303 } 1304 1305 expr, err := parseExpr(from, exprBytes) 1306 if err != nil { 1307 return false 1308 } 1309 1310 // Package the expression into a fake *ast.CallExpr and re-insert 1311 // into the function. 1312 call := &ast.CallExpr{ 1313 Fun: expr, 1314 Lparen: to, 1315 Rparen: to, 1316 } 1317 1318 switch stmt := stmt.(type) { 1319 case *ast.DeferStmt: 1320 stmt.Call = call 1321 case *ast.GoStmt: 1322 stmt.Call = call 1323 } 1324 1325 return replaceNode(parent, bad, stmt) 1326 } 1327 1328 // parseStmt parses the statement in src and updates its position to 1329 // start at pos. 1330 func parseStmt(pos token.Pos, src []byte) (ast.Stmt, error) { 1331 // Wrap our expression to make it a valid Go file we can pass to ParseFile. 1332 fileSrc := bytes.Join([][]byte{ 1333 []byte("package fake;func _(){"), 1334 src, 1335 []byte("}"), 1336 }, nil) 1337 1338 // Use ParseFile instead of ParseExpr because ParseFile has 1339 // best-effort behavior, whereas ParseExpr fails hard on any error. 1340 fakeFile, err := parser.ParseFile(token.NewFileSet(), "", fileSrc, 0) 1341 if fakeFile == nil { 1342 return nil, errors.Errorf("error reading fake file source: %v", err) 1343 } 1344 1345 // Extract our expression node from inside the fake file. 1346 if len(fakeFile.Decls) == 0 { 1347 return nil, errors.Errorf("error parsing fake file: %v", err) 1348 } 1349 1350 fakeDecl, _ := fakeFile.Decls[0].(*ast.FuncDecl) 1351 if fakeDecl == nil || len(fakeDecl.Body.List) == 0 { 1352 return nil, errors.Errorf("no statement in %s: %v", src, err) 1353 } 1354 1355 stmt := fakeDecl.Body.List[0] 1356 1357 // parser.ParseFile returns undefined positions. 1358 // Adjust them for the current file. 1359 offsetPositions(stmt, pos-1-(stmt.Pos()-1)) 1360 1361 return stmt, nil 1362 } 1363 1364 // parseExpr parses the expression in src and updates its position to 1365 // start at pos. 1366 func parseExpr(pos token.Pos, src []byte) (ast.Expr, error) { 1367 stmt, err := parseStmt(pos, src) 1368 if err != nil { 1369 return nil, err 1370 } 1371 1372 exprStmt, ok := stmt.(*ast.ExprStmt) 1373 if !ok { 1374 return nil, errors.Errorf("no expr in %s: %v", src, err) 1375 } 1376 1377 return exprStmt.X, nil 1378 } 1379 1380 var tokenPosType = reflect.TypeOf(token.NoPos) 1381 1382 // offsetPositions applies an offset to the positions in an ast.Node. 1383 func offsetPositions(n ast.Node, offset token.Pos) { 1384 ast.Inspect(n, func(n ast.Node) bool { 1385 if n == nil { 1386 return false 1387 } 1388 1389 v := reflect.ValueOf(n).Elem() 1390 1391 switch v.Kind() { 1392 case reflect.Struct: 1393 for i := 0; i < v.NumField(); i++ { 1394 f := v.Field(i) 1395 if f.Type() != tokenPosType { 1396 continue 1397 } 1398 1399 if !f.CanSet() { 1400 continue 1401 } 1402 1403 // Don't offset invalid positions: they should stay invalid. 1404 if !token.Pos(f.Int()).IsValid() { 1405 continue 1406 } 1407 1408 f.SetInt(f.Int() + int64(offset)) 1409 } 1410 } 1411 1412 return true 1413 }) 1414 } 1415 1416 // replaceNode updates parent's child oldChild to be newChild. It 1417 // returns whether it replaced successfully. 1418 func replaceNode(parent, oldChild, newChild ast.Node) bool { 1419 if parent == nil || oldChild == nil || newChild == nil { 1420 return false 1421 } 1422 1423 parentVal := reflect.ValueOf(parent).Elem() 1424 if parentVal.Kind() != reflect.Struct { 1425 return false 1426 } 1427 1428 newChildVal := reflect.ValueOf(newChild) 1429 1430 tryReplace := func(v reflect.Value) bool { 1431 if !v.CanSet() || !v.CanInterface() { 1432 return false 1433 } 1434 1435 // If the existing value is oldChild, we found our child. Make 1436 // sure our newChild is assignable and then make the swap. 1437 if v.Interface() == oldChild && newChildVal.Type().AssignableTo(v.Type()) { 1438 v.Set(newChildVal) 1439 return true 1440 } 1441 1442 return false 1443 } 1444 1445 // Loop over parent's struct fields. 1446 for i := 0; i < parentVal.NumField(); i++ { 1447 f := parentVal.Field(i) 1448 1449 switch f.Kind() { 1450 // Check interface and pointer fields. 1451 case reflect.Interface, reflect.Ptr: 1452 if tryReplace(f) { 1453 return true 1454 } 1455 1456 // Search through any slice fields. 1457 case reflect.Slice: 1458 for i := 0; i < f.Len(); i++ { 1459 if tryReplace(f.Index(i)) { 1460 return true 1461 } 1462 } 1463 } 1464 } 1465 1466 return false 1467 }