github.com/go-toolsmith/pkgload@v1.2.3-0.20240512101226-e704f55998f1/pkgload.go (about)

     1  // Package pkgload is a set of utilities for `go/packages` load-related operations.
     2  package pkgload
     3  
     4  import (
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"golang.org/x/tools/go/packages"
    10  )
    11  
    12  // Unit is a set of packages that form a logical group.
    13  // It is guaranteed that at least 1 field of this object is non-nil.
    14  type Unit struct {
    15  	// Base is a standard (normal) package.
    16  	//
    17  	// Note: it can be nil for packages that only have external
    18  	// tests, for example.
    19  	Base *packages.Package
    20  
    21  	// Test is a package compiled for test.
    22  	// Can be nil.
    23  	Test *packages.Package
    24  
    25  	// ExternalTest is a "_test" compiled package.
    26  	// Can be nil.
    27  	ExternalTest *packages.Package
    28  
    29  	// TestBinary is a test binary.
    30  	// Non-nil if Test or ExternalTest are present.
    31  	TestBinary *packages.Package
    32  }
    33  
    34  // NonNil returns first non-nil field (package) of the unit.
    35  //
    36  //  1. If Base is not nil, return Base.
    37  //  2. If Test is not nil, return Test.
    38  //  3. If ExternalTest is not nil, return ExternalTest.
    39  //  4. Otherwise return TestBinary.
    40  //
    41  // If all unit fields are nil, method panics.
    42  // This should never happen for properly-loaded units.
    43  func (u *Unit) NonNil() *packages.Package {
    44  	switch {
    45  	case u.Base != nil:
    46  		return u.Base
    47  	case u.Test != nil:
    48  		return u.Test
    49  	case u.ExternalTest != nil:
    50  		return u.ExternalTest
    51  	case u.TestBinary != nil:
    52  		return u.TestBinary
    53  	default:
    54  		panic("all Unit fields are nil")
    55  	}
    56  }
    57  
    58  // LoadPackages with a given config and patterns.
    59  func LoadPackages(cfg *packages.Config, patterns []string) ([]*packages.Package, error) {
    60  	pkgs, err := packages.Load(cfg, patterns...)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	result := pkgs[:0]
    66  	VisitUnits(pkgs, func(u *Unit) {
    67  		if u.ExternalTest != nil {
    68  			result = append(result, u.ExternalTest)
    69  		}
    70  
    71  		switch {
    72  		// Prefer tests to the base package, if present.
    73  		case u.Test != nil:
    74  			result = append(result, u.Test)
    75  		case u.Base != nil:
    76  			result = append(result, u.Base)
    77  		}
    78  	})
    79  
    80  	sort.SliceStable(result, func(i, j int) bool {
    81  		return result[i].PkgPath < result[j].PkgPath
    82  	})
    83  	return result, nil
    84  }
    85  
    86  // Deduplicate returns a copy of pkgs slice where all duplicated
    87  // package entries are removed.
    88  //
    89  // Packages are considered equal if all conditions below are satisfied:
    90  //   - Same ID
    91  //   - Same Name
    92  //   - Same PkgPath
    93  //   - Equal GoFiles
    94  func Deduplicate(pkgs []*packages.Package) []*packages.Package {
    95  	type pkgKey struct {
    96  		id    string
    97  		name  string
    98  		path  string
    99  		files string
   100  	}
   101  
   102  	pkgSet := make(map[pkgKey]*packages.Package)
   103  	for _, pkg := range pkgs {
   104  		sort.Strings(pkg.GoFiles)
   105  		key := pkgKey{
   106  			id:    pkg.ID,
   107  			name:  pkg.Name,
   108  			path:  pkg.PkgPath,
   109  			files: strings.Join(pkg.GoFiles, ";"),
   110  		}
   111  		pkgSet[key] = pkg
   112  	}
   113  
   114  	list := make([]*packages.Package, 0, len(pkgSet))
   115  	for _, pkg := range pkgSet {
   116  		list = append(list, pkg)
   117  	}
   118  	return list
   119  }
   120  
   121  // VisitUnits traverses potentially unsorted pkgs list as a set of units.
   122  // All related packages from the slice are passed into visit func as a single unit.
   123  // Units are visited in a sorted order (import path).
   124  //
   125  // All packages in a slice must be non-nil.
   126  func VisitUnits(pkgs []*packages.Package, visit func(*Unit)) {
   127  	pkgs = Deduplicate(pkgs)
   128  	units := make(map[string]*Unit)
   129  
   130  	internUnit := func(key string) *Unit {
   131  		u, ok := units[key]
   132  		if !ok {
   133  			u = &Unit{}
   134  			units[key] = u
   135  		}
   136  		return u
   137  	}
   138  
   139  	// Sanity check.
   140  	// Panic should never trigger if this library is correct.
   141  	mustBeNil := func(pkg *packages.Package) {
   142  		if pkg != nil {
   143  			panic(fmt.Sprintf("nil assertion failed for ID=%q Path=%q",
   144  				pkg.ID, pkg.PkgPath))
   145  		}
   146  	}
   147  
   148  	withoutSuffix := func(s, suffix string) string {
   149  		return s[:len(s)-len(suffix)]
   150  	}
   151  
   152  	for _, pkg := range pkgs {
   153  		switch {
   154  		case strings.HasSuffix(pkg.Name, "_test"):
   155  			key := withoutSuffix(pkg.PkgPath, "_test")
   156  			u := internUnit(key)
   157  			mustBeNil(u.ExternalTest)
   158  			u.ExternalTest = pkg
   159  		case strings.Contains(pkg.ID, ".test]"):
   160  			u := internUnit(pkg.PkgPath)
   161  			mustBeNil(u.Test)
   162  			u.Test = pkg
   163  		case pkg.Name == "main" && strings.HasSuffix(pkg.ID, ".test"):
   164  			key := withoutSuffix(pkg.PkgPath, ".text")
   165  			u := internUnit(key)
   166  			mustBeNil(u.TestBinary)
   167  			u.TestBinary = pkg
   168  		case pkg.Name == "":
   169  			// Empty package. Skip.
   170  		default:
   171  			u := internUnit(pkg.PkgPath)
   172  			mustBeNil(u.Base)
   173  			u.Base = pkg
   174  		}
   175  	}
   176  
   177  	unitList := make([]*Unit, 0, len(units))
   178  	for _, u := range units {
   179  		unitList = append(unitList, u)
   180  	}
   181  	sort.SliceStable(unitList, func(i, j int) bool {
   182  		return unitList[i].NonNil().PkgPath < unitList[j].NonNil().PkgPath
   183  	})
   184  	for _, u := range unitList {
   185  		visit(u)
   186  	}
   187  }