github.com/samlitowitz/goimportcycle@v1.0.9/internal/ast/primitive_builder.go (about)

     1  package ast
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/token"
     8  	"path/filepath"
     9  
    10  	"github.com/samlitowitz/goimportcycle/internal"
    11  )
    12  
    13  type PrimitiveBuilder struct {
    14  	modulePath    string
    15  	moduleRootDir string
    16  
    17  	packagesByUID map[string]*internal.Package
    18  	filesByUID    map[string]*internal.File
    19  
    20  	curPkg  *internal.Package
    21  	curFile *internal.File
    22  }
    23  
    24  func NewPrimitiveBuilder(modulePath, moduleRootDir string) *PrimitiveBuilder {
    25  	return &PrimitiveBuilder{
    26  		modulePath:    modulePath,
    27  		moduleRootDir: moduleRootDir,
    28  
    29  		packagesByUID: make(map[string]*internal.Package),
    30  		filesByUID:    make(map[string]*internal.File),
    31  	}
    32  }
    33  
    34  func (builder *PrimitiveBuilder) MarkupImportCycles() error {
    35  	stk := NewFileStack()
    36  	for _, baseFile := range builder.filesByUID {
    37  		stk.Push(baseFile)
    38  		err := builder.markupImportCycles(baseFile, stk)
    39  		if err != nil {
    40  			return err
    41  		}
    42  		stk.Pop()
    43  	}
    44  	return nil
    45  }
    46  
    47  func (builder *PrimitiveBuilder) markupImportCycles(
    48  	baseFile *internal.File,
    49  	stk *FileStack,
    50  ) error {
    51  	for _, imp := range baseFile.Imports {
    52  		for _, typ := range imp.ReferencedTypes {
    53  			refFile := typ.File
    54  			if stk.Contains(refFile) {
    55  				for i := stk.Len() - 1; i > 0; i-- {
    56  					curFile := stk.At(i)
    57  					curFile.InImportCycle = true
    58  					curFile.Package.InImportCycle = true
    59  					imp.InImportCycle = true
    60  					imp.ReferencedFilesInCycle[refFile.UID()] = refFile
    61  					if curFile.UID() == baseFile.UID() {
    62  						return nil
    63  					}
    64  				}
    65  			}
    66  			stk.Push(refFile)
    67  			err := builder.markupImportCycles(refFile, stk)
    68  			if err != nil {
    69  				return err
    70  			}
    71  			stk.Pop()
    72  		}
    73  	}
    74  	for _, refFile := range baseFile.ReferencedFiles() {
    75  		// If you eventually import yourself, it's a cycle
    76  		if stk.Contains(refFile) {
    77  			for i := stk.Len() - 1; i > 0; i-- {
    78  				curFile := stk.At(i)
    79  				curFile.InImportCycle = true
    80  				curFile.Package.InImportCycle = true
    81  				if curFile.UID() == baseFile.UID() {
    82  					return nil
    83  				}
    84  			}
    85  		}
    86  		stk.Push(refFile)
    87  		err := builder.markupImportCycles(refFile, stk)
    88  		if err != nil {
    89  			return err
    90  		}
    91  		stk.Pop()
    92  	}
    93  	return nil
    94  }
    95  
    96  func (builder *PrimitiveBuilder) AddNode(node ast.Node) error {
    97  	switch node := node.(type) {
    98  	case *Package:
    99  		return builder.addPackage(node)
   100  
   101  	case *File:
   102  		return builder.addFile(node)
   103  
   104  	case *ImportSpec:
   105  		return builder.addImport(node)
   106  
   107  	case *FuncDecl:
   108  		return builder.addFuncDecl(node)
   109  
   110  	case *ast.GenDecl:
   111  		return builder.addGenDecl(node)
   112  
   113  	case *SelectorExpr:
   114  		return builder.addSelectorExpr(node)
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  func (builder *PrimitiveBuilder) Files() []*internal.File {
   121  	files := make([]*internal.File, 0, len(builder.filesByUID))
   122  	for _, file := range builder.filesByUID {
   123  		files = append(files, file)
   124  	}
   125  	return files
   126  }
   127  
   128  func (builder *PrimitiveBuilder) Packages() []*internal.Package {
   129  	pkgs := make([]*internal.Package, 0, len(builder.packagesByUID))
   130  	for _, pkg := range builder.packagesByUID {
   131  		pkgs = append(pkgs, pkg)
   132  	}
   133  	return pkgs
   134  }
   135  
   136  func (builder *PrimitiveBuilder) addPackage(node *Package) error {
   137  	newPkg := buildPackage(
   138  		builder.modulePath,
   139  		builder.moduleRootDir,
   140  		node.DirName,
   141  		node.Name,
   142  		len(node.Files),
   143  	)
   144  	newPkgUID := newPkg.UID()
   145  
   146  	pkg, pkgExists := builder.packagesByUID[newPkgUID]
   147  
   148  	if pkgExists && !pkg.IsStub {
   149  		return fmt.Errorf("add package: duplicate package: %s", newPkg.ImportPath())
   150  	}
   151  
   152  	// replace stub with the real thing
   153  	if pkgExists && pkg.IsStub {
   154  		// Does this work or do we need a full traversal?
   155  		copyPackage(pkg, newPkg)
   156  		for _, file := range pkg.Files {
   157  			if !file.IsStub {
   158  				continue
   159  			}
   160  			if len(file.Decls) > 0 {
   161  				continue
   162  			}
   163  			// remove stub files with no declarations
   164  			delete(pkg.Files, file.UID())
   165  			delete(builder.filesByUID, file.UID())
   166  		}
   167  	}
   168  
   169  	// totally new package
   170  	if !pkgExists {
   171  		builder.packagesByUID[newPkgUID] = newPkg
   172  	}
   173  
   174  	builder.curPkg = builder.packagesByUID[newPkgUID]
   175  	return nil
   176  }
   177  
   178  func (builder *PrimitiveBuilder) addFile(node *File) error {
   179  	if builder.curPkg == nil {
   180  		return fmt.Errorf("add file: no package defined: %s", node.AbsPath)
   181  	}
   182  	file := &internal.File{
   183  		Package:  builder.packagesByUID[builder.curPkg.UID()],
   184  		FileName: filepath.Base(node.AbsPath),
   185  		AbsPath:  node.AbsPath,
   186  		Imports:  make(map[string]*internal.Import),
   187  		Decls:    make(map[string]*internal.Decl),
   188  	}
   189  	fileUID := file.UID()
   190  	if _, ok := builder.filesByUID[fileUID]; ok {
   191  		return fmt.Errorf("add file: duplicate file: %s", node.AbsPath)
   192  	}
   193  
   194  	pkgUID := builder.curPkg.UID()
   195  	builder.filesByUID[fileUID] = file
   196  	builder.packagesByUID[pkgUID].Files[fileUID] = builder.filesByUID[fileUID]
   197  	builder.curFile = builder.filesByUID[fileUID]
   198  
   199  	return nil
   200  }
   201  
   202  func (builder *PrimitiveBuilder) addImport(node *ImportSpec) error {
   203  	if builder.curPkg == nil {
   204  		return fmt.Errorf("add import: no package defined: %s \"%s\"", node.Name.String(), node.Path.Value)
   205  	}
   206  	if builder.curFile == nil {
   207  		return fmt.Errorf("add import: no file defined: %s \"%s\"", node.Name.String(), node.Path.Value)
   208  	}
   209  	imp := &internal.Import{
   210  		Name:                   node.Name.String(),
   211  		Path:                   node.Path.Value,
   212  		ReferencedTypes:        make(map[string]*internal.Decl),
   213  		ReferencedFilesInCycle: make(map[string]*internal.File),
   214  	}
   215  	if node.IsAliased {
   216  		imp.Name = node.Alias
   217  	}
   218  	impUID := imp.UID()
   219  	if _, ok := builder.curFile.Imports[impUID]; ok {
   220  		return fmt.Errorf("add import: duplicate import: %s \"%s\"", node.Name.String(), node.Path.Value)
   221  	}
   222  
   223  	// if the package exists, use it, otherwise use a stub
   224  	pkg := buildPackage(builder.modulePath, builder.moduleRootDir, imp.Path, imp.Name, 1)
   225  	pkg.IsStub = true
   226  	if _, ok := builder.packagesByUID[pkg.UID()]; ok {
   227  		pkg = builder.packagesByUID[pkg.UID()]
   228  	} else {
   229  		fileStub := buildStubFile(pkg)
   230  		pkg.Files[fileStub.UID()] = fileStub
   231  		builder.filesByUID[fileStub.UID()] = fileStub
   232  		builder.packagesByUID[pkg.UID()] = pkg
   233  	}
   234  
   235  	imp.Package = pkg
   236  
   237  	builder.curFile.Imports[impUID] = imp
   238  	return nil
   239  }
   240  
   241  func (builder *PrimitiveBuilder) addFuncDecl(node *FuncDecl) error {
   242  	if builder.curPkg == nil {
   243  		return fmt.Errorf("add func decl: no package defined: %s", node.QualifiedName)
   244  	}
   245  	if builder.curFile == nil {
   246  		return fmt.Errorf("add func decl: no file defined: %s", node.QualifiedName)
   247  	}
   248  	if node.Name.String() == "" {
   249  		return fmt.Errorf("add func decl: invalid function name")
   250  	}
   251  	declUID := node.QualifiedName
   252  	if _, ok := builder.curFile.Decls[declUID]; ok {
   253  		return fmt.Errorf("add func decl: duplicate declaration: %s", node.QualifiedName)
   254  	}
   255  	// TODO: receiver methods should never be received and should be skipped
   256  	var receiverDecl *internal.Decl
   257  	for _, file := range builder.curPkg.Files {
   258  		if _, ok := file.Decls[node.ReceiverName]; ok {
   259  			receiverDecl = file.Decls[node.ReceiverName]
   260  			break
   261  		}
   262  	}
   263  	decl := &internal.Decl{
   264  		File:         builder.curFile,
   265  		ReceiverDecl: receiverDecl,
   266  		Name:         node.Name.String(),
   267  	}
   268  	decl = builder.fixupStubDecl(decl)
   269  	builder.curFile.Decls[declUID] = decl
   270  
   271  	return nil
   272  }
   273  
   274  func (builder *PrimitiveBuilder) addGenDecl(node *ast.GenDecl) error {
   275  	if builder.curPkg == nil {
   276  		return errors.New("add gen decl: no package defined")
   277  	}
   278  	if builder.curFile == nil {
   279  		return errors.New("add gen decl: no file defined")
   280  	}
   281  	for _, spec := range node.Specs {
   282  		switch spec := spec.(type) {
   283  		case *ast.TypeSpec:
   284  			if node.Tok != token.TYPE {
   285  				return errors.New("add gen decl: invalid declaration")
   286  			}
   287  			if _, ok := builder.curFile.Decls[spec.Name.String()]; ok {
   288  				return errors.New("add gen decl: duplicate declaration")
   289  			}
   290  			if spec.Name.String() == "" {
   291  				return errors.New("add gen decl: invalid name")
   292  			}
   293  			decl := &internal.Decl{
   294  				File:         builder.curFile,
   295  				ReceiverDecl: nil,
   296  				Name:         spec.Name.String(),
   297  			}
   298  			decl = builder.fixupStubDecl(decl)
   299  			builder.curFile.Decls[decl.UID()] = decl
   300  
   301  		case *ast.ValueSpec:
   302  			if node.Tok != token.CONST && node.Tok != token.VAR {
   303  				return errors.New("add gen decl: invalid declaration")
   304  			}
   305  			for _, name := range spec.Names {
   306  				if _, ok := builder.curFile.Decls[name.String()]; ok {
   307  					return errors.New("add gen decl: duplicate declaration")
   308  				}
   309  				if name.String() == "" {
   310  					return errors.New("add gen decl: invalid constant or variable name")
   311  				}
   312  				decl := &internal.Decl{
   313  					File:         builder.curFile,
   314  					ReceiverDecl: nil,
   315  					Name:         name.String(),
   316  				}
   317  				decl = builder.fixupStubDecl(decl)
   318  				builder.curFile.Decls[decl.UID()] = decl
   319  			}
   320  
   321  		default:
   322  			return errors.New("add gen decl: unhandled spec type")
   323  		}
   324  	}
   325  	return nil
   326  }
   327  
   328  func (builder *PrimitiveBuilder) addSelectorExpr(node *SelectorExpr) error {
   329  	if builder.curPkg == nil {
   330  		return fmt.Errorf("add selector expr: no package defined: %s", node.Sel.String())
   331  	}
   332  	if builder.curFile == nil {
   333  		return fmt.Errorf("add selector expr: no file defined: %s", node.Sel.String())
   334  	}
   335  	imp, hasImp := builder.curFile.Imports[node.ImportName]
   336  
   337  	if !hasImp {
   338  		return fmt.Errorf("add selector expr: no import defined: %s", node.Sel.String())
   339  	}
   340  
   341  	decl := &internal.Decl{
   342  		Name: node.Sel.String(),
   343  	}
   344  
   345  	if _, ok := imp.ReferencedTypes[decl.Name]; ok {
   346  		// type already registered
   347  		return nil
   348  	}
   349  
   350  	if imp.Package == nil {
   351  		return fmt.Errorf("add selector expr: no package defined here: %s", node.Sel.String())
   352  	}
   353  
   354  	// attempt to find file where declaration is defined
   355  	var foundDecl bool
   356  	var stubFile *internal.File
   357  	for _, file := range imp.Package.Files {
   358  		// track stub file,
   359  		if file.IsStub {
   360  			stubFile = file
   361  			continue
   362  		}
   363  		if !file.HasDecl(decl) {
   364  			continue
   365  		}
   366  		decl.File = file
   367  		foundDecl = true
   368  		break
   369  	}
   370  
   371  	// if not file is found, attempt to add to a stub file
   372  	if !foundDecl {
   373  		if stubFile == nil {
   374  			return fmt.Errorf("add selector expr: no stub file defined: %s", node.Sel.String())
   375  		}
   376  		stubDecl, isDeclInStub := stubFile.Decls[decl.UID()]
   377  		if isDeclInStub {
   378  			decl = stubDecl
   379  		}
   380  		if !isDeclInStub {
   381  			decl.File = stubFile
   382  			// add declaration to stub file
   383  			stubFile.Decls[decl.UID()] = decl
   384  		}
   385  	}
   386  
   387  	if decl.File == nil {
   388  		return fmt.Errorf("add selector expr: missing type declaration: %s", node.Sel.String())
   389  	}
   390  
   391  	imp.ReferencedTypes[decl.Name] = decl
   392  	return nil
   393  }
   394  
   395  func (builder *PrimitiveBuilder) fixupStubDecl(newDecl *internal.Decl) *internal.Decl {
   396  	for fileUID, file := range builder.curPkg.Files {
   397  		// can only fix-up declarations in stub files
   398  		if !file.IsStub {
   399  			continue
   400  		}
   401  		for stubDeclUID, stubDecl := range file.Decls {
   402  			// can only fix-up the same declaration
   403  			if newDecl.UID() != stubDecl.UID() {
   404  				continue
   405  			}
   406  			// everything is already pointing at the stub declaration
   407  			// update stub declaration with values from the new declaration
   408  			copyDeclaration(stubDecl, newDecl)
   409  
   410  			// remove declaration from stub file
   411  			delete(file.Decls, stubDeclUID)
   412  
   413  			// if there are no more declarations in the stub file remove it
   414  			if len(file.Decls) == 0 {
   415  				delete(builder.curPkg.Files, fileUID)
   416  				delete(builder.filesByUID, fileUID)
   417  			}
   418  			return stubDecl
   419  		}
   420  	}
   421  	return newDecl
   422  }
   423  
   424  func buildPackage(
   425  	modulePath,
   426  	moduleRootDir,
   427  	dirName, name string,
   428  	fileCount int,
   429  ) *internal.Package {
   430  	pkg := &internal.Package{
   431  		DirName:    dirName,
   432  		ModulePath: modulePath,
   433  		ModuleRoot: moduleRootDir,
   434  		Name:       name,
   435  		Files:      make(map[string]*internal.File, fileCount),
   436  	}
   437  	return pkg
   438  }
   439  
   440  func buildStubFile(pkg *internal.Package) *internal.File {
   441  	return &internal.File{
   442  		Package:  pkg,
   443  		FileName: "stub.go",
   444  		AbsPath: fmt.Sprintf(
   445  			"STUB://%s/stub.go",
   446  			pkg.UID(),
   447  		),
   448  		Imports: make(map[string]*internal.Import),
   449  		Decls:   make(map[string]*internal.Decl),
   450  		IsStub:  true,
   451  	}
   452  }
   453  
   454  func copyPackage(to, from *internal.Package) {
   455  	to.DirName = from.DirName
   456  	to.ModuleRoot = from.ModuleRoot
   457  	to.Name = from.Name
   458  	if from.Files != nil {
   459  		for uid, file := range from.Files {
   460  			to.Files[uid] = file
   461  		}
   462  	}
   463  	to.IsStub = from.IsStub
   464  	to.InImportCycle = from.InImportCycle
   465  }
   466  
   467  func copyDeclaration(to, from *internal.Decl) {
   468  	to.File = from.File
   469  	to.ReceiverDecl = from.ReceiverDecl
   470  	to.Name = from.Name
   471  }
   472  
   473  type FileStack struct {
   474  	indexByUID map[string]int
   475  	stack      []*internal.File
   476  }
   477  
   478  func NewFileStack() *FileStack {
   479  	return &FileStack{
   480  		indexByUID: make(map[string]int),
   481  		stack:      make([]*internal.File, 0),
   482  	}
   483  }
   484  
   485  func (s *FileStack) Push(f *internal.File) {
   486  	s.stack = append(s.stack, f)
   487  	s.indexByUID[f.UID()] = len(s.stack) - 1
   488  }
   489  
   490  func (s *FileStack) Pop() {
   491  	delete(s.indexByUID, s.stack[len(s.stack)-1].UID())
   492  	s.stack = s.stack[0 : len(s.stack)-1]
   493  }
   494  
   495  func (s *FileStack) Top() *internal.File {
   496  	if len(s.stack) == 0 {
   497  		return nil
   498  	}
   499  	return s.stack[len(s.stack)-1]
   500  }
   501  
   502  func (s *FileStack) At(i int) *internal.File {
   503  	if len(s.stack) == 0 {
   504  		return nil
   505  	}
   506  	if i > (len(s.stack) - 1) {
   507  		return nil
   508  	}
   509  	return s.stack[i]
   510  }
   511  
   512  func (s *FileStack) Contains(f *internal.File) bool {
   513  	_, ok := s.indexByUID[f.UID()]
   514  	return ok
   515  }
   516  
   517  func (s *FileStack) Len() int {
   518  	return len(s.stack)
   519  }