github.com/jd-ly/tools@v0.5.7/internal/lsp/source/util.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/printer"
    13  	"go/token"
    14  	"go/types"
    15  	"path/filepath"
    16  	"regexp"
    17  	"sort"
    18  	"strconv"
    19  	"strings"
    20  
    21  	"github.com/jd-ly/tools/internal/lsp/protocol"
    22  	"github.com/jd-ly/tools/internal/span"
    23  	errors "golang.org/x/xerrors"
    24  )
    25  
    26  // MappedRange provides mapped protocol.Range for a span.Range, accounting for
    27  // UTF-16 code points.
    28  type MappedRange struct {
    29  	spanRange span.Range
    30  	m         *protocol.ColumnMapper
    31  
    32  	// protocolRange is the result of converting the spanRange using the mapper.
    33  	// It is computed on-demand.
    34  	protocolRange *protocol.Range
    35  }
    36  
    37  // NewMappedRange returns a MappedRange for the given start and end token.Pos.
    38  func NewMappedRange(fset *token.FileSet, m *protocol.ColumnMapper, start, end token.Pos) MappedRange {
    39  	return MappedRange{
    40  		spanRange: span.Range{
    41  			FileSet:   fset,
    42  			Start:     start,
    43  			End:       end,
    44  			Converter: m.Converter,
    45  		},
    46  		m: m,
    47  	}
    48  }
    49  
    50  func (s MappedRange) Range() (protocol.Range, error) {
    51  	if s.protocolRange == nil {
    52  		spn, err := s.spanRange.Span()
    53  		if err != nil {
    54  			return protocol.Range{}, err
    55  		}
    56  		prng, err := s.m.Range(spn)
    57  		if err != nil {
    58  			return protocol.Range{}, err
    59  		}
    60  		s.protocolRange = &prng
    61  	}
    62  	return *s.protocolRange, nil
    63  }
    64  
    65  func (s MappedRange) Span() (span.Span, error) {
    66  	return s.spanRange.Span()
    67  }
    68  
    69  func (s MappedRange) SpanRange() span.Range {
    70  	return s.spanRange
    71  }
    72  
    73  func (s MappedRange) URI() span.URI {
    74  	return s.m.URI
    75  }
    76  
    77  // GetParsedFile is a convenience function that extracts the Package and
    78  // ParsedGoFile for a file in a Snapshot. pkgPolicy is one of NarrowestPackage/
    79  // WidestPackage.
    80  func GetParsedFile(ctx context.Context, snapshot Snapshot, fh FileHandle, pkgPolicy PackageFilter) (Package, *ParsedGoFile, error) {
    81  	pkg, err := snapshot.PackageForFile(ctx, fh.URI(), TypecheckWorkspace, pkgPolicy)
    82  	if err != nil {
    83  		return nil, nil, err
    84  	}
    85  	pgh, err := pkg.File(fh.URI())
    86  	return pkg, pgh, err
    87  }
    88  
    89  func IsGenerated(ctx context.Context, snapshot Snapshot, uri span.URI) bool {
    90  	fh, err := snapshot.GetFile(ctx, uri)
    91  	if err != nil {
    92  		return false
    93  	}
    94  	pgf, err := snapshot.ParseGo(ctx, fh, ParseHeader)
    95  	if err != nil {
    96  		return false
    97  	}
    98  	tok := snapshot.FileSet().File(pgf.File.Pos())
    99  	if tok == nil {
   100  		return false
   101  	}
   102  	for _, commentGroup := range pgf.File.Comments {
   103  		for _, comment := range commentGroup.List {
   104  			if matched := generatedRx.MatchString(comment.Text); matched {
   105  				// Check if comment is at the beginning of the line in source.
   106  				if pos := tok.Position(comment.Slash); pos.Column == 1 {
   107  					return true
   108  				}
   109  			}
   110  		}
   111  	}
   112  	return false
   113  }
   114  
   115  func nodeToProtocolRange(snapshot Snapshot, pkg Package, n ast.Node) (protocol.Range, error) {
   116  	mrng, err := posToMappedRange(snapshot, pkg, n.Pos(), n.End())
   117  	if err != nil {
   118  		return protocol.Range{}, err
   119  	}
   120  	return mrng.Range()
   121  }
   122  
   123  func objToMappedRange(snapshot Snapshot, pkg Package, obj types.Object) (MappedRange, error) {
   124  	if pkgName, ok := obj.(*types.PkgName); ok {
   125  		// An imported Go package has a package-local, unqualified name.
   126  		// When the name matches the imported package name, there is no
   127  		// identifier in the import spec with the local package name.
   128  		//
   129  		// For example:
   130  		// 		import "go/ast" 	// name "ast" matches package name
   131  		// 		import a "go/ast"  	// name "a" does not match package name
   132  		//
   133  		// When the identifier does not appear in the source, have the range
   134  		// of the object be the import path, including quotes.
   135  		if pkgName.Imported().Name() == pkgName.Name() {
   136  			return posToMappedRange(snapshot, pkg, obj.Pos(), obj.Pos()+token.Pos(len(pkgName.Imported().Path())+2))
   137  		}
   138  	}
   139  	return nameToMappedRange(snapshot, pkg, obj.Pos(), obj.Name())
   140  }
   141  
   142  func nameToMappedRange(snapshot Snapshot, pkg Package, pos token.Pos, name string) (MappedRange, error) {
   143  	return posToMappedRange(snapshot, pkg, pos, pos+token.Pos(len(name)))
   144  }
   145  
   146  func posToMappedRange(snapshot Snapshot, pkg Package, pos, end token.Pos) (MappedRange, error) {
   147  	logicalFilename := snapshot.FileSet().File(pos).Position(pos).Filename
   148  	pgf, _, err := findFileInDeps(pkg, span.URIFromPath(logicalFilename))
   149  	if err != nil {
   150  		return MappedRange{}, err
   151  	}
   152  	if !pos.IsValid() {
   153  		return MappedRange{}, errors.Errorf("invalid position for %v", pos)
   154  	}
   155  	if !end.IsValid() {
   156  		return MappedRange{}, errors.Errorf("invalid position for %v", end)
   157  	}
   158  	return NewMappedRange(snapshot.FileSet(), pgf.Mapper, pos, end), nil
   159  }
   160  
   161  // Matches cgo generated comment as well as the proposed standard:
   162  //	https://golang.org/s/generatedcode
   163  var generatedRx = regexp.MustCompile(`// .*DO NOT EDIT\.?`)
   164  
   165  func DetectLanguage(langID, filename string) FileKind {
   166  	switch langID {
   167  	case "go":
   168  		return Go
   169  	case "go.mod":
   170  		return Mod
   171  	case "go.sum":
   172  		return Sum
   173  	}
   174  	// Fallback to detecting the language based on the file extension.
   175  	switch filepath.Ext(filename) {
   176  	case ".mod":
   177  		return Mod
   178  	case ".sum":
   179  		return Sum
   180  	default: // fallback to Go
   181  		return Go
   182  	}
   183  }
   184  
   185  func (k FileKind) String() string {
   186  	switch k {
   187  	case Mod:
   188  		return "go.mod"
   189  	case Sum:
   190  		return "go.sum"
   191  	default:
   192  		return "go"
   193  	}
   194  }
   195  
   196  // nodeAtPos returns the index and the node whose position is contained inside
   197  // the node list.
   198  func nodeAtPos(nodes []ast.Node, pos token.Pos) (ast.Node, int) {
   199  	if nodes == nil {
   200  		return nil, -1
   201  	}
   202  	for i, node := range nodes {
   203  		if node.Pos() <= pos && pos <= node.End() {
   204  			return node, i
   205  		}
   206  	}
   207  	return nil, -1
   208  }
   209  
   210  // IsInterface returns if a types.Type is an interface
   211  func IsInterface(T types.Type) bool {
   212  	return T != nil && types.IsInterface(T)
   213  }
   214  
   215  // FormatNode returns the "pretty-print" output for an ast node.
   216  func FormatNode(fset *token.FileSet, n ast.Node) string {
   217  	var buf strings.Builder
   218  	if err := printer.Fprint(&buf, fset, n); err != nil {
   219  		return ""
   220  	}
   221  	return buf.String()
   222  }
   223  
   224  // Deref returns a pointer's element type, traversing as many levels as needed.
   225  // Otherwise it returns typ.
   226  func Deref(typ types.Type) types.Type {
   227  	for {
   228  		p, ok := typ.Underlying().(*types.Pointer)
   229  		if !ok {
   230  			return typ
   231  		}
   232  		typ = p.Elem()
   233  	}
   234  }
   235  
   236  func SortDiagnostics(d []*Diagnostic) {
   237  	sort.Slice(d, func(i int, j int) bool {
   238  		return CompareDiagnostic(d[i], d[j]) < 0
   239  	})
   240  }
   241  
   242  func CompareDiagnostic(a, b *Diagnostic) int {
   243  	if r := protocol.CompareRange(a.Range, b.Range); r != 0 {
   244  		return r
   245  	}
   246  	if a.Source < b.Source {
   247  		return -1
   248  	}
   249  	if a.Message < b.Message {
   250  		return -1
   251  	}
   252  	if a.Message == b.Message {
   253  		return 0
   254  	}
   255  	return 1
   256  }
   257  
   258  // FindPosInPackage finds the parsed file for a position in a given search
   259  // package.
   260  func FindPosInPackage(snapshot Snapshot, searchpkg Package, pos token.Pos) (*ParsedGoFile, Package, error) {
   261  	tok := snapshot.FileSet().File(pos)
   262  	if tok == nil {
   263  		return nil, nil, errors.Errorf("no file for pos in package %s", searchpkg.ID())
   264  	}
   265  	uri := span.URIFromPath(tok.Name())
   266  
   267  	pgf, pkg, err := findFileInDeps(searchpkg, uri)
   268  	if err != nil {
   269  		return nil, nil, err
   270  	}
   271  	return pgf, pkg, nil
   272  }
   273  
   274  // findFileInDeps finds uri in pkg or its dependencies.
   275  func findFileInDeps(pkg Package, uri span.URI) (*ParsedGoFile, Package, error) {
   276  	queue := []Package{pkg}
   277  	seen := make(map[string]bool)
   278  
   279  	for len(queue) > 0 {
   280  		pkg := queue[0]
   281  		queue = queue[1:]
   282  		seen[pkg.ID()] = true
   283  
   284  		if pgf, err := pkg.File(uri); err == nil {
   285  			return pgf, pkg, nil
   286  		}
   287  		for _, dep := range pkg.Imports() {
   288  			if !seen[dep.ID()] {
   289  				queue = append(queue, dep)
   290  			}
   291  		}
   292  	}
   293  	return nil, nil, errors.Errorf("no file for %s in package %s", uri, pkg.ID())
   294  }
   295  
   296  // MarshalArgs encodes the given arguments to json.RawMessages. This function
   297  // is used to construct arguments to a protocol.Command.
   298  //
   299  // Example usage:
   300  //
   301  //   jsonArgs, err := EncodeArgs(1, "hello", true, StructuredArg{42, 12.6})
   302  //
   303  func MarshalArgs(args ...interface{}) ([]json.RawMessage, error) {
   304  	var out []json.RawMessage
   305  	for _, arg := range args {
   306  		argJSON, err := json.Marshal(arg)
   307  		if err != nil {
   308  			return nil, err
   309  		}
   310  		out = append(out, argJSON)
   311  	}
   312  	return out, nil
   313  }
   314  
   315  // UnmarshalArgs decodes the given json.RawMessages to the variables provided
   316  // by args. Each element of args should be a pointer.
   317  //
   318  // Example usage:
   319  //
   320  //   var (
   321  //       num int
   322  //       str string
   323  //       bul bool
   324  //       structured StructuredArg
   325  //   )
   326  //   err := UnmarshalArgs(args, &num, &str, &bul, &structured)
   327  //
   328  func UnmarshalArgs(jsonArgs []json.RawMessage, args ...interface{}) error {
   329  	if len(args) != len(jsonArgs) {
   330  		return fmt.Errorf("DecodeArgs: expected %d input arguments, got %d JSON arguments", len(args), len(jsonArgs))
   331  	}
   332  	for i, arg := range args {
   333  		if err := json.Unmarshal(jsonArgs[i], arg); err != nil {
   334  			return err
   335  		}
   336  	}
   337  	return nil
   338  }
   339  
   340  // ImportPath returns the unquoted import path of s,
   341  // or "" if the path is not properly quoted.
   342  func ImportPath(s *ast.ImportSpec) string {
   343  	t, err := strconv.Unquote(s.Path.Value)
   344  	if err != nil {
   345  		return ""
   346  	}
   347  	return t
   348  }
   349  
   350  // NodeContains returns true if a node encloses a given position pos.
   351  func NodeContains(n ast.Node, pos token.Pos) bool {
   352  	return n != nil && n.Pos() <= pos && pos <= n.End()
   353  }
   354  
   355  // CollectScopes returns all scopes in an ast path, ordered as innermost scope
   356  // first.
   357  func CollectScopes(info *types.Info, path []ast.Node, pos token.Pos) []*types.Scope {
   358  	// scopes[i], where i<len(path), is the possibly nil Scope of path[i].
   359  	var scopes []*types.Scope
   360  	for _, n := range path {
   361  		// Include *FuncType scope if pos is inside the function body.
   362  		switch node := n.(type) {
   363  		case *ast.FuncDecl:
   364  			if node.Body != nil && NodeContains(node.Body, pos) {
   365  				n = node.Type
   366  			}
   367  		case *ast.FuncLit:
   368  			if node.Body != nil && NodeContains(node.Body, pos) {
   369  				n = node.Type
   370  			}
   371  		}
   372  		scopes = append(scopes, info.Scopes[n])
   373  	}
   374  	return scopes
   375  }
   376  
   377  // Qualifier returns a function that appropriately formats a types.PkgName
   378  // appearing in a *ast.File.
   379  func Qualifier(f *ast.File, pkg *types.Package, info *types.Info) types.Qualifier {
   380  	// Construct mapping of import paths to their defined or implicit names.
   381  	imports := make(map[*types.Package]string)
   382  	for _, imp := range f.Imports {
   383  		var obj types.Object
   384  		if imp.Name != nil {
   385  			obj = info.Defs[imp.Name]
   386  		} else {
   387  			obj = info.Implicits[imp]
   388  		}
   389  		if pkgname, ok := obj.(*types.PkgName); ok {
   390  			imports[pkgname.Imported()] = pkgname.Name()
   391  		}
   392  	}
   393  	// Define qualifier to replace full package paths with names of the imports.
   394  	return func(p *types.Package) string {
   395  		if p == pkg {
   396  			return ""
   397  		}
   398  		if name, ok := imports[p]; ok {
   399  			return name
   400  		}
   401  		return p.Name()
   402  	}
   403  }
   404  
   405  // isDirective reports whether c is a comment directive.
   406  //
   407  // Copied and adapted from go/src/go/ast/ast.go.
   408  func isDirective(c string) bool {
   409  	if len(c) < 3 {
   410  		return false
   411  	}
   412  	if c[1] != '/' {
   413  		return false
   414  	}
   415  	//-style comment (no newline at the end)
   416  	c = c[2:]
   417  	if len(c) == 0 {
   418  		// empty line
   419  		return false
   420  	}
   421  	// "//line " is a line directive.
   422  	// (The // has been removed.)
   423  	if strings.HasPrefix(c, "line ") {
   424  		return true
   425  	}
   426  
   427  	// "//[a-z0-9]+:[a-z0-9]"
   428  	// (The // has been removed.)
   429  	colon := strings.Index(c, ":")
   430  	if colon <= 0 || colon+1 >= len(c) {
   431  		return false
   432  	}
   433  	for i := 0; i <= colon+1; i++ {
   434  		if i == colon {
   435  			continue
   436  		}
   437  		b := c[i]
   438  		if !('a' <= b && b <= 'z' || '0' <= b && b <= '9') {
   439  			return false
   440  		}
   441  	}
   442  	return true
   443  }
   444  
   445  // honorSymlinks toggles whether or not we consider symlinks when comparing
   446  // file or directory URIs.
   447  const honorSymlinks = false
   448  
   449  func CompareURI(left, right span.URI) int {
   450  	if honorSymlinks {
   451  		return span.CompareURI(left, right)
   452  	}
   453  	if left == right {
   454  		return 0
   455  	}
   456  	if left < right {
   457  		return -1
   458  	}
   459  	return 1
   460  }
   461  
   462  // InDir checks whether path is in the file tree rooted at dir.
   463  // InDir makes some effort to succeed even in the presence of symbolic links.
   464  //
   465  // Copied and slightly adjusted from go/src/cmd/go/internal/search/search.go.
   466  func InDir(dir, path string) bool {
   467  	if inDirLex(dir, path) {
   468  		return true
   469  	}
   470  	if !honorSymlinks {
   471  		return false
   472  	}
   473  	xpath, err := filepath.EvalSymlinks(path)
   474  	if err != nil || xpath == path {
   475  		xpath = ""
   476  	} else {
   477  		if inDirLex(dir, xpath) {
   478  			return true
   479  		}
   480  	}
   481  
   482  	xdir, err := filepath.EvalSymlinks(dir)
   483  	if err == nil && xdir != dir {
   484  		if inDirLex(xdir, path) {
   485  			return true
   486  		}
   487  		if xpath != "" {
   488  			if inDirLex(xdir, xpath) {
   489  				return true
   490  			}
   491  		}
   492  	}
   493  	return false
   494  }
   495  
   496  // inDirLex is like inDir but only checks the lexical form of the file names.
   497  // It does not consider symbolic links.
   498  //
   499  // Copied from go/src/cmd/go/internal/search/search.go.
   500  func inDirLex(dir, path string) bool {
   501  	pv := strings.ToUpper(filepath.VolumeName(path))
   502  	dv := strings.ToUpper(filepath.VolumeName(dir))
   503  	path = path[len(pv):]
   504  	dir = dir[len(dv):]
   505  	switch {
   506  	default:
   507  		return false
   508  	case pv != dv:
   509  		return false
   510  	case len(path) == len(dir):
   511  		if path == dir {
   512  			return true
   513  		}
   514  		return false
   515  	case dir == "":
   516  		return path != ""
   517  	case len(path) > len(dir):
   518  		if dir[len(dir)-1] == filepath.Separator {
   519  			if path[:len(dir)] == dir {
   520  				return path[len(dir):] != ""
   521  			}
   522  			return false
   523  		}
   524  		if path[len(dir)] == filepath.Separator && path[:len(dir)] == dir {
   525  			if len(path) == len(dir)+1 {
   526  				return true
   527  			}
   528  			return path[len(dir)+1:] != ""
   529  		}
   530  		return false
   531  	}
   532  }