github.com/wedaly/gospelunk@v0.0.0-20240506220214-89e2d4a79789/pkg/list/list.go (about)

     1  package list
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"path/filepath"
     7  	"sort"
     8  	"strings"
     9  
    10  	"golang.org/x/tools/go/packages"
    11  
    12  	"github.com/wedaly/gospelunk/pkg/file"
    13  )
    14  
    15  type Options struct {
    16  	IncludeStructFields     bool
    17  	IncludeInterfaceMethods bool
    18  	IncludePrivate          bool
    19  	IncludeTests            bool
    20  	OnlyImports             bool
    21  }
    22  
    23  type Result struct {
    24  	Defs []Definition
    25  }
    26  
    27  type Package struct {
    28  	Name string
    29  	ID   string
    30  }
    31  
    32  type Definition struct {
    33  	file.Loc
    34  	Name string
    35  	Pkg  Package
    36  }
    37  
    38  func List(patterns []string, opts Options) (Result, error) {
    39  	var result Result
    40  
    41  	pkgs, err := loadGoPackages(patterns, opts)
    42  	if err != nil {
    43  		return result, err
    44  	}
    45  
    46  	seenFiles := make(map[string]struct{})
    47  	for _, pkg := range pkgs {
    48  		goPaths := make(map[string]struct{}, len(pkg.GoFiles))
    49  		for _, p := range pkg.GoFiles {
    50  			goPaths[p] = struct{}{}
    51  		}
    52  
    53  		for _, astFile := range pkg.Syntax {
    54  			path := pkg.Fset.Position(astFile.Pos()).Filename
    55  
    56  			if _, ok := goPaths[path]; !ok {
    57  				// Likely a compiled file from cgo. Ignore it.
    58  				continue
    59  			}
    60  
    61  			if _, ok := seenFiles[path]; ok {
    62  				// When opts.IncludeTests is true, the pkgs list will contain both the original pkg
    63  				// as well as the pkg compiled for tests. Deduplicate the file paths to avoid duplicating
    64  				// non-test definitions.
    65  				continue
    66  			} else {
    67  				seenFiles[path] = struct{}{}
    68  			}
    69  
    70  			ast.Inspect(astFile, func(node ast.Node) bool {
    71  				switch x := node.(type) {
    72  				case *ast.ValueSpec:
    73  					loadDefsFromValueSpec(pkg, opts, x, &result.Defs)
    74  					return false
    75  
    76  				case *ast.TypeSpec:
    77  					loadDefsFromTypeSpec(pkg, opts, x, &result.Defs)
    78  					return false
    79  
    80  				case *ast.FuncDecl:
    81  					loadDefsFromFuncDecl(pkg, opts, x, &result.Defs)
    82  					return false
    83  
    84  				default:
    85  					return true
    86  				}
    87  			})
    88  		}
    89  	}
    90  
    91  	sort.Slice(result.Defs, func(i, j int) bool {
    92  		a, b := result.Defs[i], result.Defs[j]
    93  		if a.Path != b.Path {
    94  			return a.Path < b.Path
    95  		} else if a.Line != b.Line {
    96  			return a.Line < b.Line
    97  		} else if a.Column != b.Column {
    98  			return a.Column < b.Column
    99  		} else {
   100  			return a.Name < b.Name
   101  		}
   102  	})
   103  
   104  	return result, nil
   105  }
   106  
   107  func loadGoPackages(patterns []string, opts Options) ([]*packages.Package, error) {
   108  	cfg := &packages.Config{
   109  		Mode:  packages.NeedName | packages.NeedFiles | packages.NeedSyntax | packages.NeedTypes,
   110  		Tests: opts.IncludeTests,
   111  	}
   112  
   113  	if opts.OnlyImports {
   114  		cfg.Mode |= (packages.NeedImports | packages.NeedDeps)
   115  	}
   116  
   117  	// Workaround for a quirk of the Go build system.
   118  	// When specifying a package using "file=" syntax, the result differs depending
   119  	// on whether the current working directory is inside the Go module.
   120  	// If inside the module, the package includes syntax trees for all files in the package.
   121  	// If outside the module, the package includes only syntax trees for the specific file.
   122  	// We want the same behavior in either case, so set the directory to the one containing
   123  	// the requested file to guarantee that the current working directory is in the module.
   124  	if len(patterns) == 1 && strings.HasPrefix(patterns[0], "file=") {
   125  		_, path, _ := strings.Cut(patterns[0], "=")
   126  		cfg.Dir = filepath.Dir(path)
   127  	}
   128  
   129  	pkgs, err := packages.Load(cfg, patterns...)
   130  	if err != nil {
   131  		return nil, fmt.Errorf("packages.Load: %w", err)
   132  	}
   133  
   134  	if opts.OnlyImports {
   135  		pkgs = uniqueImports(pkgs)
   136  	}
   137  
   138  	return pkgs, nil
   139  }
   140  
   141  func uniqueImports(pkgs []*packages.Package) []*packages.Package {
   142  	uniqueImports := make(map[string]*packages.Package, len(pkgs))
   143  	for _, pkg := range pkgs {
   144  		for _, importedPkg := range pkg.Imports {
   145  			if _, ok := uniqueImports[importedPkg.ID]; !ok {
   146  				uniqueImports[importedPkg.ID] = importedPkg
   147  			}
   148  		}
   149  	}
   150  
   151  	result := make([]*packages.Package, 0, len(uniqueImports))
   152  	for _, pkg := range uniqueImports {
   153  		result = append(result, pkg)
   154  	}
   155  
   156  	sort.Slice(result, func(i, j int) bool {
   157  		return result[i].ID < result[j].ID
   158  	})
   159  
   160  	return result
   161  }
   162  
   163  func loadDefsFromValueSpec(pkg *packages.Package, opts Options, valueSpec *ast.ValueSpec, defs *[]Definition) {
   164  	position := pkg.Fset.Position(valueSpec.Pos())
   165  	for _, nameIdent := range valueSpec.Names {
   166  		if nameIdent != nil && (opts.IncludePrivate || nameIdent.IsExported()) {
   167  			valueName := nameIdent.Name
   168  			*defs = append(*defs, Definition{
   169  				Name: valueName,
   170  				Pkg: Package{
   171  					ID:   pkg.ID,
   172  					Name: pkg.Name,
   173  				},
   174  				Loc: file.Loc{
   175  					Path:   position.Filename,
   176  					Line:   position.Line,
   177  					Column: position.Column,
   178  				},
   179  			})
   180  		}
   181  	}
   182  }
   183  
   184  func loadDefsFromTypeSpec(pkg *packages.Package, opts Options, typeSpec *ast.TypeSpec, defs *[]Definition) {
   185  	if typeSpec.Name == nil || (!opts.IncludePrivate && !typeSpec.Name.IsExported()) {
   186  		return
   187  	}
   188  
   189  	position := pkg.Fset.Position(typeSpec.Pos())
   190  	typeName := typeSpec.Name.Name
   191  
   192  	switch x := typeSpec.Type.(type) {
   193  	case *ast.StructType:
   194  		loadDefsFromStructType(pkg, opts, typeName, x, defs)
   195  	case *ast.InterfaceType:
   196  		loadDefsFromInterfaceType(pkg, opts, typeName, x, defs)
   197  	}
   198  
   199  	*defs = append(*defs, Definition{
   200  		Name: typeName,
   201  		Pkg: Package{
   202  			ID:   pkg.ID,
   203  			Name: pkg.Name,
   204  		},
   205  
   206  		Loc: file.Loc{
   207  			Path:   position.Filename,
   208  			Line:   position.Line,
   209  			Column: position.Column,
   210  		},
   211  	})
   212  }
   213  
   214  func loadDefsFromStructType(pkg *packages.Package, opts Options, typeName string, structType *ast.StructType, defs *[]Definition) {
   215  	if !opts.IncludeStructFields {
   216  		return
   217  	}
   218  
   219  	for _, field := range structType.Fields.List {
   220  		position := pkg.Fset.Position(field.Pos())
   221  		for _, nameIdent := range field.Names {
   222  			if nameIdent != nil && (opts.IncludePrivate || nameIdent.IsExported()) {
   223  				fieldName := nameIdent.Name
   224  				*defs = append(*defs, Definition{
   225  					Name: fmt.Sprintf("%s.%s", typeName, fieldName),
   226  					Pkg: Package{
   227  						ID:   pkg.ID,
   228  						Name: pkg.Name,
   229  					},
   230  					Loc: file.Loc{
   231  						Path:   position.Filename,
   232  						Line:   position.Line,
   233  						Column: position.Column,
   234  					},
   235  				})
   236  			}
   237  		}
   238  	}
   239  }
   240  
   241  func loadDefsFromInterfaceType(pkg *packages.Package, opts Options, typeName string, interfaceType *ast.InterfaceType, defs *[]Definition) {
   242  	if !opts.IncludeInterfaceMethods {
   243  		return
   244  	}
   245  
   246  	for _, method := range interfaceType.Methods.List {
   247  		position := pkg.Fset.Position(method.Pos())
   248  		for _, nameIdent := range method.Names {
   249  			if nameIdent != nil && (opts.IncludePrivate || nameIdent.IsExported()) {
   250  				methodName := nameIdent.Name
   251  				*defs = append(*defs, Definition{
   252  					Name: fmt.Sprintf("%s.%s", typeName, methodName),
   253  					Pkg: Package{
   254  						ID:   pkg.ID,
   255  						Name: pkg.Name,
   256  					},
   257  					Loc: file.Loc{
   258  						Path:   position.Filename,
   259  						Line:   position.Line,
   260  						Column: position.Column,
   261  					},
   262  				})
   263  			}
   264  		}
   265  	}
   266  }
   267  
   268  func loadDefsFromFuncDecl(pkg *packages.Package, opts Options, funcDecl *ast.FuncDecl, defs *[]Definition) {
   269  	if funcDecl.Name == nil || (!opts.IncludePrivate && !funcDecl.Name.IsExported()) {
   270  		return
   271  	}
   272  	position := pkg.Fset.Position(funcDecl.Pos())
   273  	name := funcDecl.Name.Name
   274  	if funcDecl.Recv != nil {
   275  		name = fmt.Sprintf("%s.%s", findFuncRecvName(funcDecl), name)
   276  	}
   277  	*defs = append(*defs, Definition{
   278  		Name: name,
   279  		Pkg: Package{
   280  			ID:   pkg.ID,
   281  			Name: pkg.Name,
   282  		},
   283  		Loc: file.Loc{
   284  			Path:   position.Filename,
   285  			Line:   position.Line,
   286  			Column: position.Column,
   287  		},
   288  	})
   289  }
   290  
   291  func findFuncRecvName(funcDecl *ast.FuncDecl) string {
   292  	var typeName string
   293  	for _, field := range funcDecl.Recv.List {
   294  		ast.Inspect(field.Type, func(node ast.Node) bool {
   295  			if ident, ok := node.(*ast.Ident); ok {
   296  				typeName = ident.Name
   297  				return false
   298  			}
   299  			return true
   300  		})
   301  	}
   302  	return typeName
   303  }